Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Snowflake load_file using native COPY INTO #544

Merged
merged 17 commits into from Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 11 additions & 8 deletions example_dags/example_amazon_s3_snowflake_transform.py
Expand Up @@ -21,17 +21,17 @@ def combine_data(center_1: Table, center_2: Table):
@aql.transform()
def clean_data(input_table: Table):
return """SELECT *
FROM {{input_table}} WHERE TYPE NOT LIKE 'Guinea Pig'
FROM {{input_table}} WHERE type NOT LIKE 'Guinea Pig'
"""


@aql.dataframe(identifiers_as_lower=False)
@aql.dataframe()
Copy link
Collaborator Author

@tatiana tatiana Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to handle capitalization more consistently across dataframes / loading file after implementing: #564

def aggregate_data(df: pd.DataFrame):
adoption_reporting_dataframe = df.pivot_table(
new_df = df.pivot_table(
index="date", values="name", columns=["type"], aggfunc="count"
).reset_index()

return adoption_reporting_dataframe
new_df.columns = new_df.columns.str.lower()
return new_df


@dag(
Expand Down Expand Up @@ -67,11 +67,11 @@ def example_amazon_s3_snowflake_transform():
)

temp_table_1 = aql.load_file(
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_1.csv"),
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_1_unquoted.csv"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CSV was very odd beforehand in the format:

"header1","header2"
"value1","value2"

The COPY INTO command wasn't acceptable in this format, and we decided that this file could be cleaned beforehand. Users may face similar issues in future.

output_table=input_table_1,
)
temp_table_2 = aql.load_file(
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_2.csv"),
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_2_unquoted.csv"),
output_table=input_table_2,
)

Expand All @@ -85,7 +85,10 @@ def example_amazon_s3_snowflake_transform():
cleaned_data,
output_table=Table(
name="aggregated_adoptions_" + str(int(time.time())),
metadata=Metadata(schema=os.environ["SNOWFLAKE_SCHEMA"]),
metadata=Metadata(
schema=os.environ["SNOWFLAKE_SCHEMA"],
database=os.environ["SNOWFLAKE_DATABASE"],
),
conn_id="snowflake_conn",
),
)
Expand Down
131 changes: 108 additions & 23 deletions src/astro/databases/snowflake.py
@@ -1,16 +1,17 @@
"""Snowflake database implementation."""
import logging
import os
import random
import string
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import pandas as pd
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from pandas.io.sql import SQLDatabase
from snowflake.connector import pandas_tools
from snowflake.connector.errors import ProgrammingError

from astro import settings
from astro.constants import (
DEFAULT_CHUNK_SIZE,
FileLocation,
Expand All @@ -36,6 +37,14 @@
FileType.PARQUET: "MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE",
}

DEFAULT_STORAGE_INTEGRATION = {
FileLocation.S3: settings.SNOWFLAKE_STORAGE_INTEGRATION_AMAZON,
FileLocation.GS: settings.SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE,
}

NATIVE_LOAD_SUPPORTED_FILE_TYPES = (FileType.CSV, FileType.NDJSON, FileType.PARQUET)
NATIVE_LOAD_SUPPORTED_FILE_LOCATIONS = (FileLocation.GS, FileLocation.S3)


@dataclass
class SnowflakeStage:
Expand Down Expand Up @@ -138,7 +147,6 @@ class SnowflakeDatabase(BaseDatabase):
"""

def __init__(self, conn_id: str = DEFAULT_CONN_ID):
self.storage_integration: Optional[str] = None
super().__init__(conn_id)

@property
Expand Down Expand Up @@ -192,7 +200,9 @@ def _create_stage_auth_sub_statement(
:param storage_integration: Previously created Snowflake storage integration
:return: String containing line to be used for authentication on the remote storage
"""

storage_integration = storage_integration or DEFAULT_STORAGE_INTEGRATION.get(
file.location.location_type
)
if storage_integration is not None:
auth = f"storage_integration = {storage_integration};"
else:
Expand Down Expand Up @@ -291,6 +301,93 @@ def drop_stage(self, stage: SnowflakeStage) -> None:
# Table load methods
# ---------------------------------------------------------

def create_table_using_schema_autodetection(
self,
table: Table,
file: Optional[File] = None,
dataframe: Optional[pd.DataFrame] = None,
) -> None:
"""
Create a SQL table, automatically inferring the schema using the given file.

:param table: The table to be created.
:param file: File used to infer the new table columns.
:param dataframe: Dataframe used to infer the new table columns if there is no file
"""
if file:
dataframe = file.export_to_dataframe(
nrows=settings.LOAD_TABLE_AUTODETECT_ROWS_COUNT
)

# Snowflake doesn't handle well mixed capitalisation of column name chars
# we are handling this more gracefully in a separate PR
if dataframe is not None:
dataframe.columns.str.upper()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to handle capitalization more consistently across dataframes / loading file after implementing: #564


super().create_table_using_schema_autodetection(table, dataframe=dataframe)

def is_native_load_file_available(
self, source_file: File, target_table: Table
) -> bool:
"""
Check if there is an optimised path for source to destination.

:param source_file: File from which we need to transfer data
:param target_table: Table that needs to be populated with file data
"""
is_file_type_supported = (
source_file.type.name in NATIVE_LOAD_SUPPORTED_FILE_TYPES
)
is_file_location_supported = (
source_file.location.location_type in NATIVE_LOAD_SUPPORTED_FILE_LOCATIONS
)
return is_file_type_supported and is_file_location_supported

def load_file_to_table_natively(
self,
source_file: File,
target_table: Table,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: Optional[Dict] = None,
**kwargs,
) -> None:
"""
Load the content of a file to an existing Snowflake table natively by:
- Creating a Snowflake external stage
- Using Snowflake COPY INTO statement

Requirements:
- The user must have permissions to create a STAGE in Snowflake.
- If loading from GCP Cloud Storage, `native_support_kwargs` must define `storage_integration`
- If loading from AWS S3, the credentials for creating the stage may be
retrieved from the Airflow connection or from the `storage_integration`
attribute within `native_support_kwargs`.

:param source_file: File from which we need to transfer data
:param target_table: Table to which the content of the file will be loaded to
:param if_exists: Strategy used to load (currently supported: "append" or "replace")
:param native_support_kwargs: may be used for the stage creation, as described above.

.. seealso::
`Snowflake official documentation on COPY INTO
<https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html>`_
`Snowflake official documentation on CREATE STAGE
<https://docs.snowflake.com/en/sql-reference/sql/create-stage.html>`_

"""
native_support_kwargs = native_support_kwargs or {}
storage_integration = native_support_kwargs.get("storage_integration")
stage = self.create_stage(
file=source_file, storage_integration=storage_integration
)
table_name = self.get_table_qualified_name(target_table)
file_path = os.path.basename(source_file.path) or ""
sql_statement = (
f"COPY INTO {table_name} FROM @{stage.qualified_name}/{file_path}"
)
self.hook.run(sql_statement)
self.drop_stage(stage)

def load_pandas_dataframe_to_table(
self,
source_dataframe: pd.DataFrame,
Expand All @@ -307,27 +404,15 @@ def load_pandas_dataframe_to_table(
:param if_exists: Strategy to be used in case the target table already exists.
:param chunk_size: Specify the number of rows in each batch to be written at a time.
"""
db = SQLDatabase(engine=self.sqlalchemy_engine)
# Make columns uppercase to prevent weird errors in snowflake
source_dataframe.columns = source_dataframe.columns.str.upper()
schema = None
if target_table.metadata:
schema = getattr(target_table.metadata, "schema", None)

# within prep_table() we use pandas drop() function which is used when we pass 'if_exists=replace'.
# There is an issue where has_table() works with uppercase table names but the function meta.reflect() don't.
# To prevent the issue we are passing table name in lowercase.
db.prep_table(
source_dataframe,
target_table.name.lower(),
schema=schema,
if_exists=if_exists,
index=False,
)
self.create_table(target_table, dataframe=source_dataframe)

self.table_exists(target_table)
pandas_tools.write_pandas(
self.hook.get_conn(),
source_dataframe,
target_table.name,
conn=self.hook.get_conn(),
df=source_dataframe,
table_name=target_table.name,
schema=target_table.metadata.schema,
database=target_table.metadata.database,
chunk_size=chunk_size,
quote_identifiers=False,
)
Expand Down
1 change: 1 addition & 0 deletions tests/benchmark/Dockerfile
Expand Up @@ -11,6 +11,7 @@ ENV AIRFLOW_HOME=/opt/app/
ENV PYTHONPATH=/opt/app/
ENV ASTRO_PUBLISH_BENCHMARK_DATA=True
ENV GCP_BUCKET=dag-authoring
ENV AIRFLOW__ASTRO_SDK__SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE=gcs_int_python_sdk

# Debian Bullseye is shipped with Python 3.9
# Upgrade built-in pip
Expand Down
7 changes: 5 additions & 2 deletions tests/benchmark/Makefile
Expand Up @@ -19,6 +19,10 @@ clean:
@rm -f unittests.cfg
@rm -f unittests.db
@rm -f webserver_config.py
@rm -f ../../unittests.cfg
@rm -f ../../unittests.db
@rm -f ../../airflow.cfg
@rm -f ../../airflow.db

# Takes approximately 7min
setup_gke:
Expand All @@ -45,8 +49,7 @@ local: check_google_credentials
benchmark
@rm -rf astro-sdk


run_job: check_google_credentials
run_job:
@gcloud container clusters get-credentials astro-sdk --zone us-central1-a --project ${GCP_PROJECT}
@kubectl apply -f infrastructure/kubernetes/namespace.yaml
@kubectl apply -f infrastructure/kubernetes/postgres.yaml
Expand Down
12 changes: 6 additions & 6 deletions tests/benchmark/config.json
@@ -1,5 +1,11 @@
{
"databases": [
{
"name": "snowflake",
"params": {
"conn_id": "snowflake_conn"
}
},
{
"name": "postgres",
"params": {
Expand All @@ -17,12 +23,6 @@
"database": "bigquery"
}
}
},
{
"name": "snowflake",
"params": {
"conn_id": "snowflake_conn"
}
}
],
"datasets": [
Expand Down
12 changes: 12 additions & 0 deletions tests/benchmark/debug.yaml
@@ -0,0 +1,12 @@
apiVersion: v1
kind: Pod
metadata:
name: troubleshoot
namespace: benchmark
spec:
containers:
- name: troubleshoot-benchmark
image: gcr.io/astronomer-dag-authoring/benchmark
# Just spin & wait forever
command: [ "/bin/bash", "-c", "--" ]
args: [ "while true; do sleep 30; done;" ]
20 changes: 19 additions & 1 deletion tests/benchmark/results.md
Expand Up @@ -85,7 +85,6 @@ The benchmark was run as a Kubernetes job in GKE:
* Container resource limit:
* Memory: 10 Gi


| database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system |
|:-----------|:-----------|:-------------|:-------------|:----------------|:------------------|
| snowflake | ten_kb | 4.75s | 59.3MB | 1.45s | 100.0ms |
Expand All @@ -95,6 +94,25 @@ The benchmark was run as a Kubernetes job in GKE:
| snowflake | five_gb | 24.46min | 97.85MB | 1.43min | 5.94s |
| snowflake | ten_gb | 50.85min | 104.53MB | 2.7min | 12.11s |

### With native support

The benchmark was run as a Kubernetes job in GKE:

* Version: `astro-sdk-python` 1.0.0a1 (`bc58830`)
* Machine type: `n2-standard-4`
* vCPU: 4
* Memory: 16 GB RAM
* Container resource limit:
* Memory: 10 Gi

| database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system |
|:-----------|:-----------|:-------------|:-------------|:----------------|:------------------|
| snowflake | ten_kb | 9.1s | 56.45MB | 2.56s | 110.0ms |
| snowflake | hundred_kb | 9.19s | 45.4MB | 2.55s | 120.0ms |
| snowflake | ten_mb | 10.9s | 47.51MB | 2.58s | 160.0ms |
| snowflake | one_gb | 1.07min | 47.94MB | 8.7s | 5.67s |
| snowflake | five_gb | 5.49min | 53.69MB | 18.76s | 1.6s |
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll log a ticket for us to further investigate why we weren't able to run the benchmark with 10 GB for this implementation.


### Database: postgres

| database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system |
Expand Down
6 changes: 0 additions & 6 deletions tests/data/README.md

This file was deleted.

37 changes: 37 additions & 0 deletions tests/databases/test_snowflake.py
Expand Up @@ -451,3 +451,40 @@ def test_create_stage_amazon_fails_due_to_no_credentials(get_credentials):
"In order to create an stage for S3, one of the following is required"
)
assert exc_info.match(expected_msg)


@pytest.mark.integration
@pytest.mark.parametrize(
"database_table_fixture",
[
{"database": Database.SNOWFLAKE},
],
indirect=True,
ids=["snowflake"],
)
@pytest.mark.parametrize(
"remote_files_fixture",
[
{"provider": "amazon", "filetype": FileType.CSV},
],
indirect=True,
ids=["amazon_csv"],
)
def test_load_file_to_table_natively(remote_files_fixture, database_table_fixture):
"""Load a file to a Snowflake table using the native optimisation."""
filepath = remote_files_fixture[0]
database, target_table = database_table_fixture
database.load_file_to_table(
tatiana marked this conversation as resolved.
Show resolved Hide resolved
File(filepath), target_table, {}, use_native_support=True
)

df = database.hook.get_pandas_df(f"SELECT * FROM {target_table.name}")
assert len(df) == 3
expected = pd.DataFrame(
[
{"id": 1, "name": "First"},
{"id": 2, "name": "Second"},
{"id": 3, "name": "Third with unicode पांचाल"},
]
)
test_utils.assert_dataframes_are_equal(df, expected)