Skip to content

Commit

Permalink
Make SparkSqlHook use Connection (#15794)
Browse files Browse the repository at this point in the history
* Make SparkSqlHook use Connection

This allows a SparkSqlHook to be created without a backing Connection,
and if a backing Connection *is* found, use it to provide the default
arguments not explicitly passed into the hook.

* Properly clean connections for Spark tests

* Expected Connection values in SparkSqlHook tests

Now that SparkSqlHook defaults to read values from Connection if a
config is not explicitly provided, we need to tweak the tests to reflect
this expectation.
  • Loading branch information
uranusjr committed Jun 14, 2021
1 parent 943292b commit 5c86e3d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
36 changes: 30 additions & 6 deletions airflow/providers/apache/spark/hooks/spark_sql.py
Expand Up @@ -17,11 +17,14 @@
# under the License.
#
import subprocess
from typing import Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from airflow.models.connection import Connection


class SparkSqlHook(BaseHook):
"""
Expand All @@ -45,14 +48,16 @@ class SparkSqlHook(BaseHook):
:param keytab: Full path to the file that contains the keytab
:type keytab: str
:param master: spark://host:port, mesos://host:port, yarn, or local
(Default: The ``host`` and ``port`` set in the Connection, or ``"yarn"``)
:type master: str
:param name: Name of the job.
:type name: str
:param num_executors: Number of executors to launch
:type num_executors: int
:param verbose: Whether to pass the verbose flag to spark-sql
:type verbose: bool
:param yarn_queue: The YARN queue to submit to (Default: "default")
:param yarn_queue: The YARN queue to submit to
(Default: The ``queue`` value set in the Connection, or ``"default"``)
:type yarn_queue: str
"""

Expand All @@ -72,16 +77,35 @@ def __init__(
executor_memory: Optional[str] = None,
keytab: Optional[str] = None,
principal: Optional[str] = None,
master: str = 'yarn',
master: Optional[str] = None,
name: str = 'default-name',
num_executors: Optional[int] = None,
verbose: bool = True,
yarn_queue: str = 'default',
yarn_queue: Optional[str] = None,
) -> None:
super().__init__()

try:
conn: "Optional[Connection]" = self.get_connection(conn_id)
except AirflowNotFoundException:
conn = None
options = {}
else:
options = conn.extra_dejson

# Set arguments to values set in Connection if not explicitly provided.
if master is None:
if conn is None:
master = "yarn"
elif conn.port:
master = f"{conn.host}:{conn.port}"
else:
master = conn.host
if yarn_queue is None:
yarn_queue = options.get("queue", "default")

self._sql = sql
self._conf = conf
self._conn = self.get_connection(conn_id)
self._total_executor_cores = total_executor_cores
self._executor_cores = executor_cores
self._executor_memory = executor_memory
Expand Down
8 changes: 5 additions & 3 deletions airflow/providers/apache/spark/operators/spark_sql.py
Expand Up @@ -47,14 +47,16 @@ class SparkSqlOperator(BaseOperator):
:param keytab: Full path to the file that contains the keytab
:type keytab: str
:param master: spark://host:port, mesos://host:port, yarn, or local
(Default: The ``host`` and ``port`` set in the Connection, or ``"yarn"``)
:type master: str
:param name: Name of the job
:type name: str
:param num_executors: Number of executors to launch
:type num_executors: int
:param verbose: Whether to pass the verbose flag to spark-sql
:type verbose: bool
:param yarn_queue: The YARN queue to submit to (Default: "default")
:param yarn_queue: The YARN queue to submit to
(Default: The ``queue`` value set in the Connection, or ``"default"``)
:type yarn_queue: str
"""

Expand All @@ -73,11 +75,11 @@ def __init__(
executor_memory: Optional[str] = None,
keytab: Optional[str] = None,
principal: Optional[str] = None,
master: str = 'yarn',
master: Optional[str] = None,
name: str = 'default-name',
num_executors: Optional[int] = None,
verbose: bool = True,
yarn_queue: str = 'default',
yarn_queue: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down
Expand Up @@ -27,7 +27,7 @@ The Apache Spark connection type enables connection to Apache Spark.
Default Connection IDs
----------------------

Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by default, Spark SQL hooks and operators point to ``spark_sql_default`` by default, but don't use it.
Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by default. Spark SQL hooks and operators point to ``spark_sql_default`` by default.

Configuring the Connection
--------------------------
Expand Down
18 changes: 12 additions & 6 deletions tests/providers/apache/spark/hooks/test_spark_sql.py
Expand Up @@ -27,6 +27,7 @@
from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook
from airflow.utils import db
from tests.test_utils.db import clear_db_connections


def get_after(sentinel, iterable):
Expand All @@ -49,10 +50,15 @@ class TestSparkSqlHook(unittest.TestCase):
'conf': 'key=value,PROP=VALUE',
}

def setUp(self):

@classmethod
def setUpClass(cls) -> None:
clear_db_connections(add_default_connections_back=False)
db.merge_conn(Connection(conn_id='spark_default', conn_type='spark', host='yarn://yarn-master'))

@classmethod
def tearDownClass(cls) -> None:
clear_db_connections(add_default_connections_back=True)

def test_build_command(self):
hook = SparkSqlHook(**self._config)

Expand Down Expand Up @@ -95,7 +101,7 @@ def test_spark_process_runcmd(self, mock_popen):
'-e',
'SELECT 1',
'--master',
'yarn',
'yarn://yarn-master',
'--name',
'default-name',
'--verbose',
Expand All @@ -112,7 +118,7 @@ def test_spark_process_runcmd(self, mock_popen):
'-e',
'SELECT 1',
'--master',
'yarn',
'yarn://yarn-master',
'--name',
'default-name',
'--verbose',
Expand All @@ -139,7 +145,7 @@ def test_spark_process_runcmd_with_str(self, mock_popen):
'-e',
'SELECT 1',
'--master',
'yarn',
'yarn://yarn-master',
'--name',
'default-name',
'--verbose',
Expand Down Expand Up @@ -168,7 +174,7 @@ def test_spark_process_runcmd_with_list(self, mock_popen):
'-e',
'SELECT 1',
'--master',
'yarn',
'yarn://yarn-master',
'--name',
'default-name',
'--verbose',
Expand Down

0 comments on commit 5c86e3d

Please sign in to comment.