diff --git a/flask_login/login_manager.py b/flask_login/login_manager.py index 5307094b..d77fbd81 100644 --- a/flask_login/login_manager.py +++ b/flask_login/login_manager.py @@ -21,7 +21,7 @@ from .signals import (user_loaded_from_cookie, user_loaded_from_header, user_loaded_from_request, user_unauthorized, user_needs_refresh, user_accessed, session_protected) -from .utils import (_get_user, login_url as make_login_url, _create_identifier, +from .utils import (login_url as make_login_url, _create_identifier, _user_context_processor, encode_cookie, decode_cookie, make_next_param, expand_login_view) @@ -75,17 +75,17 @@ def __init__(self, app=None, add_context_processor=True): #: and ``self.needs_refresh_message`` self.localize_callback = None - self.user_callback = None - self.unauthorized_callback = None self.needs_refresh_callback = None self.id_attribute = ID_ATTRIBUTE - self.header_callback = None + self._user_callback = None + + self._header_callback = None - self.request_callback = None + self._request_callback = None self._session_identifier_generator = _create_identifier @@ -185,7 +185,7 @@ def user_loader(self, callback): :param callback: The callback for retrieving a user object. :type callback: callable ''' - self.user_callback = callback + self._user_callback = callback return callback def header_loader(self, callback): @@ -200,7 +200,9 @@ def header_loader(self, callback): :param callback: The callback for retrieving a user object. :type callback: callable ''' - self.header_callback = callback + print('LoginManager.header_loader is deprecated. Use ' + + 'LoginManager.request_loader instead.') + self._header_callback = callback return callback def request_loader(self, callback): @@ -212,7 +214,7 @@ def request_loader(self, callback): :param callback: The callback for retrieving a user object. :type callback: callable ''' - self.request_callback = callback + self._request_callback = callback return callback def unauthorized_handler(self, callback): @@ -287,87 +289,64 @@ def needs_refresh(self): return redirect(redirect_url) - def reload_user(self, user=None): - ''' - This set the ctx.user with the user object loaded by your customized - user_loader callback function, which should retrieved the user object - with the user_id got from session. - - Syntax example: - from flask_login import LoginManager - @login_manager.user_loader - def any_valid_func_name(user_id): - # get your user object using the given user_id, - # if you use SQLAlchemy, for example: - user_obj = User.query.get(int(user_id)) - return user_obj - - Reason to let YOU define this self.user_callback: - Because we won't know how/where you will load you user object. - ''' - ctx = _request_ctx_stack.top + def _update_request_context_with_user(self, user=None): + '''Store the given user as ctx.user.''' - if user is None: - user_id = session.get('user_id') - if user_id is None: - ctx.user = self.anonymous_user() - else: - if self.user_callback is None: - raise Exception( - "No user_loader has been installed for this " - "LoginManager. Refer to " - "https://flask-login.readthedocs.io/" - "en/latest/#how-it-works for more info.") - user = self.user_callback(user_id) - if user is None: - ctx.user = self.anonymous_user() - else: - ctx.user = user - else: - ctx.user = user + ctx = _request_ctx_stack.top + ctx.user = self.anonymous_user() if user is None else user def _load_user(self): '''Loads user from session or remember_me cookie as applicable''' + + if self._user_callback is None and self._request_callback is None: + raise Exception( + "Missing user_loader or request_loader. Refer to " + "http://flask-login.readthedocs.io/#how-it-works " + "for more info.") + user_accessed.send(current_app._get_current_object()) - # first check SESSION_PROTECTION - config = current_app.config - if config.get('SESSION_PROTECTION', self.session_protection): - deleted = self._session_protection() - if deleted: - return self.reload_user() - - # If a remember cookie is set, and the session is not, move the - # cookie user ID to the session. - # - # However, the session may have been set if the user has been - # logged out on this request, 'remember' would be set to clear, - # so we should check for that and not restore the session. - is_missing_user_id = 'user_id' not in session - if is_missing_user_id: + # Check SESSION_PROTECTION + if self._session_protection_failed(): + return self._update_request_context_with_user() + + user = None + + # Load user from Flask Session + user_id = session.get('user_id') + if user_id is not None and self._user_callback is not None: + user = self._user_callback(user_id) + + # Load user from Remember Me Cookie or Request Loader + if user is None: + config = current_app.config cookie_name = config.get('REMEMBER_COOKIE_NAME', COOKIE_NAME) header_name = config.get('AUTH_HEADER_NAME', AUTH_HEADER_NAME) has_cookie = (cookie_name in request.cookies and session.get('remember') != 'clear') if has_cookie: - return self._load_from_cookie(request.cookies[cookie_name]) - elif self.request_callback: - return self._load_from_request(request) + cookie = request.cookies[cookie_name] + user = self._load_user_from_remember_cookie(cookie) + elif self._request_callback: + user = self._load_user_from_request(request) elif header_name in request.headers: - return self._load_from_header(request.headers[header_name]) + header = request.headers[header_name] + user = self._load_user_from_header(header) - return self.reload_user() + return self._update_request_context_with_user(user) - def _session_protection(self): + def _session_protection_failed(self): sess = session._get_current_object() ident = self._session_identifier_generator() app = current_app._get_current_object() mode = app.config.get('SESSION_PROTECTION', self.session_protection) + if not mode or mode not in ['basic', 'strong']: + return False + # if the sess is empty, it's an anonymous user or just logged out # so we can skip this - if sess and ident != sess.get('_id', None): if mode == 'basic' or sess.permanent: sess['_fresh'] = False @@ -383,39 +362,37 @@ def _session_protection(self): return False - def _load_from_cookie(self, cookie): + def _load_user_from_remember_cookie(self, cookie): user_id = decode_cookie(cookie) if user_id is not None: session['user_id'] = user_id session['_fresh'] = False - - self.reload_user() - - if _request_ctx_stack.top.user is not None: - app = current_app._get_current_object() - user_loaded_from_cookie.send(app, user=_get_user()) - - def _load_from_header(self, header): - user = None - if self.header_callback: - user = self.header_callback(header) - if user is not None: - self.reload_user(user=user) - app = current_app._get_current_object() - user_loaded_from_header.send(app, user=_get_user()) - else: - self.reload_user() - - def _load_from_request(self, request): - user = None - if self.request_callback: - user = self.request_callback(request) - if user is not None: - self.reload_user(user=user) - app = current_app._get_current_object() - user_loaded_from_request.send(app, user=_get_user()) - else: - self.reload_user() + user = None + if self._user_callback: + user = self._user_callback(user_id) + if user is not None: + app = current_app._get_current_object() + user_loaded_from_cookie.send(app, user=user) + return user + return None + + def _load_user_from_header(self, header): + if self._header_callback: + user = self._header_callback(header) + if user is not None: + app = current_app._get_current_object() + user_loaded_from_header.send(app, user=user) + return user + return None + + def _load_user_from_request(self, request): + if self._request_callback: + user = self._request_callback(request) + if user is not None: + app = current_app._get_current_object() + user_loaded_from_request.send(app, user=user) + return user + return None def _update_remember_cookie(self, response): # Don't modify the session unless there's something to do. diff --git a/flask_login/utils.py b/flask_login/utils.py index 623c4d29..904fb005 100644 --- a/flask_login/utils.py +++ b/flask_login/utils.py @@ -176,7 +176,7 @@ def login_user(user, remember=False, duration=None, force=False, fresh=True): raise Exception('duration must be a datetime.timedelta, ' 'instead got: {0}'.format(duration)) - _request_ctx_stack.top.user = user + current_app.login_manager._update_request_context_with_user(user) user_logged_in.send(current_app._get_current_object(), user=_get_user()) return True @@ -203,7 +203,7 @@ def logout_user(): user_logged_out.send(current_app._get_current_object(), user=user) - current_app.login_manager.reload_user() + current_app.login_manager._update_request_context_with_user() return True diff --git a/test_login.py b/test_login.py index be914864..c34089b5 100644 --- a/test_login.py +++ b/test_login.py @@ -117,7 +117,7 @@ def is_active(self): class AboutTestCase(unittest.TestCase): - """Make sure we can get version and other info.""" + '''Make sure we can get version and other info.''' def test_have_about_data(self): self.assertTrue(__title__ is not None) @@ -141,6 +141,10 @@ def test_static_loads_anonymous(self): lm = LoginManager() lm.init_app(app) + @lm.user_loader + def load_user(user_id): + return USERS[int(user_id)] + with app.test_client() as c: c.get('/static/favicon.ico') self.assertTrue(current_user.is_anonymous) @@ -152,6 +156,10 @@ def test_static_loads_without_accessing_session(self): lm = LoginManager() lm.init_app(app) + @lm.user_loader + def load_user(user_id): + return USERS[int(user_id)] + with app.test_client() as c: with listen_to(user_accessed) as listener: c.get('/static/favicon.ico') @@ -185,10 +193,9 @@ def test_no_user_loader_raises(self): with self.app.test_request_context(): session['user_id'] = '2' with self.assertRaises(Exception) as cm: - login_manager.reload_user() - expected_exception_message = 'No user_loader has been installed' - self.assertTrue( - str(cm.exception).startswith(expected_exception_message)) + login_manager._load_user() + expected_message = 'Missing user_loader or request_loader' + self.assertTrue(str(cm.exception).startswith(expected_message)) class MethodViewLoginTestCase(unittest.TestCase): @@ -318,10 +325,6 @@ def handle_404(e): unittest.TestCase.setUp(self) - def _get_remember_cookie(self, test_client): - our_cookies = test_client.cookie_jar._cookies['localhost.local']['/'] - return our_cookies[self.remember_cookie_name] - def _delete_session(self, c): # Helper method to cause the session to be deleted # as if the browser was closed. This will remove @@ -375,7 +378,7 @@ def test_login_inactive_user_forced(self): def test_login_user_with_header(self): user_id = 2 user_name = USERS[user_id].name - self.login_manager.request_callback = None + self.login_manager._request_callback = None with self.app.test_client() as c: basic_fmt = 'Basic {0}' decoded = bytes.decode(base64.b64encode(str.encode(str(user_id)))) @@ -386,7 +389,7 @@ def test_login_user_with_header(self): def test_login_invalid_user_with_header(self): user_id = 9000 user_name = u'Anonymous' - self.login_manager.request_callback = None + self.login_manager._request_callback = None with self.app.test_client() as c: basic_fmt = 'Basic {0}' decoded = bytes.decode(base64.b64encode(str.encode(str(user_id)))) @@ -849,7 +852,7 @@ def test_user_loaded_from_cookie_fired(self): def test_user_loaded_from_header_fired(self): user_id = 1 user_name = USERS[user_id].name - self.login_manager.request_callback = None + self.login_manager._request_callback = None with self.app.test_client() as c: with listen_to(user_loaded_from_header) as listener: headers = [ @@ -998,6 +1001,16 @@ def test_session_not_modified(self): # Ensure that if nothing changed the session is not modified. self.assertFalse(session.modified) + def test_invalid_remember_cookie(self): + domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local' + with self.app.test_client() as c: + c.get('/login-notch-remember') + with c.session_transaction() as sess: + sess['user_id'] = None + c.set_cookie(domain, self.remember_cookie_name, 'foo') + result = c.get('/username') + self.assertEqual(u'Anonymous', result.data.decode('utf-8')) + # # Session Protection # @@ -1246,6 +1259,117 @@ def test_user_context_processor(self): self.assertIsInstance(_ucp()['current_user'], AnonymousUserMixin) +class LoginViaRequestTestCase(unittest.TestCase): + ''' Tests for LoginManager.request_loader.''' + + def setUp(self): + self.app = Flask(__name__) + self.app.config['SECRET_KEY'] = 'deterministic' + self.app.config['SESSION_PROTECTION'] = None + self.remember_cookie_name = 'remember' + self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name + self.login_manager = LoginManager() + self.login_manager.init_app(self.app) + self.login_manager._login_disabled = False + + @self.app.route('/') + def index(): + return u'Welcome!' + + @self.app.route('/login-notch') + def login_notch(): + return unicode(login_user(notch)) + + @self.app.route('/username') + def username(): + if current_user.is_authenticated: + return current_user.name + return u'Anonymous', 401 + + @self.app.route('/logout') + def logout(): + return unicode(logout_user()) + + @self.login_manager.request_loader + def load_user_from_request(request): + user_id = request.args.get('user_id') or session.get('user_id') + try: + user_id = int(float(user_id)) + except TypeError: + pass + return USERS.get(user_id) + + # This will help us with the possibility of typoes in the tests. Now + # we shouldn't have to check each response to help us set up state + # (such as login pages) to make sure it worked: we will always + # get an exception raised (rather than return a 404 response) + @self.app.errorhandler(404) + def handle_404(e): + raise e + + unittest.TestCase.setUp(self) + + def test_has_no_user_loader_callback(self): + self.assertIsNone(self.login_manager._user_callback) + + def test_request_context_users_are_anonymous(self): + with self.app.test_request_context(): + self.assertTrue(current_user.is_anonymous) + + def test_defaults_anonymous(self): + with self.app.test_client() as c: + result = c.get('/username') + self.assertEqual(result.status_code, 401) + + def test_login_via_request(self): + user_id = 2 + user_name = USERS[user_id].name + with self.app.test_client() as c: + url = '/username?user_id={user_id}'.format(user_id=user_id) + result = c.get(url) + self.assertEqual(user_name, result.data.decode('utf-8')) + + def test_login_via_request_uses_cookie_when_already_logged_in(self): + user_id = 2 + user_name = notch.name + with self.app.test_client() as c: + c.get('/login-notch') + url = '/username' + result = c.get(url) + self.assertEqual(user_name, result.data.decode('utf-8')) + url = '/username?user_id={user_id}'.format(user_id=user_id) + result = c.get(url) + self.assertEqual(u'Steve', result.data.decode('utf-8')) + + def test_login_invalid_user_with_request(self): + user_id = 9000 + with self.app.test_client() as c: + url = '/username?user_id={user_id}'.format(user_id=user_id) + result = c.get(url) + self.assertEqual(result.status_code, 401) + + def test_login_invalid_user_with_request_when_already_logged_in(self): + user_id = 9000 + with self.app.test_client() as c: + url = '/login-notch' + result = c.get(url) + self.assertEqual(u'True', result.data.decode('utf-8')) + url = '/username?user_id={user_id}'.format(user_id=user_id) + result = c.get(url) + self.assertEqual(result.status_code, 401) + + def test_login_user_with_request_does_not_modify_session(self): + user_id = 2 + user_name = USERS[user_id].name + with self.app.test_client() as c: + url = '/username?user_id={user_id}'.format(user_id=user_id) + result = c.get(url) + self.assertEqual(user_name, result.data.decode('utf-8')) + url = '/username' + result = c.get(url) + self.assertEqual(u'Anonymous', result.data.decode('utf-8')) + + class TestLoginUrlGeneration(unittest.TestCase): def setUp(self): self.app = Flask(__name__)