Skip to content

Commit

Permalink
Use load_connections_dict in connections import
Browse files Browse the repository at this point in the history
  • Loading branch information
natanweinberger authored and ashb committed Jun 10, 2021
1 parent 5cd4b84 commit dd2745e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
12 changes: 4 additions & 8 deletions airflow/cli/commands/connection_command.py
Expand Up @@ -28,7 +28,7 @@
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.secrets.local_filesystem import _parse_secret_file
from airflow.secrets.local_filesystem import load_connections_dict
from airflow.utils import cli as cli_utils, yaml
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.session import create_session
Expand Down Expand Up @@ -247,17 +247,13 @@ def connections_import(args):

def _import_helper(file_path):
"""Load connections from a file and save them to the DB. On collision, skip."""
connections_dict = _parse_secret_file(file_path)
connections_dict = load_connections_dict(file_path)
with create_session() as session:
for conn_id, conn_dict in connections_dict.items():
for conn_id, conn in connections_dict.items():
if session.query(Connection).filter(Connection.conn_id == conn_id).first():
print(f'Could not import connection {conn_id}: connection already exists.')
continue

if "extra_dejson" in conn_dict:
conn_dict["extra"] = conn_dict.pop("extra_dejson")
# Add the connection to the DB
connection = Connection(conn_id, **conn_dict)
session.add(connection)
session.add(conn)
session.commit()
print(f'Imported connection {conn_id}')
6 changes: 4 additions & 2 deletions tests/cli/commands/test_connection_command.py
Expand Up @@ -758,7 +758,7 @@ def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
):
connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))

@mock.patch('airflow.cli.commands.connection_command._parse_secret_file')
@mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_load_connections(self, mock_exists, mock_parse_secret_file):
mock_exists.return_value = True
Expand Down Expand Up @@ -799,6 +799,7 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_id",
"conn_type",
"description",
"host",
Expand All @@ -816,7 +817,7 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
assert expected_connections == current_conns_as_dicts

@provide_session
@mock.patch('airflow.cli.commands.connection_command._parse_secret_file')
@mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_not_overwrite_existing_connections(
self, mock_exists, mock_parse_secret_file, session=None
Expand Down Expand Up @@ -875,6 +876,7 @@ def test_cli_connections_import_should_not_overwrite_existing_connections(
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_id",
"conn_type",
"description",
"host",
Expand Down

0 comments on commit dd2745e

Please sign in to comment.