Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support regex pattern in SFTPHOOK #15409

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 47 additions & 8 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
"""This module contains SFTP hook."""
import datetime
import os
import re
import stat
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -166,16 +168,21 @@ def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
}
return files

def list_directory(self, path: str) -> List[str]:
def list_directory(self, path: str, regex_pattern: Optional[str] = None) -> List[str]:
"""
Returns a list of files on the remote system.

:param path: full path to the remote directory to list
:type path: str
:param regex_pattern: optional pattern to filter the remote_full_path files
:type: regex_pattern: Optional[str]
"""
conn = self.get_conn()
files = conn.listdir(path)
return files
if regex_pattern:
pattern = re.compile(regex_pattern)
return [file for file in conn.listdir(path) if pattern.match(file)]

return conn.listdir(path)

def create_directory(self, path: str, mode: int = 777) -> None:
"""
Expand All @@ -198,7 +205,9 @@ def delete_directory(self, path: str) -> None:
conn = self.get_conn()
conn.rmdir(path)

def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
def retrieve_file(
self, remote_full_path: str, local_full_path: str, regex_pattern: Optional[str] = None
) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I really dislike about this implementation is it subtly changes semantics depending on the arguments. In this function, if regex_pattern is None, remote_full_path should point to a file, but when pattern matching is performed, it should point to a directory instead (and the file name matched with the pattern). This will be very unintuitive in practice and likely cause headaches to both users and maintainers.

Pattern matching is definitely a useful feature here, but I really don’t think this is how it should be done.

"""
Transfers the remote file to a local location.
If local_full_path is a string path, the file will be put
Expand All @@ -208,11 +217,21 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
:type remote_full_path: str
:param local_full_path: full path to the local file
:type local_full_path: str
:param regex_pattern: optional pattern to filter the remote_full_path files
:type: regex_pattern: Optional[str]
"""
conn = self.get_conn()
self.log.info('Retrieving file from FTP: %s', remote_full_path)
conn.get(remote_full_path, local_full_path)
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
if regex_pattern:
pattern = re.compile(regex_pattern)
for file in conn.listdir(remote_full_path):
if pattern.match(file):
conn.get(os.path.join(remote_full_path, file), os.path.join(local_full_path, file))
self.log.info('Finished retrieving file from FTP: %s', file)

else:
conn.get(remote_full_path, local_full_path)
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)

def store_file(self, remote_full_path: str, local_full_path: str) -> None:
"""
Expand All @@ -238,14 +257,23 @@ def delete_file(self, path: str) -> None:
conn = self.get_conn()
conn.remove(path)

def get_mod_time(self, path: str) -> str:
def get_mod_time(self, path: str, regex_pattern: Optional[str] = None) -> str:
"""
Returns modification time.

:param path: full path to the remote file
:type path: str
:param regex_pattern: optional pattern to filter the remote_full_path files
:type: regex_pattern: Optional[str]
"""
conn = self.get_conn()
if regex_pattern:
pattern = re.compile(regex_pattern)
for file in conn.listdir(path):
if pattern.match(file):
path = file
break

ftp_mdtm = conn.stat(path).st_mtime
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')

Expand Down Expand Up @@ -279,7 +307,11 @@ def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[
return True

def get_tree_map(
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
self,
path: str,
prefix: Optional[str] = None,
delimiter: Optional[str] = None,
regex_pattern: Optional[str] = None,
) -> Tuple[List[str], List[str], List[str]]:
"""
Return tuple with recursive lists of files, directories and unknown paths from given path.
Expand All @@ -293,6 +325,8 @@ def get_tree_map(
:type delimiter: str
:return: tuple with list of files, dirs and unknown items
:rtype: Tuple[List[str], List[str], List[str]]
:param regex_pattern: optional pattern to filter the remote_full_path files
:type: regex_pattern: Optional[str]
"""
conn = self.get_conn()
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
Expand All @@ -308,4 +342,9 @@ def append_matching_path_callback(list_):
recurse=True,
)

if regex_pattern:
pattern = re.compile(regex_pattern)
files = [file for file in files if pattern.match(file)]
unknowns = [file for file in unknowns if pattern.match(file)]
dirs = [file for file in dirs if pattern.match(file)]
return files, dirs, unknowns
11 changes: 8 additions & 3 deletions airflow/providers/sftp/sensors/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,27 @@ class SFTPSensor(BaseSensorOperator):
:type path: str
:param sftp_conn_id: The connection to run the sensor against
:type sftp_conn_id: str
:param regex_pattern: optional pattern to filter the path files
:type: regex_pattern: Optional[str]
"""

template_fields = ('path',)
template_fields = ('path', 'regex_pattern')

@apply_defaults
def __init__(self, *, path: str, sftp_conn_id: str = 'sftp_default', **kwargs) -> None:
def __init__(
self, *, path: str, regex_pattern: Optional[str] = None, sftp_conn_id: str = 'sftp_default', **kwargs
) -> None:
super().__init__(**kwargs)
self.path = path
self.regex_pattern = regex_pattern
self.hook: Optional[SFTPHook] = None
self.sftp_conn_id = sftp_conn_id

def poke(self, context: dict) -> bool:
self.hook = SFTPHook(self.sftp_conn_id)
self.log.info('Poking for %s', self.path)
try:
mod_time = self.hook.get_mod_time(self.path)
mod_time = self.hook.get_mod_time(self.path, self.regex_pattern)
self.log.info('Found File %s last modified: %s', str(self.path), str(mod_time))
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
Expand Down
32 changes: 28 additions & 4 deletions tests/providers/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def generate_host_key(pkey: paramiko.PKey):
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'
CSV_TMP_FILE_FOR_TESTS = 'test_file.csv'

SFTP_CONNECTION_USER = "root"

Expand All @@ -68,6 +69,8 @@ def setUp(self):
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, CSV_TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')

def test_get_conn(self):
output = self.hook.get_conn()
Expand All @@ -85,7 +88,13 @@ def test_describe_directory(self):

def test_list_directory(self):
output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
assert output == [SUB_DIR]
assert output == [SUB_DIR, CSV_TMP_FILE_FOR_TESTS]

def test_list_directory_with_pattern(self):
output = self.hook.list_directory(
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), regex_pattern=".*.csv"
)
assert output == [CSV_TMP_FILE_FOR_TESTS]

def test_create_and_delete_directory(self):
new_dir_name = 'new_dir'
Expand Down Expand Up @@ -117,7 +126,7 @@ def test_store_retrieve_and_delete_file(self):
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS),
)
output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
assert output == [SUB_DIR, TMP_FILE_FOR_TESTS]
assert output == [SUB_DIR, CSV_TMP_FILE_FOR_TESTS, TMP_FILE_FOR_TESTS]
retrieved_file_name = 'retrieved.txt'
self.hook.retrieve_file(
remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
Expand All @@ -127,7 +136,15 @@ def test_store_retrieve_and_delete_file(self):
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
self.hook.delete_file(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
assert output == [SUB_DIR]
assert output == [SUB_DIR, CSV_TMP_FILE_FOR_TESTS]

def test_retrieve_with_pattern(self):
self.hook.retrieve_file(
remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH),
regex_pattern=".*.csv",
)
assert CSV_TMP_FILE_FOR_TESTS in os.listdir(TMP_PATH)

def test_get_mod_time(self):
self.hook.store_file(
Expand Down Expand Up @@ -252,10 +269,17 @@ def test_get_tree_map(self):
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
files, dirs, unknowns = tree_map

assert files == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)]
assert files == [
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS),
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, CSV_TMP_FILE_FOR_TESTS),
]
assert dirs == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)]
assert unknowns == []

def test_get_tree_map_with_pattern(self):
files, _, _ = self.hook.get_tree_map(path=os.path.join(TMP_PATH), regex_pattern=".*.csv")
assert os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, CSV_TMP_FILE_FOR_TESTS) in files

def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
Expand Down
17 changes: 14 additions & 3 deletions tests/providers/sftp/sensors/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ def test_file_present(self, sftp_hook_mock):
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt')
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt', None)
assert output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_present_with_pattern(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', regex_pattern=".*.txt")
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/', ".*.txt")
assert output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
Expand All @@ -41,7 +50,7 @@ def test_file_absent(self, sftp_hook_mock):
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt')
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt', None)
assert not output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
Expand All @@ -51,7 +60,9 @@ def test_sftp_failure(self, sftp_hook_mock):
context = {'ds': '1970-01-01'}
with pytest.raises(OSError):
sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
'/path/to/file/1970-01-01.txt', None
)

def test_hook_not_created_during_init(self):
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt')
Expand Down