diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py index b690f2cf78637..8842f820f9968 100644 --- a/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/airflow/providers/apache/spark/hooks/spark_sql.py @@ -17,11 +17,14 @@ # under the License. # import subprocess -from typing import Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.hooks.base import BaseHook +if TYPE_CHECKING: + from airflow.models.connection import Connection + class SparkSqlHook(BaseHook): """ @@ -45,6 +48,7 @@ class SparkSqlHook(BaseHook): :param keytab: Full path to the file that contains the keytab :type keytab: str :param master: spark://host:port, mesos://host:port, yarn, or local + (Default: The ``host`` and ``port`` set in the Connection, or ``"yarn"``) :type master: str :param name: Name of the job. :type name: str @@ -52,7 +56,8 @@ class SparkSqlHook(BaseHook): :type num_executors: int :param verbose: Whether to pass the verbose flag to spark-sql :type verbose: bool - :param yarn_queue: The YARN queue to submit to (Default: "default") + :param yarn_queue: The YARN queue to submit to + (Default: The ``queue`` value set in the Connection, or ``"default"``) :type yarn_queue: str """ @@ -72,16 +77,35 @@ def __init__( executor_memory: Optional[str] = None, keytab: Optional[str] = None, principal: Optional[str] = None, - master: str = 'yarn', + master: Optional[str] = None, name: str = 'default-name', num_executors: Optional[int] = None, verbose: bool = True, - yarn_queue: str = 'default', + yarn_queue: Optional[str] = None, ) -> None: super().__init__() + + try: + conn: "Optional[Connection]" = self.get_connection(conn_id) + except AirflowNotFoundException: + conn = None + options = {} + else: + options = conn.extra_dejson + + # Set arguments to values set in Connection if not explicitly provided. + if master is None: + if conn is None: + master = "yarn" + elif conn.port: + master = f"{conn.host}:{conn.port}" + else: + master = conn.host + if yarn_queue is None: + yarn_queue = options.get("queue", "default") + self._sql = sql self._conf = conf - self._conn = self.get_connection(conn_id) self._total_executor_cores = total_executor_cores self._executor_cores = executor_cores self._executor_memory = executor_memory diff --git a/airflow/providers/apache/spark/operators/spark_sql.py b/airflow/providers/apache/spark/operators/spark_sql.py index 6c52fa2d58c03..af02092cfce7f 100644 --- a/airflow/providers/apache/spark/operators/spark_sql.py +++ b/airflow/providers/apache/spark/operators/spark_sql.py @@ -47,6 +47,7 @@ class SparkSqlOperator(BaseOperator): :param keytab: Full path to the file that contains the keytab :type keytab: str :param master: spark://host:port, mesos://host:port, yarn, or local + (Default: The ``host`` and ``port`` set in the Connection, or ``"yarn"``) :type master: str :param name: Name of the job :type name: str @@ -54,7 +55,8 @@ class SparkSqlOperator(BaseOperator): :type num_executors: int :param verbose: Whether to pass the verbose flag to spark-sql :type verbose: bool - :param yarn_queue: The YARN queue to submit to (Default: "default") + :param yarn_queue: The YARN queue to submit to + (Default: The ``queue`` value set in the Connection, or ``"default"``) :type yarn_queue: str """ @@ -73,11 +75,11 @@ def __init__( executor_memory: Optional[str] = None, keytab: Optional[str] = None, principal: Optional[str] = None, - master: str = 'yarn', + master: Optional[str] = None, name: str = 'default-name', num_executors: Optional[int] = None, verbose: bool = True, - yarn_queue: str = 'default', + yarn_queue: Optional[str] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/docs/apache-airflow-providers-apache-spark/connections/spark.rst b/docs/apache-airflow-providers-apache-spark/connections/spark.rst index 09d7d4ec4dc90..adc88b7f839c2 100644 --- a/docs/apache-airflow-providers-apache-spark/connections/spark.rst +++ b/docs/apache-airflow-providers-apache-spark/connections/spark.rst @@ -27,7 +27,7 @@ The Apache Spark connection type enables connection to Apache Spark. Default Connection IDs ---------------------- -Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by default, Spark SQL hooks and operators point to ``spark_sql_default`` by default, but don't use it. +Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by default. Spark SQL hooks and operators point to ``spark_sql_default`` by default. Configuring the Connection -------------------------- diff --git a/tests/providers/apache/spark/hooks/test_spark_sql.py b/tests/providers/apache/spark/hooks/test_spark_sql.py index 35e4330b882d5..7fb44d62200d4 100644 --- a/tests/providers/apache/spark/hooks/test_spark_sql.py +++ b/tests/providers/apache/spark/hooks/test_spark_sql.py @@ -27,6 +27,7 @@ from airflow.models import Connection from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook from airflow.utils import db +from tests.test_utils.db import clear_db_connections def get_after(sentinel, iterable): @@ -49,10 +50,15 @@ class TestSparkSqlHook(unittest.TestCase): 'conf': 'key=value,PROP=VALUE', } - def setUp(self): - + @classmethod + def setUpClass(cls) -> None: + clear_db_connections(add_default_connections_back=False) db.merge_conn(Connection(conn_id='spark_default', conn_type='spark', host='yarn://yarn-master')) + @classmethod + def tearDownClass(cls) -> None: + clear_db_connections(add_default_connections_back=True) + def test_build_command(self): hook = SparkSqlHook(**self._config) @@ -95,7 +101,7 @@ def test_spark_process_runcmd(self, mock_popen): '-e', 'SELECT 1', '--master', - 'yarn', + 'yarn://yarn-master', '--name', 'default-name', '--verbose', @@ -112,7 +118,7 @@ def test_spark_process_runcmd(self, mock_popen): '-e', 'SELECT 1', '--master', - 'yarn', + 'yarn://yarn-master', '--name', 'default-name', '--verbose', @@ -139,7 +145,7 @@ def test_spark_process_runcmd_with_str(self, mock_popen): '-e', 'SELECT 1', '--master', - 'yarn', + 'yarn://yarn-master', '--name', 'default-name', '--verbose', @@ -168,7 +174,7 @@ def test_spark_process_runcmd_with_list(self, mock_popen): '-e', 'SELECT 1', '--master', - 'yarn', + 'yarn://yarn-master', '--name', 'default-name', '--verbose',