Skip to content

Commit

Permalink
make get_dag endpoint consistent w/ serializeddag
Browse files Browse the repository at this point in the history
update get_dag tests to handle the new behavior
update dag_schema to get tags the same way as dag_details_schema
  • Loading branch information
psg2 committed Jul 7, 2021
1 parent 2b7c596 commit fcae7a2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 41 deletions.
10 changes: 5 additions & 5 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Expand Up @@ -34,14 +34,14 @@


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
@provide_session
def get_dag(dag_id, session):
def get_dag(dag_id):
"""Get basic information about a DAG."""
dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none()

try:
dag: DAG = current_app.dag_bag.get_dag(dag_id)
except SerializedDagNotFound:
raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found")
if dag is None:
raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found")

return dag_schema.dump(dag)


Expand Down
10 changes: 9 additions & 1 deletion airflow/api_connexion/schemas/dag_schema.py
Expand Up @@ -56,7 +56,15 @@ class Meta:
owners = fields.Method("get_owners", dump_only=True)
description = auto_field(dump_only=True)
schedule_interval = fields.Nested(ScheduleIntervalSchema)
tags = fields.List(fields.Nested(DagTagSchema), dump_only=True)
tags = fields.Method("get_tags", dump_only=True)

@staticmethod
def get_tags(obj: DAG):
"""Dumps tags as objects"""
tags = obj.tags
if tags:
return [DagTagSchema().dump(dict(name=tag)) for tag in tags]
return []

@staticmethod
def get_owners(obj: DagModel):
Expand Down
75 changes: 40 additions & 35 deletions tests/api_connexion/endpoints/test_dag_endpoint.py
Expand Up @@ -31,7 +31,6 @@
from airflow.security import permissions
from airflow.utils.session import provide_session
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags

SERIALIZER = URLSafeSerializer(conf.get('webserver', 'secret_key'))
Expand All @@ -40,6 +39,7 @@
TASK_ID = "op1"
DAG2_ID = "test_dag2"
DAG3_ID = "test_dag3"
DAG4_ID = "test_dag4"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -83,8 +83,25 @@ def configured_app(minimal_app_for_api):
with DAG(DAG3_ID) as dag3: # DAG start_date set to None
DummyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12))

with DAG(DAG4_ID, schedule_interval=None) as dag4: # DAG schedule_interval set to None
DummyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12))

dag_bag = DagBag(os.devnull, include_examples=False)
dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3}
dag_bag.dags = {
dag.dag_id: dag,
dag2.dag_id: dag2,
dag3.dag_id: dag3,
dag4.dag_id: dag4,
}

app.appbuilder.sm.sync_perm_for_dag( # type: ignore
dag.dag_id,
access_control={'TestGranularDag': [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]},
)
app.appbuilder.sm.sync_perm_for_dag( # type: ignore
dag.dag_id,
access_control={'TestGranularDag': [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]},
)

app.dag_bag = dag_bag

Expand All @@ -110,6 +127,7 @@ def setup_attrs(self, configured_app) -> None:
self.dag_id = DAG_ID
self.dag2_id = DAG2_ID
self.dag3_id = DAG3_ID
self.dag4_id = DAG4_ID

def teardown_method(self) -> None:
self.clean_db()
Expand Down Expand Up @@ -137,54 +155,44 @@ def _create_deactivated_dag(self, session=None):


class TestGetDag(TestDagEndpoint):
@conf_vars({("webserver", "secret_key"): "mysecret"})
def test_should_respond_200(self):
self._create_dag_models(1)
response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={'REMOTE_USER': "test"})
response = self.client.get(f"/api/v1/dags/{self.dag_id}", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
assert {
"dag_id": "TEST_DAG_1",
"dag_id": self.dag_id,
"description": None,
"fileloc": "/tmp/dag_1.py",
"file_token": 'Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk',
"is_paused": False,
"is_active": True,
"fileloc": __file__,
"file_token": FILE_TOKEN,
"is_paused": None,
"is_subdag": False,
"owners": [],
"root_dag_id": None,
"schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"},
"tags": [],
"schedule_interval": {
"__type": "TimeDelta",
"days": 1,
"microseconds": 0,
"seconds": 0,
},
"tags": [{'name': 'example'}],
} == response.json

@conf_vars({("webserver", "secret_key"): "mysecret"})
def test_should_respond_200_with_schedule_interval_none(self, session):
dag_model = DagModel(
dag_id="TEST_DAG_1",
fileloc="/tmp/dag_1.py",
schedule_interval=None,
)
session.add(dag_model)
session.commit()
response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={'REMOTE_USER': "test"})
response = self.client.get(f"/api/v1/dags/{self.dag4_id}", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
assert {
"dag_id": "TEST_DAG_1",
"dag_id": self.dag4_id,
"description": None,
"fileloc": "/tmp/dag_1.py",
"file_token": 'Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk',
"is_paused": False,
"is_active": False,
"fileloc": __file__,
"file_token": FILE_TOKEN,
"is_paused": None,
"is_subdag": False,
"owners": [],
"root_dag_id": None,
"schedule_interval": None,
"tags": [],
} == response.json

def test_should_respond_200_with_granular_dag_access(self):
self._create_dag_models(1)
response = self.client.get(
"/api/v1/dags/TEST_DAG_1", environ_overrides={'REMOTE_USER': "test_granular_permissions"}
f"/api/v1/dags/{self.dag_id}", environ_overrides={'REMOTE_USER': "test_granular_permissions"}
)
assert response.status_code == 200

Expand All @@ -193,9 +201,7 @@ def test_should_respond_404(self):
assert response.status_code == 404

def test_should_raises_401_unauthenticated(self):
self._create_dag_models(1)

response = self.client.get("/api/v1/dags/TEST_DAG_1")
response = self.client.get(f"/api/v1/dags/{self.dag_id}")

assert_401(response)

Expand All @@ -206,9 +212,8 @@ def test_should_raise_403_forbidden(self):
assert response.status_code == 403

def test_should_respond_403_with_granular_access_for_different_dag(self):
self._create_dag_models(3)
response = self.client.get(
"/api/v1/dags/TEST_DAG_2", environ_overrides={'REMOTE_USER': "test_granular_permissions"}
"/api/v1/dags/{self.dag2_id}", environ_overrides={'REMOTE_USER': "test_granular_permissions"}
)
assert response.status_code == 403

Expand Down

0 comments on commit fcae7a2

Please sign in to comment.