diff --git a/example_dags/example_amazon_s3_snowflake_transform.py b/example_dags/example_amazon_s3_snowflake_transform.py index 8c57b5ff19..e4c8c6a03f 100644 --- a/example_dags/example_amazon_s3_snowflake_transform.py +++ b/example_dags/example_amazon_s3_snowflake_transform.py @@ -21,17 +21,17 @@ def combine_data(center_1: Table, center_2: Table): @aql.transform() def clean_data(input_table: Table): return """SELECT * - FROM {{input_table}} WHERE TYPE NOT LIKE 'Guinea Pig' + FROM {{input_table}} WHERE type NOT LIKE 'Guinea Pig' """ -@aql.dataframe(identifiers_as_lower=False) +@aql.dataframe() def aggregate_data(df: pd.DataFrame): - adoption_reporting_dataframe = df.pivot_table( + new_df = df.pivot_table( index="date", values="name", columns=["type"], aggfunc="count" ).reset_index() - - return adoption_reporting_dataframe + new_df.columns = new_df.columns.str.lower() + return new_df @dag( @@ -67,11 +67,11 @@ def example_amazon_s3_snowflake_transform(): ) temp_table_1 = aql.load_file( - input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_1.csv"), + input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_1_unquoted.csv"), output_table=input_table_1, ) temp_table_2 = aql.load_file( - input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_2.csv"), + input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_2_unquoted.csv"), output_table=input_table_2, ) @@ -85,7 +85,10 @@ def example_amazon_s3_snowflake_transform(): cleaned_data, output_table=Table( name="aggregated_adoptions_" + str(int(time.time())), - metadata=Metadata(schema=os.environ["SNOWFLAKE_SCHEMA"]), + metadata=Metadata( + schema=os.environ["SNOWFLAKE_SCHEMA"], + database=os.environ["SNOWFLAKE_DATABASE"], + ), conn_id="snowflake_conn", ), ) diff --git a/src/astro/databases/snowflake.py b/src/astro/databases/snowflake.py index d6c2fe0e8d..a59df7fd7d 100644 --- a/src/astro/databases/snowflake.py +++ b/src/astro/databases/snowflake.py @@ -1,5 +1,6 @@ """Snowflake database implementation.""" import logging +import os import random import string from dataclasses import dataclass, field @@ -7,10 +8,10 @@ import pandas as pd from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook -from pandas.io.sql import SQLDatabase from snowflake.connector import pandas_tools from snowflake.connector.errors import ProgrammingError +from astro import settings from astro.constants import ( DEFAULT_CHUNK_SIZE, FileLocation, @@ -36,6 +37,14 @@ FileType.PARQUET: "MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE", } +DEFAULT_STORAGE_INTEGRATION = { + FileLocation.S3: settings.SNOWFLAKE_STORAGE_INTEGRATION_AMAZON, + FileLocation.GS: settings.SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE, +} + +NATIVE_LOAD_SUPPORTED_FILE_TYPES = (FileType.CSV, FileType.NDJSON, FileType.PARQUET) +NATIVE_LOAD_SUPPORTED_FILE_LOCATIONS = (FileLocation.GS, FileLocation.S3) + @dataclass class SnowflakeStage: @@ -138,7 +147,6 @@ class SnowflakeDatabase(BaseDatabase): """ def __init__(self, conn_id: str = DEFAULT_CONN_ID): - self.storage_integration: Optional[str] = None super().__init__(conn_id) @property @@ -192,7 +200,9 @@ def _create_stage_auth_sub_statement( :param storage_integration: Previously created Snowflake storage integration :return: String containing line to be used for authentication on the remote storage """ - + storage_integration = storage_integration or DEFAULT_STORAGE_INTEGRATION.get( + file.location.location_type + ) if storage_integration is not None: auth = f"storage_integration = {storage_integration};" else: @@ -291,6 +301,93 @@ def drop_stage(self, stage: SnowflakeStage) -> None: # Table load methods # --------------------------------------------------------- + def create_table_using_schema_autodetection( + self, + table: Table, + file: Optional[File] = None, + dataframe: Optional[pd.DataFrame] = None, + ) -> None: + """ + Create a SQL table, automatically inferring the schema using the given file. + + :param table: The table to be created. + :param file: File used to infer the new table columns. + :param dataframe: Dataframe used to infer the new table columns if there is no file + """ + if file: + dataframe = file.export_to_dataframe( + nrows=settings.LOAD_TABLE_AUTODETECT_ROWS_COUNT + ) + + # Snowflake doesn't handle well mixed capitalisation of column name chars + # we are handling this more gracefully in a separate PR + if dataframe is not None: + dataframe.columns.str.upper() + + super().create_table_using_schema_autodetection(table, dataframe=dataframe) + + def is_native_load_file_available( + self, source_file: File, target_table: Table + ) -> bool: + """ + Check if there is an optimised path for source to destination. + + :param source_file: File from which we need to transfer data + :param target_table: Table that needs to be populated with file data + """ + is_file_type_supported = ( + source_file.type.name in NATIVE_LOAD_SUPPORTED_FILE_TYPES + ) + is_file_location_supported = ( + source_file.location.location_type in NATIVE_LOAD_SUPPORTED_FILE_LOCATIONS + ) + return is_file_type_supported and is_file_location_supported + + def load_file_to_table_natively( + self, + source_file: File, + target_table: Table, + if_exists: LoadExistStrategy = "replace", + native_support_kwargs: Optional[Dict] = None, + **kwargs, + ) -> None: + """ + Load the content of a file to an existing Snowflake table natively by: + - Creating a Snowflake external stage + - Using Snowflake COPY INTO statement + + Requirements: + - The user must have permissions to create a STAGE in Snowflake. + - If loading from GCP Cloud Storage, `native_support_kwargs` must define `storage_integration` + - If loading from AWS S3, the credentials for creating the stage may be + retrieved from the Airflow connection or from the `storage_integration` + attribute within `native_support_kwargs`. + + :param source_file: File from which we need to transfer data + :param target_table: Table to which the content of the file will be loaded to + :param if_exists: Strategy used to load (currently supported: "append" or "replace") + :param native_support_kwargs: may be used for the stage creation, as described above. + + .. seealso:: + `Snowflake official documentation on COPY INTO + `_ + `Snowflake official documentation on CREATE STAGE + `_ + + """ + native_support_kwargs = native_support_kwargs or {} + storage_integration = native_support_kwargs.get("storage_integration") + stage = self.create_stage( + file=source_file, storage_integration=storage_integration + ) + table_name = self.get_table_qualified_name(target_table) + file_path = os.path.basename(source_file.path) or "" + sql_statement = ( + f"COPY INTO {table_name} FROM @{stage.qualified_name}/{file_path}" + ) + self.hook.run(sql_statement) + self.drop_stage(stage) + def load_pandas_dataframe_to_table( self, source_dataframe: pd.DataFrame, @@ -307,27 +404,15 @@ def load_pandas_dataframe_to_table( :param if_exists: Strategy to be used in case the target table already exists. :param chunk_size: Specify the number of rows in each batch to be written at a time. """ - db = SQLDatabase(engine=self.sqlalchemy_engine) - # Make columns uppercase to prevent weird errors in snowflake - source_dataframe.columns = source_dataframe.columns.str.upper() - schema = None - if target_table.metadata: - schema = getattr(target_table.metadata, "schema", None) - - # within prep_table() we use pandas drop() function which is used when we pass 'if_exists=replace'. - # There is an issue where has_table() works with uppercase table names but the function meta.reflect() don't. - # To prevent the issue we are passing table name in lowercase. - db.prep_table( - source_dataframe, - target_table.name.lower(), - schema=schema, - if_exists=if_exists, - index=False, - ) + self.create_table(target_table, dataframe=source_dataframe) + + self.table_exists(target_table) pandas_tools.write_pandas( - self.hook.get_conn(), - source_dataframe, - target_table.name, + conn=self.hook.get_conn(), + df=source_dataframe, + table_name=target_table.name, + schema=target_table.metadata.schema, + database=target_table.metadata.database, chunk_size=chunk_size, quote_identifiers=False, ) diff --git a/tests/benchmark/Dockerfile b/tests/benchmark/Dockerfile index ece9eb23d4..1d2d241bd9 100644 --- a/tests/benchmark/Dockerfile +++ b/tests/benchmark/Dockerfile @@ -11,6 +11,7 @@ ENV AIRFLOW_HOME=/opt/app/ ENV PYTHONPATH=/opt/app/ ENV ASTRO_PUBLISH_BENCHMARK_DATA=True ENV GCP_BUCKET=dag-authoring +ENV AIRFLOW__ASTRO_SDK__SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE=gcs_int_python_sdk # Debian Bullseye is shipped with Python 3.9 # Upgrade built-in pip diff --git a/tests/benchmark/Makefile b/tests/benchmark/Makefile index 6b736d0ba2..e56b3dfdf7 100644 --- a/tests/benchmark/Makefile +++ b/tests/benchmark/Makefile @@ -19,6 +19,10 @@ clean: @rm -f unittests.cfg @rm -f unittests.db @rm -f webserver_config.py + @rm -f ../../unittests.cfg + @rm -f ../../unittests.db + @rm -f ../../airflow.cfg + @rm -f ../../airflow.db # Takes approximately 7min setup_gke: @@ -45,8 +49,7 @@ local: check_google_credentials benchmark @rm -rf astro-sdk - -run_job: check_google_credentials +run_job: @gcloud container clusters get-credentials astro-sdk --zone us-central1-a --project ${GCP_PROJECT} @kubectl apply -f infrastructure/kubernetes/namespace.yaml @kubectl apply -f infrastructure/kubernetes/postgres.yaml diff --git a/tests/benchmark/config.json b/tests/benchmark/config.json index 14c598462d..0e0d9eba06 100644 --- a/tests/benchmark/config.json +++ b/tests/benchmark/config.json @@ -1,5 +1,11 @@ { "databases": [ + { + "name": "snowflake", + "params": { + "conn_id": "snowflake_conn" + } + }, { "name": "postgres", "params": { @@ -17,12 +23,6 @@ "database": "bigquery" } } - }, - { - "name": "snowflake", - "params": { - "conn_id": "snowflake_conn" - } } ], "datasets": [ diff --git a/tests/benchmark/debug.yaml b/tests/benchmark/debug.yaml new file mode 100644 index 0000000000..0a80758833 --- /dev/null +++ b/tests/benchmark/debug.yaml @@ -0,0 +1,12 @@ +apiVersion: v1 +kind: Pod +metadata: + name: troubleshoot + namespace: benchmark +spec: + containers: + - name: troubleshoot-benchmark + image: gcr.io/astronomer-dag-authoring/benchmark + # Just spin & wait forever + command: [ "/bin/bash", "-c", "--" ] + args: [ "while true; do sleep 30; done;" ] diff --git a/tests/benchmark/results.md b/tests/benchmark/results.md index 1b2dddc3c7..001fa926ca 100644 --- a/tests/benchmark/results.md +++ b/tests/benchmark/results.md @@ -85,7 +85,6 @@ The benchmark was run as a Kubernetes job in GKE: * Container resource limit: * Memory: 10 Gi - | database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system | |:-----------|:-----------|:-------------|:-------------|:----------------|:------------------| | snowflake | ten_kb | 4.75s | 59.3MB | 1.45s | 100.0ms | @@ -95,6 +94,25 @@ The benchmark was run as a Kubernetes job in GKE: | snowflake | five_gb | 24.46min | 97.85MB | 1.43min | 5.94s | | snowflake | ten_gb | 50.85min | 104.53MB | 2.7min | 12.11s | +### With native support + +The benchmark was run as a Kubernetes job in GKE: + +* Version: `astro-sdk-python` 1.0.0a1 (`bc58830`) +* Machine type: `n2-standard-4` + * vCPU: 4 + * Memory: 16 GB RAM +* Container resource limit: + * Memory: 10 Gi + +| database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system | +|:-----------|:-----------|:-------------|:-------------|:----------------|:------------------| +| snowflake | ten_kb | 9.1s | 56.45MB | 2.56s | 110.0ms | +| snowflake | hundred_kb | 9.19s | 45.4MB | 2.55s | 120.0ms | +| snowflake | ten_mb | 10.9s | 47.51MB | 2.58s | 160.0ms | +| snowflake | one_gb | 1.07min | 47.94MB | 8.7s | 5.67s | +| snowflake | five_gb | 5.49min | 53.69MB | 18.76s | 1.6s | + ### Database: postgres | database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system | diff --git a/tests/data/README.md b/tests/data/README.md deleted file mode 100644 index c67bf1939a..0000000000 --- a/tests/data/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# Test data - -Most of the data in this directory was created for the purposes of this project. - -The following exceptions apply: -* imdb.csv: copied from [saipranava/IMDB](https://github.com/saipranava/IMDB/blob/master/IMDB.csv) diff --git a/tests/databases/test_snowflake.py b/tests/databases/test_snowflake.py index 2035371b8f..44d85375a1 100644 --- a/tests/databases/test_snowflake.py +++ b/tests/databases/test_snowflake.py @@ -451,3 +451,40 @@ def test_create_stage_amazon_fails_due_to_no_credentials(get_credentials): "In order to create an stage for S3, one of the following is required" ) assert exc_info.match(expected_msg) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "database_table_fixture", + [ + {"database": Database.SNOWFLAKE}, + ], + indirect=True, + ids=["snowflake"], +) +@pytest.mark.parametrize( + "remote_files_fixture", + [ + {"provider": "amazon", "filetype": FileType.CSV}, + ], + indirect=True, + ids=["amazon_csv"], +) +def test_load_file_to_table_natively(remote_files_fixture, database_table_fixture): + """Load a file to a Snowflake table using the native optimisation.""" + filepath = remote_files_fixture[0] + database, target_table = database_table_fixture + database.load_file_to_table( + File(filepath), target_table, {}, use_native_support=True + ) + + df = database.hook.get_pandas_df(f"SELECT * FROM {target_table.name}") + assert len(df) == 3 + expected = pd.DataFrame( + [ + {"id": 1, "name": "First"}, + {"id": 2, "name": "Second"}, + {"id": 3, "name": "Third with unicode पांचाल"}, + ] + ) + test_utils.assert_dataframes_are_equal(df, expected)