Skip to content

Commit

Permalink
Deprecate some functions in the experimental API (#19931)
Browse files Browse the repository at this point in the history
This PR seeks to deprecate some functions in the experimental API.
Some of the deprecated functions are only used in the experimental REST API,
others that are valid are being moved out of the experimental package.

(cherry picked from commit 6239ae9)
  • Loading branch information
ephraimbuddy authored and jedcunningham committed Feb 17, 2022
1 parent daebc58 commit 663bb54
Show file tree
Hide file tree
Showing 20 changed files with 435 additions and 198 deletions.
29 changes: 21 additions & 8 deletions airflow/api/client/local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
"""Local client API"""

from airflow.api.client import api_client
from airflow.api.common.experimental import delete_dag, pool, trigger_dag
from airflow.api.common import delete_dag, trigger_dag
from airflow.api.common.experimental.get_lineage import get_lineage as get_lineage_api
from airflow.exceptions import AirflowBadRequest, PoolNotFound
from airflow.models.pool import Pool


class Client(api_client.Client):
Expand All @@ -36,19 +38,30 @@ def delete_dag(self, dag_id):
return f"Removed {count} record(s)"

def get_pool(self, name):
the_pool = pool.get_pool(name=name)
return the_pool.pool, the_pool.slots, the_pool.description
pool = Pool.get_pool(pool_name=name)
if not pool:
raise PoolNotFound(f"Pool {name} not found")
return pool.pool, pool.slots, pool.description

def get_pools(self):
return [(p.pool, p.slots, p.description) for p in pool.get_pools()]
return [(p.pool, p.slots, p.description) for p in Pool.get_pools()]

def create_pool(self, name, slots, description):
the_pool = pool.create_pool(name=name, slots=slots, description=description)
return the_pool.pool, the_pool.slots, the_pool.description
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")
pool_name_length = Pool.pool.property.columns[0].type.length
if len(name) > pool_name_length:
raise AirflowBadRequest(f"pool name cannot be more than {pool_name_length} characters")
try:
slots = int(slots)
except ValueError:
raise AirflowBadRequest(f"Bad value for `slots`: {slots}")
pool = Pool.create_or_update_pool(name=name, slots=slots, description=description)
return pool.pool, pool.slots, pool.description

def delete_pool(self, name):
the_pool = pool.delete_pool(name=name)
return the_pool.pool, the_pool.slots, the_pool.description
pool = Pool.delete_pool(name=name)
return pool.pool, pool.slots, pool.description

def get_lineage(self, dag_id, execution_date):
lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date)
Expand Down
83 changes: 83 additions & 0 deletions airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Delete DAGs APIs."""
import logging

from sqlalchemy import or_

from airflow import models
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DagModel, TaskFail
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.db import get_sqla_model_classes
from airflow.utils.session import provide_session
from airflow.utils.state import State

log = logging.getLogger(__name__)


@provide_session
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int:
"""
:param dag_id: the dag_id of the DAG to delete
:param keep_records_in_log: whether keep records of the given dag_id
in the Log table in the backend database (for reasons like auditing).
The default value is True.
:param session: session used
:return count of deleted dags
"""
log.info("Deleting DAG: %s", dag_id)
running_tis = (
session.query(models.TaskInstance.state)
.filter(models.TaskInstance.dag_id == dag_id)
.filter(models.TaskInstance.state == State.RUNNING)
.first()
)
if running_tis:
raise AirflowException("TaskInstances still running")
dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
if dag is None:
raise DagNotFound(f"Dag id {dag_id} not found")

# Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
# There may be a lag, so explicitly removes serialized DAG here.
if SerializedDagModel.has_dag(dag_id=dag_id, session=session):
SerializedDagModel.remove_dag(dag_id=dag_id, session=session)

count = 0

for model in get_sqla_model_classes():
if hasattr(model, "dag_id"):
if keep_records_in_log and model.__name__ == 'Log':
continue
cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%"))
count += session.query(model).filter(cond).delete(synchronize_session='fetch')
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
count += (
session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete()
)

# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
synchronize_session='fetch'
)

return count
70 changes: 7 additions & 63 deletions airflow/api/common/experimental/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,68 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Delete DAGs APIs."""
import logging
import warnings

from sqlalchemy import or_
from airflow.api.common.delete_dag import * # noqa

from airflow import models
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DagModel, TaskFail
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.session import provide_session
from airflow.utils.state import State

log = logging.getLogger(__name__)


@provide_session
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int:
"""
:param dag_id: the dag_id of the DAG to delete
:param keep_records_in_log: whether keep records of the given dag_id
in the Log table in the backend database (for reasons like auditing).
The default value is True.
:param session: session used
:return count of deleted dags
"""
log.info("Deleting DAG: %s", dag_id)
running_tis = (
session.query(models.TaskInstance.state)
.filter(models.TaskInstance.dag_id == dag_id)
.filter(models.TaskInstance.state == State.RUNNING)
.first()
)
if running_tis:
raise AirflowException("TaskInstances still running")
dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
if dag is None:
raise DagNotFound(f"Dag id {dag_id} not found")

# Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
# There may be a lag, so explicitly removes serialized DAG here.
if SerializedDagModel.has_dag(dag_id=dag_id, session=session):
SerializedDagModel.remove_dag(dag_id=dag_id, session=session)

count = 0

for model in models.base.Base._decl_class_registry.values():
if hasattr(model, "dag_id"):
if keep_records_in_log and model.__name__ == 'Log':
continue
cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%"))
count += session.query(model).filter(cond).delete(synchronize_session='fetch')
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
count += (
session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete()
)

# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
synchronize_session='fetch'
)

return count
warnings.warn(
"This module is deprecated. Please use `airflow.api.common.delete_dag` instead.",
DeprecationWarning,
stacklevel=2,
)
3 changes: 3 additions & 0 deletions airflow/api/common/experimental/get_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
# specific language governing permissions and limitations
# under the License.
"""Get code APIs."""
from deprecated import deprecated

from airflow.api.common.experimental import check_and_get_dag
from airflow.exceptions import AirflowException, DagCodeNotFound
from airflow.models.dagcode import DagCode


@deprecated(reason="Use DagCode().get_code_by_fileloc() instead", version="2.2.3")
def get_code(dag_id: str) -> str:
"""Return python code of a given dag_id.
Expand Down
3 changes: 3 additions & 0 deletions airflow/api/common/experimental/get_dag_run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from datetime import datetime
from typing import Dict

from deprecated import deprecated

from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun


@deprecated(reason="Use DagRun().get_state() instead", version="2.2.3")
def get_dag_run_state(dag_id: str, execution_date: datetime) -> Dict[str, str]:
"""Return the Dag Run state identified by the given dag_id and execution_date.
Expand Down
3 changes: 3 additions & 0 deletions airflow/api/common/experimental/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# specific language governing permissions and limitations
# under the License.
"""Task APIs.."""
from deprecated import deprecated

from airflow.api.common.experimental import check_and_get_dag
from airflow.models import TaskInstance


@deprecated(reason="Use DAG().get_task", version="2.2.3")
def get_task(dag_id: str, task_id: str) -> TaskInstance:
"""Return the task object identified by the given dag_id and task_id."""
dag = check_and_get_dag(dag_id, task_id)
Expand Down
3 changes: 3 additions & 0 deletions airflow/api/common/experimental/get_task_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
"""Task Instance APIs."""
from datetime import datetime

from deprecated import deprecated

from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun
from airflow.exceptions import TaskInstanceNotFound
from airflow.models import TaskInstance


@deprecated(version="2.2.3", reason="Use DagRun.get_task_instance instead")
def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> TaskInstance:
"""Return the task instance identified by the given dag_id, task_id and execution_date."""
dag = check_and_get_dag(dag_id, task_id)
Expand Down
6 changes: 6 additions & 0 deletions airflow/api/common/experimental/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
# specific language governing permissions and limitations
# under the License.
"""Pool APIs."""
from deprecated import deprecated

from airflow.exceptions import AirflowBadRequest, PoolNotFound
from airflow.models import Pool
from airflow.utils.session import provide_session


@deprecated(reason="Use Pool.get_pool() instead", version="2.2.3")
@provide_session
def get_pool(name, session=None):
"""Get pool by a given name."""
Expand All @@ -34,12 +37,14 @@ def get_pool(name, session=None):
return pool


@deprecated(reason="Use Pool.get_pools() instead", version="2.2.3")
@provide_session
def get_pools(session=None):
"""Get all pools."""
return session.query(Pool).all()


@deprecated(reason="Use Pool.create_pool() instead", version="2.2.3")
@provide_session
def create_pool(name, slots, description, session=None):
"""Create a pool with a given parameters."""
Expand Down Expand Up @@ -70,6 +75,7 @@ def create_pool(name, slots, description, session=None):
return pool


@deprecated(reason="Use Pool.delete_pool() instead", version="2.2.3")
@provide_session
def delete_pool(name, session=None):
"""Delete pool by a given name."""
Expand Down

0 comments on commit 663bb54

Please sign in to comment.