Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add always_send, enable by default #156

Merged
merged 2 commits into from May 23, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 29 additions & 9 deletions flask_cors/core.py
Expand Up @@ -39,7 +39,8 @@
'CORS_EXPOSE_HEADERS', 'CORS_SUPPORTS_CREDENTIALS',
'CORS_MAX_AGE', 'CORS_SEND_WILDCARD',
'CORS_AUTOMATIC_OPTIONS', 'CORS_VARY_HEADER',
'CORS_RESOURCES', 'CORS_INTERCEPT_EXCEPTIONS']
'CORS_RESOURCES', 'CORS_INTERCEPT_EXCEPTIONS',
'CORS_ALWAYS_SEND']
# Attribute added to request object by decorator to indicate that CORS
# was evaluated, in case the decorator and extension are both applied
# to a view.
Expand All @@ -58,7 +59,8 @@
automatic_options=True,
vary_header=True,
resources=r'/*',
intercept_exceptions=True)
intercept_exceptions=True,
always_send=True)


def parse_resources(resources):
Expand Down Expand Up @@ -108,7 +110,7 @@ def get_regexp_pattern(regexp):
return str(regexp)


def get_cors_origin(options, request_origin):
def get_cors_origins(options, request_origin):
origins = options.get('origins')
wildcard = r'.*' in origins

Expand All @@ -120,18 +122,32 @@ def get_cors_origin(options, request_origin):
# If the allowed origins is an asterisk or 'wildcard', always match
if wildcard and options.get('send_wildcard'):
LOG.debug("Allowed origins are set to '*'. Sending wildcard CORS header.")
return '*'
return ['*']
# If the value of the Origin header is a case-sensitive match
# for any of the values in list of origins
elif try_match_any(request_origin, origins):
LOG.debug("The request's Origin header matches. Sending CORS headers.", )
# Add a single Access-Control-Allow-Origin header, with either
# the value of the Origin header or the string "*" as value.
# -- W3Spec
return request_origin
return [request_origin]
else:
LOG.debug("The request's Origin header does not match any of allowed origins.")
return None


elif options.get('always_send'):
if wildcard:
# If wildcard is in the origins, even if 'send_wildcard' is False,
# simply send the wildcard. It is the most-likely to be correct
# thing to do (the only other option is to return nothing, which)
# pretty is probably not whawt you want if you specify origins as
# '*'
return ['*']
else:
# Return all origins that are not regexes.
return sorted([o for o in origins if not probably_regex(o)])

# Terminate these steps, return the original request untouched.
else:
LOG.debug("The request did not contain an 'Origin' header. This means the browser or client did not request CORS, ensure the Origin Header is set.")
Expand All @@ -154,13 +170,15 @@ def get_allow_headers(options, acl_request_headers):


def get_cors_headers(options, request_headers, request_method, response_headers):
origin_to_set = get_cors_origin(options, request_headers.get('Origin'))
origins_to_set = get_cors_origins(options, request_headers.get('Origin'))
headers = MultiDict()

if origin_to_set is None: # CORS is not enabled for this route
if not origins_to_set: # CORS is not enabled for this route
return headers

headers[ACL_ORIGIN] = origin_to_set
for origin in origins_to_set:
headers.add(ACL_ORIGIN, origin)

headers[ACL_EXPOSE_HEADERS] = options.get('expose_headers')

if options.get('supports_credentials'):
Expand Down Expand Up @@ -191,7 +209,9 @@ def get_cors_headers(options, request_headers, request_method, response_headers)
# origins that can be matched.
if headers[ACL_ORIGIN] == '*':
pass
elif len(options.get('origins')) > 1 or any(map(probably_regex, options.get('origins'))):
elif (len(options.get('origins')) > 1 or
len(origins_to_set) > 1 or
any(map(probably_regex, options.get('origins')))):
headers.add('Vary', 'Origin')

return MultiDict((k, v) for k, v in headers.items() if v)
Expand Down
2 changes: 1 addition & 1 deletion flask_cors/version.py
@@ -1 +1 @@
__version__ = '2.1.3'
__version__ = '3.0.0'
3 changes: 0 additions & 3 deletions tests/decorator/test_credentials.py
Expand Up @@ -42,9 +42,6 @@ def test_credentials_supported(self):
resp = self.get('/test_credentials_supported', origin='www.example.com')
self.assertEquals(resp.headers.get(ACL_CREDENTIALS), 'true')

resp = self.get('/test_credentials_supported')
self.assertEquals(resp.headers.get(ACL_CREDENTIALS), None )

def test_default(self):
''' The default behavior should be to disallow credentials.
'''
Expand Down
151 changes: 32 additions & 119 deletions tests/decorator/test_origins.py
Expand Up @@ -17,7 +17,6 @@

letters = 'abcdefghijklmnopqrstuvwxyz' # string.letters is not PY3 compatible


class OriginsTestCase(FlaskCorsTestCase):
def setUp(self):
self.app = Flask(__name__)
Expand All @@ -27,9 +26,19 @@ def setUp(self):
def wildcard():
return 'Welcome!'

@self.app.route('/test_always_send')
@cross_origin(always_send=True)
def test_always_send():
return 'Welcome!'

@self.app.route('/test_always_send_no_wildcard')
@cross_origin(always_send=True, send_wildcard=False)
def test_always_send_no_wildcard():
return 'Welcome!'

@self.app.route('/test_send_wildcard_with_origin')
@cross_origin(send_wildcard=True)
def send_wildcard():
def test_send_wildcard_with_origin():
return 'Welcome!'

@self.app.route('/test_list')
Expand All @@ -49,31 +58,30 @@ def test_set():

@self.app.route('/test_subdomain_regex')
@cross_origin(origins=r"http?://\w*\.?example\.com:?\d*/?.*")
def _test_subdomain_regex():
def test_subdomain_regex():
return ''

@self.app.route('/test_compiled_subdomain_regex')
@cross_origin(origins=re.compile(r"http?://\w*\.?example\.com:?\d*/?.*"))
def _test_compiled_subdomain_regex():
def test_compiled_subdomain_regex():
return ''

@self.app.route('/test_regex_list')
@cross_origin(origins=[r".*.example.com", r".*.otherexample.com"])
def _test_regex_list():
def test_regex_list():
return ''

@self.app.route('/test_regex_mixed_list')
@cross_origin(origins=["http://example.com", r".*.otherexample.com"])
def _test_regex_mixed_list():
def test_regex_mixed_list():
return ''

def test_defaults_no_origin(self):
''' If there is no Origin header in the request, the
Access-Control-Allow-Origin header should not be included,
according to the w3 spec.
Access-Control-Allow-Origin header should be '*' by default.
'''
for resp in self.iter_responses('/'):
self.assertEqual(resp.headers.get(ACL_ORIGIN), None)
self.assertEqual(resp.headers.get(ACL_ORIGIN), '*')

def test_defaults_with_origin(self):
''' If there is an Origin header in the request, the
Expand All @@ -83,6 +91,21 @@ def test_defaults_with_origin(self):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.headers.get(ACL_ORIGIN), 'http://example.com')

def test_always_send_no_wildcard(self):
'''
If send_wildcard=False, but the there is '*' in the
allowed origins, we should send it anyways.
'''
for resp in self.iter_responses('/'):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.headers.get(ACL_ORIGIN), '*')

def test_always_send_no_wildcard_origins(self):
for resp in self.iter_responses('/'):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.headers.get(ACL_ORIGIN), '*')


def test_send_wildcard_with_origin(self):
''' If there is an Origin header in the request, the
Access-Control-Allow-Origin header should be included.
Expand Down Expand Up @@ -166,115 +189,5 @@ def test_regex_mixed_list(self):
self.get('/test_regex_mixed_list', origin='http://example.com').headers.get(ACL_ORIGIN))


class AppConfigOriginsTestCase(AppConfigTest, OriginsTestCase):
def __init__(self, *args, **kwargs):
super(AppConfigOriginsTestCase, self).__init__(*args, **kwargs)

def test_defaults_no_origin(self):
@self.app.route('/')
@cross_origin()
def wildcard():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_defaults_no_origin()

def test_defaults_with_origin(self):
@self.app.route('/')
@cross_origin()
def wildcard():
return 'Welcome!'
super(AppConfigOriginsTestCase, self).test_defaults_with_origin()

def test_send_wildcard_with_origin(self):
@self.app.route('/test_send_wildcard_with_origin')
@cross_origin(send_wildcard=True)
def send_wildcard():
return 'Welcome!'
super(AppConfigOriginsTestCase, self).test_send_wildcard_with_origin()

def test_list_serialized(self):
self.app.config['CORS_ORIGINS'] = ["http://foo.com", "http://bar.com"]

@self.app.route('/test_list')
@cross_origin()
def test_list():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_list_serialized()

def test_string_serialized(self):
self.app.config['CORS_ORIGINS'] = "http://foo.com"

@self.app.route('/test_string')
@cross_origin()
def test_string():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_string_serialized()

def test_set_serialized(self):
self.app.config['CORS_ORIGINS'] = set(["http://foo.com",
"http://bar.com"])

@self.app.route('/test_set')
@cross_origin()
def test_set():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_set_serialized()

def test_not_matching_origins(self):
self.app.config['CORS_ORIGINS'] = ["http://foo.com", "http://bar.com"]

@self.app.route('/test_list')
@cross_origin()
def test_list():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_not_matching_origins()

def test_regex_list(self):
@self.app.route('/test_regex_list')
@cross_origin()
def _test_regex_list():
return 'Welcome!'

self.app.config['CORS_ORIGINS'] = [r".*.example.com",
r".*.otherexample.com"]
super(AppConfigOriginsTestCase, self).test_regex_list()

def test_subdomain_regex(self):
self.app.config['CORS_ORIGINS'] = r"http?://\w*\.?example\.com:?\d*/?.*"

@self.app.route('/test_subdomain_regex')
@cross_origin()
def _test_subdomain_regex():
return ''

super(AppConfigOriginsTestCase, self).test_subdomain_regex()

def test_compiled_subdomain_regex(self):
self.app.config['CORS_ORIGINS'] = r"http?://\w*\.?example\.com:?\d*/?.*"

@self.app.route('/test_compiled_subdomain_regex')
@cross_origin()
def _test_compiled_subdomain_regex():
return ''

super(AppConfigOriginsTestCase, self).test_compiled_subdomain_regex()

def test_regex_mixed_list(self):
self.app.config['CORS_ORIGINS'] = ["http://example.com",
r".*.otherexample.com"]

@self.app.route('/test_regex_mixed_list')
@cross_origin()
def _test_regex_mixed_list():
return ''

super(AppConfigOriginsTestCase, self).test_regex_mixed_list()



if __name__ == "__main__":
unittest.main()