-
Notifications
You must be signed in to change notification settings - Fork 895
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Inject CORS headers even when server-side errors occur (#5632)
Motivation: There is an issue where CORS headers are not added when exceptions occur while using CorsService. [CorsService does not inject CORS headers into error responses](#5493) Modifications: Created CorsServerErrorHandler to inject CORS headers upon exceptions. Created CorsHeaderUtil and refactored CorsService, CorsPolicy. Result: CORS headers will be added to response headers when an exception is raised. Co-authored-by: Trustin Lee <trustin@linecorp.com> Co-authored-by: minwoox <songmw725@gmail.com>
- Loading branch information
1 parent
6eccf52
commit 450d5d5
Showing
7 changed files
with
426 additions
and
147 deletions.
There are no files selected for viewing
209 changes: 209 additions & 0 deletions
209
core/src/main/java/com/linecorp/armeria/internal/server/CorsHeaderUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
/* | ||
* Copyright 2024 LINE Corporation | ||
* | ||
* LINE Corporation licenses this file to you under the Apache License, | ||
* version 2.0 (the "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at: | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License | ||
*/ | ||
|
||
package com.linecorp.armeria.internal.server; | ||
|
||
import java.util.Set; | ||
|
||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import com.google.common.base.Joiner; | ||
import com.google.common.base.Strings; | ||
|
||
import com.linecorp.armeria.common.HttpHeaderNames; | ||
import com.linecorp.armeria.common.HttpRequest; | ||
import com.linecorp.armeria.common.RequestHeaders; | ||
import com.linecorp.armeria.common.ResponseHeaders; | ||
import com.linecorp.armeria.common.ResponseHeadersBuilder; | ||
import com.linecorp.armeria.common.annotation.Nullable; | ||
import com.linecorp.armeria.server.ServiceRequestContext; | ||
import com.linecorp.armeria.server.cors.CorsConfig; | ||
import com.linecorp.armeria.server.cors.CorsPolicy; | ||
|
||
import io.netty.util.AsciiString; | ||
|
||
/** | ||
* A utility class related to CORS headers. | ||
*/ | ||
public final class CorsHeaderUtil { | ||
|
||
private static final Logger logger = LoggerFactory.getLogger(CorsHeaderUtil.class); | ||
public static final String ANY_ORIGIN = "*"; | ||
public static final String NULL_ORIGIN = "null"; | ||
public static final String DELIMITER = ","; | ||
private static final Joiner HEADER_JOINER = Joiner.on(DELIMITER); | ||
|
||
public static ResponseHeaders addCorsHeaders(ServiceRequestContext ctx, CorsConfig corsConfig, | ||
ResponseHeaders responseHeaders) { | ||
final HttpRequest httpRequest = ctx.request(); | ||
final ResponseHeadersBuilder responseHeadersBuilder = responseHeaders.toBuilder(); | ||
|
||
setCorsResponseHeaders(ctx, httpRequest, responseHeadersBuilder, corsConfig); | ||
|
||
return responseHeadersBuilder.build(); | ||
} | ||
|
||
/** | ||
* Emit CORS headers if origin was found. | ||
* | ||
* @param req the HTTP request with the CORS info | ||
* @param headers the headers to modify | ||
*/ | ||
public static void setCorsResponseHeaders(ServiceRequestContext ctx, HttpRequest req, | ||
ResponseHeadersBuilder headers, CorsConfig config) { | ||
final CorsPolicy policy = setCorsOrigin(ctx, req, headers, config); | ||
if (policy != null) { | ||
setCorsAllowCredentials(headers, policy); | ||
setCorsAllowHeaders(req.headers(), headers, policy); | ||
setCorsExposeHeaders(headers, policy); | ||
} | ||
} | ||
|
||
public static void setCorsAllowCredentials(ResponseHeadersBuilder headers, CorsPolicy policy) { | ||
// The string "*" cannot be used for a resource that supports credentials. | ||
// https://www.w3.org/TR/cors/#resource-requests | ||
if (policy.isCredentialsAllowed() && | ||
!ANY_ORIGIN.equals(headers.get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN))) { | ||
headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); | ||
} | ||
} | ||
|
||
private static void setCorsExposeHeaders(ResponseHeadersBuilder headers, CorsPolicy corsPolicy) { | ||
if (corsPolicy.exposedHeaders().isEmpty()) { | ||
return; | ||
} | ||
|
||
headers.set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, joinExposedHeaders(corsPolicy)); | ||
} | ||
|
||
public static void setCorsAllowHeaders(RequestHeaders requestHeaders, ResponseHeadersBuilder headers, | ||
CorsPolicy corsPolicy) { | ||
if (corsPolicy.shouldAllowAllRequestHeaders()) { | ||
final String header = requestHeaders.get(HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS); | ||
if (!Strings.isNullOrEmpty(header)) { | ||
headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, header); | ||
} | ||
|
||
return; | ||
} | ||
|
||
if (corsPolicy.allowedRequestHeaders().isEmpty()) { | ||
return; | ||
} | ||
|
||
headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, joinAllowedRequestHeaders(corsPolicy)); | ||
} | ||
|
||
/** | ||
* Sets origin header according to the given CORS configuration and HTTP request. | ||
* | ||
* @param request the HTTP request | ||
* @param headers the HTTP headers to modify | ||
* | ||
* @return {@code policy} if CORS configuration matches, otherwise {@code null} | ||
*/ | ||
@Nullable | ||
public static CorsPolicy setCorsOrigin(ServiceRequestContext ctx, HttpRequest request, | ||
ResponseHeadersBuilder headers, CorsConfig config) { | ||
|
||
final String origin = request.headers().get(HttpHeaderNames.ORIGIN); | ||
if (origin != null) { | ||
final CorsPolicy policy = config.getPolicy(origin, ctx.routingContext()); | ||
if (policy == null) { | ||
logger.debug( | ||
"{} There is no CORS policy configured for the request origin '{}' and the path '{}'.", | ||
ctx, origin, ctx.path()); | ||
return null; | ||
} | ||
if (NULL_ORIGIN.equals(origin)) { | ||
setCorsNullOrigin(headers); | ||
return policy; | ||
} | ||
if (config.isAnyOriginSupported()) { | ||
if (policy.isCredentialsAllowed()) { | ||
echoCorsRequestOrigin(request, headers); | ||
addCorsVaryHeader(headers); | ||
} else { | ||
setCorsAnyOrigin(headers); | ||
} | ||
return policy; | ||
} | ||
setCorsOrigin(headers, origin); | ||
addCorsVaryHeader(headers); | ||
return policy; | ||
} | ||
return null; | ||
} | ||
|
||
private static void setCorsOrigin(ResponseHeadersBuilder headers, String origin) { | ||
headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin); | ||
} | ||
|
||
private static void echoCorsRequestOrigin(HttpRequest request, ResponseHeadersBuilder headers) { | ||
final String origin = request.headers().get(HttpHeaderNames.ORIGIN); | ||
if (origin != null) { | ||
setCorsOrigin(headers, origin); | ||
} | ||
} | ||
|
||
private static void addCorsVaryHeader(ResponseHeadersBuilder headers) { | ||
headers.add(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN.toString()); | ||
} | ||
|
||
private static void setCorsAnyOrigin(ResponseHeadersBuilder headers) { | ||
setCorsOrigin(headers, ANY_ORIGIN); | ||
} | ||
|
||
private static void setCorsNullOrigin(ResponseHeadersBuilder headers) { | ||
setCorsOrigin(headers, NULL_ORIGIN); | ||
} | ||
|
||
/** | ||
* Joins the given set of headers into a single string. | ||
* This can be useful for creating a comma-separated list of headers. | ||
* | ||
* @param headers The set of headers to be joined. | ||
* @return A {@link String} representing the joined headers. | ||
*/ | ||
private static String joinHeaders(Set<AsciiString> headers) { | ||
return HEADER_JOINER.join(headers); | ||
} | ||
|
||
/** | ||
* Joins the set of exposed headers into a single string. | ||
* This method utilizes {@link #joinHeaders(Set)} to create a comma-separated list of exposed headers. | ||
* | ||
* @return A {@link String} representing the joined exposed headers. | ||
*/ | ||
private static String joinExposedHeaders(CorsPolicy policy) { | ||
return joinHeaders(policy.exposedHeaders()); | ||
} | ||
|
||
/** | ||
* Joins the set of allowed request headers into a single string. | ||
* This method utilizes {@link #joinHeaders(Set)} | ||
* to create a comma-separated list of allowed request headers. | ||
* | ||
* @return A {@link String} representing the joined allowed request headers. | ||
*/ | ||
private static String joinAllowedRequestHeaders(CorsPolicy corsPolicy) { | ||
return joinHeaders(corsPolicy.allowedRequestHeaders()); | ||
} | ||
|
||
private CorsHeaderUtil() { | ||
} | ||
} |
97 changes: 97 additions & 0 deletions
97
core/src/main/java/com/linecorp/armeria/server/CorsServerErrorHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/* | ||
* Copyright 2024 LINE Corporation | ||
* | ||
* LINE Corporation licenses this file to you under the Apache License, | ||
* version 2.0 (the "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at: | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License | ||
*/ | ||
|
||
package com.linecorp.armeria.server; | ||
|
||
import static com.linecorp.armeria.internal.server.CorsHeaderUtil.addCorsHeaders; | ||
|
||
import com.linecorp.armeria.common.AggregatedHttpResponse; | ||
import com.linecorp.armeria.common.HttpResponse; | ||
import com.linecorp.armeria.common.HttpStatus; | ||
import com.linecorp.armeria.common.RequestHeaders; | ||
import com.linecorp.armeria.common.ResponseHeaders; | ||
import com.linecorp.armeria.common.annotation.Nullable; | ||
import com.linecorp.armeria.server.cors.CorsConfig; | ||
import com.linecorp.armeria.server.cors.CorsService; | ||
|
||
/** | ||
* wraps ServerErrorHandler for adding CORS headers to error responses. | ||
*/ | ||
final class CorsServerErrorHandler implements ServerErrorHandler { | ||
ServerErrorHandler serverErrorHandler; | ||
|
||
/** | ||
* Constructs a new {@link CorsServerErrorHandler} instance with a specified {@link ServerErrorHandler}. | ||
* This handler is used to delegate server error handling for CORS-related errors. | ||
* | ||
* @param serverErrorHandler The {@link ServerErrorHandler} to be used for handling server errors. | ||
*/ | ||
CorsServerErrorHandler(ServerErrorHandler serverErrorHandler) { | ||
this.serverErrorHandler = serverErrorHandler; | ||
} | ||
|
||
@Override | ||
public @Nullable AggregatedHttpResponse renderStatus(@Nullable ServiceRequestContext ctx, | ||
ServiceConfig serviceConfig, | ||
@Nullable RequestHeaders headers, | ||
HttpStatus status, @Nullable String description, | ||
@Nullable Throwable cause) { | ||
|
||
if (ctx == null) { | ||
return serverErrorHandler.renderStatus(null, serviceConfig, headers, status, description, cause); | ||
} | ||
|
||
final CorsService corsService = serviceConfig.service().as(CorsService.class); | ||
if (corsService == null) { | ||
return serverErrorHandler.renderStatus(ctx, serviceConfig, headers, status, description, cause); | ||
} | ||
|
||
final AggregatedHttpResponse res = serverErrorHandler.renderStatus(ctx, serviceConfig, headers, status, | ||
description, cause); | ||
|
||
if (res == null) { | ||
return serverErrorHandler.renderStatus(ctx, serviceConfig, headers, status, description, cause); | ||
} | ||
|
||
final CorsConfig corsConfig = corsService.config(); | ||
final ResponseHeaders updatedResponseHeaders = addCorsHeaders(ctx, corsConfig, | ||
res.headers()); | ||
|
||
return AggregatedHttpResponse.of(updatedResponseHeaders, res.content()); | ||
} | ||
|
||
@Override | ||
public @Nullable HttpResponse onServiceException(ServiceRequestContext ctx, Throwable cause) { | ||
if (cause instanceof HttpResponseException) { | ||
final HttpResponse oldRes = serverErrorHandler.onServiceException(ctx, cause); | ||
if (oldRes == null) { | ||
return null; | ||
} | ||
final CorsService corsService = ctx.config().service().as(CorsService.class); | ||
if (corsService == null) { | ||
return oldRes; | ||
} | ||
return oldRes | ||
.recover(HttpResponseException.class, | ||
ex -> ex.httpResponse() | ||
.mapHeaders(oldHeaders -> addCorsHeaders(ctx, | ||
corsService.config(), | ||
oldHeaders))); | ||
} else { | ||
return serverErrorHandler.onServiceException(ctx, cause); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.