Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check user is active #26635

Merged
merged 2 commits into from Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion airflow/www/app.py
Expand Up @@ -39,7 +39,11 @@
from airflow.www.extensions.init_jinja_globals import init_jinja_globals
from airflow.www.extensions.init_manifest_files import configure_manifest_files
from airflow.www.extensions.init_robots import init_robots
from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection
from airflow.www.extensions.init_security import (
init_api_experimental_auth,
init_check_user_active,
init_xframe_protection,
)
from airflow.www.extensions.init_session import init_airflow_session_interface
from airflow.www.extensions.init_views import (
init_api_connexion,
Expand Down Expand Up @@ -152,6 +156,7 @@ def create_app(config=None, testing=False):
init_jinja_globals(flask_app)
init_xframe_protection(flask_app)
init_airflow_session_interface(flask_app)
init_check_user_active(flask_app)
return flask_app


Expand Down
11 changes: 11 additions & 0 deletions airflow/www/extensions/init_security.py
Expand Up @@ -19,6 +19,9 @@
import logging
from importlib import import_module

from flask import g, redirect, url_for
from flask_login import logout_user

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException

Expand Down Expand Up @@ -60,3 +63,11 @@ def init_api_experimental_auth(app):
except ImportError as err:
log.critical("Cannot import %s for API authentication due to: %s", backend, err)
raise AirflowException(err)


def init_check_user_active(app):
@app.before_request
def check_user_active():
if g.user is not None and not g.user.is_anonymous and not g.user.is_active:
logout_user()
return redirect(url_for(app.appbuilder.sm.auth_view.endpoint + ".login"))
1 change: 1 addition & 0 deletions tests/test_utils/decorators.py
Expand Up @@ -45,6 +45,7 @@ def no_op(*args, **kwargs):
"init_xframe_protection",
"init_airflow_session_interface",
"init_appbuilder",
"init_check_user_active",
]

@functools.wraps(f)
Expand Down
1 change: 1 addition & 0 deletions tests/www/views/conftest.py
Expand Up @@ -58,6 +58,7 @@ def app(examples_dag_bag):
"init_jinja_globals",
"init_plugins",
"init_airflow_session_interface",
"init_check_user_active",
]
)
def factory():
Expand Down
14 changes: 14 additions & 0 deletions tests/www/views/test_session.py
Expand Up @@ -88,3 +88,17 @@ def test_session_id_rotates(app, user_client):
new_session_cookie = get_session_cookie(user_client)
assert new_session_cookie is not None
assert old_session_cookie.value != new_session_cookie.value


def test_check_active_user(app, user_client):
user = app.appbuilder.sm.find_user(username="test_user")
user.active = False
resp = user_client.get("/home")
assert resp.status_code == 302
assert "/login" in resp.headers.get("Location")

# And they were logged out
user.active = True
resp = user_client.get("/home")
assert resp.status_code == 302
assert "/login" in resp.headers.get("Location")
13 changes: 11 additions & 2 deletions tests/www/views/test_views_base.py
Expand Up @@ -30,9 +30,18 @@
from tests.test_utils.www import check_content_in_response, check_content_not_in_response


def test_index(admin_client):
def test_index_redirect(admin_client):
resp = admin_client.get('/')
assert resp.status_code == 302
assert '/home' in resp.headers.get("Location")

resp = admin_client.get('/', follow_redirects=True)
check_content_in_response('DAGs', resp)


def test_homepage_query_count(admin_client):
with assert_queries_count(16):
resp = admin_client.get('/', follow_redirects=True)
resp = admin_client.get('/home')
check_content_in_response('DAGs', resp)


Expand Down