diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy index 72f5e124c25..7fad9272fa0 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy @@ -16,19 +16,28 @@ package io.micronaut.http.server.netty.cors import io.micronaut.context.ApplicationContext +import io.micronaut.core.annotation.Nullable +import io.micronaut.core.async.publisher.Publishers +import io.micronaut.core.util.StringUtils import io.micronaut.http.* import io.micronaut.http.annotation.Controller import io.micronaut.http.annotation.Get +import io.micronaut.http.filter.ServerFilterChain import io.micronaut.http.server.HttpServerConfiguration import io.micronaut.http.server.cors.CorsFilter import io.micronaut.http.server.cors.CorsOriginConfiguration +import io.micronaut.http.server.util.HttpHostResolver import io.micronaut.runtime.server.EmbeddedServer import io.micronaut.web.router.RouteMatch import io.micronaut.web.router.Router +import io.micronaut.web.router.UriRouteMatch import org.apache.http.client.utils.URIBuilder +import org.reactivestreams.Publisher +import reactor.core.publisher.Mono import spock.lang.AutoCleanup import spock.lang.Shared import spock.lang.Specification +import spock.lang.Unroll import java.util.stream.Collectors @@ -36,70 +45,82 @@ import static io.micronaut.http.HttpHeaders.* class CorsFilterSpec extends Specification { - @Shared @AutoCleanup + @Shared + @AutoCleanup EmbeddedServer embeddedServer = ApplicationContext.run(EmbeddedServer) - CorsFilter buildCorsHandler(HttpServerConfiguration.CorsConfiguration config) { - new CorsFilter(config ?: new HttpServerConfiguration.CorsConfiguration()) - } - - void "test handleRequest for non CORS request"() { + void "non CORS request is passed through"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.empty() + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration() CorsFilter corsHandler = buildCorsHandler(config) + HttpRequest request = createRequest(null as String) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() then: "the request is passed through" - !result.isPresent() + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().isEmpty() } - void "test handleRequest with no matching configuration"() { + void "request with origin and no matching configuration"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.bar.com' + HttpRequest request = createRequest(origin) CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() then: "the request is passed through because no configuration matches the origin" - 2 * headers.getOrigin() >> Optional.of('http://www.bar.com') - !result.isPresent() + HttpStatus.OK == response.status() + response.headers.names().isEmpty() } - void "test handleRequest with regex matching configuration"() { + @Unroll + void "regex matching configuration"(List regex, String origin) { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers + HttpRequest request = createRequest(origin) request.getAttribute(HttpAttributes.ROUTE_MATCH, RouteMatch.class) >> Optional.empty() - def config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = regex - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is passed through because no configuration matches the origin" - 2 * headers.getOrigin() >> Optional.of(origin) - !result.isPresent() + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 3 + response.headers.find { it.key == 'Access-Control-Allow-Origin' } + response.headers.find { it.key == 'Vary' } + response.headers.find { it.key == 'Access-Control-Allow-Credentials' } + response.headers.find { it.key == 'Access-Control-Allow-Origin' }.value == [origin] + response.headers.find { it.key == 'Vary' }.value == ['Origin'] + response.headers.find { it.key == 'Access-Control-Allow-Credentials' }.value == [StringUtils.TRUE] where: regex | origin @@ -112,198 +133,251 @@ class CorsFilterSpec extends Specification { void "test handleRequest with disallowed method"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers + String origin = 'http://www.foo.com' + HttpRequest request = createRequest(origin) - def config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is rejected because the method is not in the list of allowedMethods" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 1 * request.getMethod() >> HttpMethod.POST + then: result.isPresent() - result.get().status == HttpStatus.FORBIDDEN + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.FORBIDDEN == response.status() + response.headers.names().isEmpty() } - void "test handleRequest with disallowed header (not preflight)"() { + void "with disallowed header (not preflight) the request is passed through because allowed headers are only checked for preflight requests"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers + String origin = 'http://www.foo.com' + HttpRequest request = createRequest(origin) + request.getMethod() >> HttpMethod.GET - def config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) when: - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is passed through because allowed headers are only checked for preflight requests" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 1 * request.getMethod() >> HttpMethod.GET - !result.isPresent() - 0 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 3 + response.headers.find { it.key == 'Access-Control-Allow-Origin' } + response.headers.find { it.key == 'Vary' } + response.headers.find { it.key == 'Access-Control-Allow-Credentials' } + response.headers.find { it.key == 'Access-Control-Allow-Origin' }.value == ['http://www.foo.com'] + response.headers.find { it.key == 'Vary' }.value == ['Origin'] + response.headers.find { it.key == 'Access-Control-Allow-Credentials' }.value == [StringUtils.TRUE] } void "test preflight handleRequest with disallowed header"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo', 'bar']) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + HttpRequest request = createRequest(headers) + request.getMethod() >> HttpMethod.OPTIONS + request.getUri() >> new URIBuilder( '/example' ).build() + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(request.getUri().toString(), request) + .collect(Collectors.toList()) + + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] originConfig.allowedHeaders = ['foo'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) + MutableHttpResponse response = result.get() then: "the request is rejected because bar is not allowed" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 1 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) - 1 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo', 'bar']) - result.get().status == HttpStatus.FORBIDDEN + HttpStatus.FORBIDDEN == response.status() } - void "test preflight handleRequest with allowed header"() { + void "test preflight with allowed header"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo']) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + HttpRequest request = createRequest(headers) + request.getMethod() >> HttpMethod.OPTIONS + request.getUri() >> new URIBuilder( '/example' ).build() + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(request.getUri().toString(), request) + .collect(Collectors.toList()) + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] originConfig.allowedHeaders = ['foo', 'bar'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) - - then: "the request is successful" - 4 * headers.getOrigin() >> Optional.of('http://www.foo.com') - 2 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) - 2 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['foo']) - result.get().status == HttpStatus.OK + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 6 + response.headers.find { it.key == 'Access-Control-Allow-Origin' } + response.headers.find { it.key == 'Vary' } + response.headers.find { it.key == 'Access-Control-Allow-Credentials' } + response.headers.find { it.key == 'Access-Control-Allow-Methods' } + response.headers.find { it.key == 'Access-Control-Allow-Headers' } + response.headers.find { it.key == 'Access-Control-Max-Age' } + response.headers.find { it.key == 'Access-Control-Allow-Origin' }.value == ['http://www.foo.com'] + response.headers.find { it.key == 'Vary' }.value == ['Origin'] + response.headers.find { it.key == 'Access-Control-Allow-Credentials' }.value == [StringUtils.TRUE] + response.headers.find { it.key == 'Access-Control-Allow-Methods' }.value == ['GET'] + response.headers.find { it.key == 'Access-Control-Allow-Headers' }.value == ['foo'] + response.headers.find { it.key == 'Access-Control-Max-Age' }.value == ['1800'] } void "test handleResponse when configuration not present"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.bar.com' + HttpServerConfiguration.CorsConfiguration config = new HttpServerConfiguration.CorsConfiguration() CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + config.setConfigurations([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + } + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + } when: - def result = corsHandler.handleRequest(request) + Optional> result = corsHandler.handleRequest(request) then: "the response is not modified" - 2 * headers.getOrigin() >> Optional.of('http://www.bar.com') notThrown(NullPointerException) !result.isPresent() } - void "test handleResponse for normal request"() { + void "verify behaviour for normal request"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + } + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() then: - !result.isPresent() + result.isPresent() when: - MutableHttpResponse response = HttpResponse.ok() - corsHandler.handleResponse(request, response) + MutableHttpResponse response = result.get() - then: "the response is not modified" + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 5 response.getHeaders().get(ACCESS_CONTROL_ALLOW_ORIGIN) == 'http://www.foo.com' // The origin is echo'd response.getHeaders().get(VARY) == 'Origin' // The vary header is set response.getHeaders().getAll(ACCESS_CONTROL_EXPOSE_HEADERS) == ['Foo-Header', 'Bar-Header' ]// Expose headers are set from config response.getHeaders().get(ACCESS_CONTROL_ALLOW_CREDENTIALS) == 'true' // Allow credentials header is set + response.getHeaders().get(ACCESS_CONTROL_MAX_AGE) == '1800' } void "test handleResponse for preflight request"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + HttpHeaders headers = Stub(HttpHeaders) { + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + getOrigin() >> Optional.of('http://www.foo.com') + } + URI uri = new URIBuilder('/example').build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getMethod() >> HttpMethod.OPTIONS + getUri() >> uri + } + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(uri.toString(), request) + .collect(Collectors.toList()) + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route -> route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - HttpResponse response = corsHandler.handleRequest(request).get() + MutableHttpResponse response = result.get() - then: "the response is not modified" - 2 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) - 2 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + then: + HttpStatus.OK == response.status() + response.headers.names().size() == 7 response.getHeaders().get(ACCESS_CONTROL_ALLOW_METHODS) == 'GET' response.getHeaders().get(ACCESS_CONTROL_ALLOW_ORIGIN) == 'http://www.foo.com' // The origin is echo'd response.getHeaders().get(VARY) == 'Origin' // The vary header is set @@ -315,32 +389,44 @@ class CorsFilterSpec extends Specification { void "test handleResponse for preflight request with single header"() { given: - def config = new HttpServerConfiguration.CorsConfiguration(singleHeader: true) CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = new HttpServerConfiguration.CorsConfiguration(singleHeader: true, enabled: true) + config.setConfigurations([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). + + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of('http://www.foo.com') + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + } + URI uri = new URIBuilder( '/example' ).build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getMethod() >> HttpMethod.OPTIONS + getUri() >> uri + } + List> routes = embeddedServer.getApplicationContext().getBean(Router). findAny(request.getUri().toString(), request) .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - HttpResponse response = corsHandler.handleRequest(request).get() + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() then: "the response is not modified" - 2 * headers.get(ACCESS_CONTROL_REQUEST_HEADERS, _) >> Optional.of(['X-Header', 'Y-Header']) - 2 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) response.getHeaders().get(ACCESS_CONTROL_ALLOW_METHODS) == 'GET' response.getHeaders().get(ACCESS_CONTROL_ALLOW_ORIGIN) == 'http://www.foo.com' // The origin is echo'd response.getHeaders().get(VARY) == 'Origin' // The vary header is set @@ -352,63 +438,84 @@ class CorsFilterSpec extends Specification { void "test preflight handleRequest on route that doesn't exists"() { given: - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - def uri = new URIBuilder( '/doesnt-exists-route' ) - request.getUri() >> uri.build() - def config = new HttpServerConfiguration.CorsConfiguration() + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) + getOrigin() >> Optional.of(origin) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + URI uri = new URIBuilder( '/doesnt-exists-route' ).build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getUri() >> uri + getMethod() >> HttpMethod.OPTIONS + } + List> routes = embeddedServer.getApplicationContext().getBean(Router). + findAny(uri.toString(), request) + .collect(Collectors.toList()) + request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.allowedOrigins = ['http://www.foo.com'] originConfig.allowedMethods = [HttpMethod.GET] originConfig.allowedHeaders = ['foo', 'bar'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - request.getMethod() >> HttpMethod.OPTIONS - def routes = embeddedServer.getApplicationContext().getBean(Router). - findAny(request.getUri().toString(), request) - .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + + then: + result.isPresent() when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - 1 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.GET) - def result = corsHandler.handleRequest(request) + MutableHttpResponse response = result.get() - then: "the request is successful" - 2 * headers.getOrigin() >> Optional.of('http://www.foo.com') - !result.isPresent() + then: + HttpStatus.OK == response.status() } void "test preflight handleRequest on route that does exist but doesn't handle requested HTTP Method"() { given: - def config = new HttpServerConfiguration.CorsConfiguration() + CorsOriginConfiguration originConfig = new CorsOriginConfiguration() originConfig.exposedHeaders = ['Foo-Header', 'Bar-Header'] - config.configurations = new LinkedHashMap() - config.configurations.put('foo', originConfig) + + HttpServerConfiguration.CorsConfiguration config = enabledCorsConfiguration([foo: originConfig]) + CorsFilter corsHandler = buildCorsHandler(config) - HttpRequest request = Mock(HttpRequest) - HttpHeaders headers = Mock(HttpHeaders) - request.getHeaders() >> headers - headers.getOrigin() >> Optional.of('http://www.foo.com') - request.getMethod() >> HttpMethod.OPTIONS - def uri = new URIBuilder( '/example' ) - request.getUri() >> uri.build() - def routes = embeddedServer.getApplicationContext().getBean(Router). + + String origin = 'http://www.foo.com' + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.of(origin) + getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.POST) + contains(ACCESS_CONTROL_REQUEST_METHOD) >> true + } + URI uri = new URIBuilder( '/example' ).build() + HttpRequest request = Stub(HttpRequest) { + getHeaders() >> headers + getMethod() >> HttpMethod.OPTIONS + getUri() >> uri + } + + List> routes = embeddedServer.getApplicationContext().getBean(Router). findAny(request.getUri().toString(), request) .collect(Collectors.toList()) - request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) + when: - headers.contains(ACCESS_CONTROL_REQUEST_METHOD) >> true - def result = corsHandler.handleRequest(request) + Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() - then: "the request is successful" - 1 * headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, _) >> Optional.of(HttpMethod.POST) - !result.isPresent() + then: + result.isPresent() + + when: + MutableHttpResponse response = result.get() + + then: + HttpStatus.OK == response.status() } @Controller @@ -417,4 +524,43 @@ class CorsFilterSpec extends Specification { @Get("/example") String example() { return "Example"} } + + private HttpRequest createRequest(String originHeader) { + HttpHeaders headers = Stub(HttpHeaders) { + getOrigin() >> Optional.ofNullable(originHeader) + } + createRequest(headers) + } + + private HttpRequest createRequest(HttpHeaders headers) { + Stub(HttpRequest) { + getHeaders() >> headers + } + } + + private ServerFilterChain okChain() { + new ServerFilterChain() { + @Override + Publisher> proceed(HttpRequest req) { + Publishers.just(HttpResponse.ok()) + } + } + } + + private HttpServerConfiguration.CorsConfiguration enabledCorsConfiguration(Map corsConfigurationMap = null) { + HttpServerConfiguration.CorsConfiguration config = new HttpServerConfiguration.CorsConfiguration() { + @Override + boolean isEnabled() { + true + } + } + if (corsConfigurationMap != null) { + config.setConfigurations(corsConfigurationMap) + } + config + } + + private CorsFilter buildCorsHandler(HttpServerConfiguration.CorsConfiguration config) { + new CorsFilter(config ?: enabledCorsConfiguration()) + } } diff --git a/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java b/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java index 705b5c55fea..a7039180c88 100644 --- a/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java +++ b/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java @@ -104,8 +104,9 @@ protected void handleResponse(HttpRequest request, MutableHttpResponse res CorsOriginConfiguration config = optionalConfig.get(); if (CorsUtil.isPreflightRequest(request)) { - Optional result = headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, CONVERSION_CONTEXT_HTTP_METHOD); - setAllowMethods(result.get(), response); + headers.getFirst(ACCESS_CONTROL_REQUEST_METHOD, CONVERSION_CONTEXT_HTTP_METHOD) + .ifPresent(result -> setAllowMethods(result, response)); + Optional> allowedHeaders = headers.get(ACCESS_CONTROL_REQUEST_HEADERS, ConversionContext.LIST_OF_STRING); allowedHeaders.ifPresent(val -> setAllowHeaders(val, response)