Skip to content

Commit

Permalink
Inject CORS headers even when server-side errors occur (#5632)
Browse files Browse the repository at this point in the history
Motivation:

There is an issue where CORS headers are not added when exceptions occur while using CorsService. 
[CorsService does not inject CORS headers into error responses](#5493)

Modifications:

Created CorsServerErrorHandler to inject CORS headers upon exceptions. Created CorsHeaderUtil and refactored CorsService, CorsPolicy.

Result:

CORS headers will be added to response headers when an exception is raised.
Co-authored-by: Trustin Lee <trustin@linecorp.com>
Co-authored-by: minwoox <songmw725@gmail.com>
  • Loading branch information
3 people committed May 13, 2024
1 parent 6eccf52 commit 450d5d5
Show file tree
Hide file tree
Showing 7 changed files with 426 additions and 147 deletions.
@@ -0,0 +1,209 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License
*/

package com.linecorp.armeria.internal.server;

import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Joiner;
import com.google.common.base.Strings;

import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.ResponseHeadersBuilder;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.cors.CorsConfig;
import com.linecorp.armeria.server.cors.CorsPolicy;

import io.netty.util.AsciiString;

/**
* A utility class related to CORS headers.
*/
public final class CorsHeaderUtil {

private static final Logger logger = LoggerFactory.getLogger(CorsHeaderUtil.class);
public static final String ANY_ORIGIN = "*";
public static final String NULL_ORIGIN = "null";
public static final String DELIMITER = ",";
private static final Joiner HEADER_JOINER = Joiner.on(DELIMITER);

public static ResponseHeaders addCorsHeaders(ServiceRequestContext ctx, CorsConfig corsConfig,
ResponseHeaders responseHeaders) {
final HttpRequest httpRequest = ctx.request();
final ResponseHeadersBuilder responseHeadersBuilder = responseHeaders.toBuilder();

setCorsResponseHeaders(ctx, httpRequest, responseHeadersBuilder, corsConfig);

return responseHeadersBuilder.build();
}

/**
* Emit CORS headers if origin was found.
*
* @param req the HTTP request with the CORS info
* @param headers the headers to modify
*/
public static void setCorsResponseHeaders(ServiceRequestContext ctx, HttpRequest req,
ResponseHeadersBuilder headers, CorsConfig config) {
final CorsPolicy policy = setCorsOrigin(ctx, req, headers, config);
if (policy != null) {
setCorsAllowCredentials(headers, policy);
setCorsAllowHeaders(req.headers(), headers, policy);
setCorsExposeHeaders(headers, policy);
}
}

public static void setCorsAllowCredentials(ResponseHeadersBuilder headers, CorsPolicy policy) {
// The string "*" cannot be used for a resource that supports credentials.
// https://www.w3.org/TR/cors/#resource-requests
if (policy.isCredentialsAllowed() &&
!ANY_ORIGIN.equals(headers.get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN))) {
headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
}

private static void setCorsExposeHeaders(ResponseHeadersBuilder headers, CorsPolicy corsPolicy) {
if (corsPolicy.exposedHeaders().isEmpty()) {
return;
}

headers.set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, joinExposedHeaders(corsPolicy));
}

public static void setCorsAllowHeaders(RequestHeaders requestHeaders, ResponseHeadersBuilder headers,
CorsPolicy corsPolicy) {
if (corsPolicy.shouldAllowAllRequestHeaders()) {
final String header = requestHeaders.get(HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS);
if (!Strings.isNullOrEmpty(header)) {
headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, header);
}

return;
}

if (corsPolicy.allowedRequestHeaders().isEmpty()) {
return;
}

headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, joinAllowedRequestHeaders(corsPolicy));
}

/**
* Sets origin header according to the given CORS configuration and HTTP request.
*
* @param request the HTTP request
* @param headers the HTTP headers to modify
*
* @return {@code policy} if CORS configuration matches, otherwise {@code null}
*/
@Nullable
public static CorsPolicy setCorsOrigin(ServiceRequestContext ctx, HttpRequest request,
ResponseHeadersBuilder headers, CorsConfig config) {

final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
if (origin != null) {
final CorsPolicy policy = config.getPolicy(origin, ctx.routingContext());
if (policy == null) {
logger.debug(
"{} There is no CORS policy configured for the request origin '{}' and the path '{}'.",
ctx, origin, ctx.path());
return null;
}
if (NULL_ORIGIN.equals(origin)) {
setCorsNullOrigin(headers);
return policy;
}
if (config.isAnyOriginSupported()) {
if (policy.isCredentialsAllowed()) {
echoCorsRequestOrigin(request, headers);
addCorsVaryHeader(headers);
} else {
setCorsAnyOrigin(headers);
}
return policy;
}
setCorsOrigin(headers, origin);
addCorsVaryHeader(headers);
return policy;
}
return null;
}

private static void setCorsOrigin(ResponseHeadersBuilder headers, String origin) {
headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}

private static void echoCorsRequestOrigin(HttpRequest request, ResponseHeadersBuilder headers) {
final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
if (origin != null) {
setCorsOrigin(headers, origin);
}
}

private static void addCorsVaryHeader(ResponseHeadersBuilder headers) {
headers.add(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN.toString());
}

private static void setCorsAnyOrigin(ResponseHeadersBuilder headers) {
setCorsOrigin(headers, ANY_ORIGIN);
}

private static void setCorsNullOrigin(ResponseHeadersBuilder headers) {
setCorsOrigin(headers, NULL_ORIGIN);
}

/**
* Joins the given set of headers into a single string.
* This can be useful for creating a comma-separated list of headers.
*
* @param headers The set of headers to be joined.
* @return A {@link String} representing the joined headers.
*/
private static String joinHeaders(Set<AsciiString> headers) {
return HEADER_JOINER.join(headers);
}

/**
* Joins the set of exposed headers into a single string.
* This method utilizes {@link #joinHeaders(Set)} to create a comma-separated list of exposed headers.
*
* @return A {@link String} representing the joined exposed headers.
*/
private static String joinExposedHeaders(CorsPolicy policy) {
return joinHeaders(policy.exposedHeaders());
}

/**
* Joins the set of allowed request headers into a single string.
* This method utilizes {@link #joinHeaders(Set)}
* to create a comma-separated list of allowed request headers.
*
* @return A {@link String} representing the joined allowed request headers.
*/
private static String joinAllowedRequestHeaders(CorsPolicy corsPolicy) {
return joinHeaders(corsPolicy.allowedRequestHeaders());
}

private CorsHeaderUtil() {
}
}
@@ -0,0 +1,97 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License
*/

package com.linecorp.armeria.server;

import static com.linecorp.armeria.internal.server.CorsHeaderUtil.addCorsHeaders;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.server.cors.CorsConfig;
import com.linecorp.armeria.server.cors.CorsService;

/**
* wraps ServerErrorHandler for adding CORS headers to error responses.
*/
final class CorsServerErrorHandler implements ServerErrorHandler {
ServerErrorHandler serverErrorHandler;

/**
* Constructs a new {@link CorsServerErrorHandler} instance with a specified {@link ServerErrorHandler}.
* This handler is used to delegate server error handling for CORS-related errors.
*
* @param serverErrorHandler The {@link ServerErrorHandler} to be used for handling server errors.
*/
CorsServerErrorHandler(ServerErrorHandler serverErrorHandler) {
this.serverErrorHandler = serverErrorHandler;
}

@Override
public @Nullable AggregatedHttpResponse renderStatus(@Nullable ServiceRequestContext ctx,
ServiceConfig serviceConfig,
@Nullable RequestHeaders headers,
HttpStatus status, @Nullable String description,
@Nullable Throwable cause) {

if (ctx == null) {
return serverErrorHandler.renderStatus(null, serviceConfig, headers, status, description, cause);
}

final CorsService corsService = serviceConfig.service().as(CorsService.class);
if (corsService == null) {
return serverErrorHandler.renderStatus(ctx, serviceConfig, headers, status, description, cause);
}

final AggregatedHttpResponse res = serverErrorHandler.renderStatus(ctx, serviceConfig, headers, status,
description, cause);

if (res == null) {
return serverErrorHandler.renderStatus(ctx, serviceConfig, headers, status, description, cause);
}

final CorsConfig corsConfig = corsService.config();
final ResponseHeaders updatedResponseHeaders = addCorsHeaders(ctx, corsConfig,
res.headers());

return AggregatedHttpResponse.of(updatedResponseHeaders, res.content());
}

@Override
public @Nullable HttpResponse onServiceException(ServiceRequestContext ctx, Throwable cause) {
if (cause instanceof HttpResponseException) {
final HttpResponse oldRes = serverErrorHandler.onServiceException(ctx, cause);
if (oldRes == null) {
return null;
}
final CorsService corsService = ctx.config().service().as(CorsService.class);
if (corsService == null) {
return oldRes;
}
return oldRes
.recover(HttpResponseException.class,
ex -> ex.httpResponse()
.mapHeaders(oldHeaders -> addCorsHeaders(ctx,
corsService.config(),
oldHeaders)));
} else {
return serverErrorHandler.onServiceException(ctx, cause);
}
}
}
Expand Up @@ -2191,13 +2191,10 @@ private DefaultServerConfig buildServerConfig(List<ServerPort> serverPorts) {
unloggedExceptionsReporter = null;
}

final ServerErrorHandler errorHandler;
if (this.errorHandler == null) {
errorHandler = ServerErrorHandler.ofDefault();
} else {
// Ensure that ServerErrorHandler never returns null by falling back to the default.
errorHandler = this.errorHandler.orElse(ServerErrorHandler.ofDefault());
}
final ServerErrorHandler errorHandler =
new CorsServerErrorHandler(
this.errorHandler == null ? ServerErrorHandler.ofDefault()
: this.errorHandler.orElse(ServerErrorHandler.ofDefault()));
final VirtualHost defaultVirtualHost =
defaultVirtualHostBuilder.build(virtualHostTemplate, dependencyInjector,
unloggedExceptionsReporter, errorHandler);
Expand Down
Expand Up @@ -27,6 +27,7 @@
import com.google.common.collect.Iterables;

import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.server.CorsHeaderUtil;
import com.linecorp.armeria.server.Route;
import com.linecorp.armeria.server.RoutingContext;

Expand Down Expand Up @@ -100,7 +101,7 @@ public CorsPolicy getPolicy(@Nullable String origin, RoutingContext routingConte
}

final String lowerCaseOrigin = Ascii.toLowerCase(origin);
final boolean isNullOrigin = CorsService.NULL_ORIGIN.equals(lowerCaseOrigin);
final boolean isNullOrigin = CorsHeaderUtil.NULL_ORIGIN.equals(lowerCaseOrigin);
for (final CorsPolicy policy : policies) {
if (isNullOrigin && policy.isNullOriginAllowed() &&
isPathMatched(policy, routingContext)) {
Expand Down

0 comments on commit 450d5d5

Please sign in to comment.