Skip to content

Commit

Permalink
Merge pull request #553 from nscuro/request-id-filter
Browse files Browse the repository at this point in the history
Add filter to track request ID
  • Loading branch information
stevespringett committed May 5, 2024
2 parents e7736ee + 0f7c17c commit 7b725ec
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
@@ -0,0 +1,59 @@
/*
* This file is part of Alpine.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) Steve Springett. All Rights Reserved.
*/
package alpine.server.filters;

import org.slf4j.MDC;

import javax.annotation.Priority;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.ext.Provider;
import java.io.IOException;
import java.util.UUID;
import java.util.regex.Pattern;

@Provider
@Priority(1)
public class RequestIdFilter implements ContainerRequestFilter, ContainerResponseFilter {

private static final Pattern REQUEST_ID_PATTERN = Pattern.compile("^[A-Za-z0-9._\\-=+]{16,192}$");

@Override
public void filter(final ContainerRequestContext requestContext) throws IOException {
String requestId = requestContext.getHeaderString("X-Request-Id");
if (requestId == null || !REQUEST_ID_PATTERN.matcher(requestId).matches()) {
requestId = UUID.randomUUID().toString();
}

requestContext.setProperty("requestId", requestId);
MDC.put("requestId", requestId);
}

@Override
public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext) throws IOException {
if (requestContext.getProperty("requestId") instanceof final String requestId) {
responseContext.getHeaders().putSingle("X-Request-Id", requestId);
}

MDC.remove("requestId");
}

}
@@ -0,0 +1,98 @@
/*
* This file is part of Alpine.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) Steve Springett. All Rights Reserved.
*/
package alpine.server.filters;

import org.assertj.core.api.SoftAssertions;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.core.MultivaluedHashMap;
import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

public class RequestIdFilterTest {

private RequestIdFilter requestIdFilter;
private ContainerRequestContext requestContextMock;
private ContainerResponseContext responseContextMock;

@Before
public void setUp() {
requestIdFilter = new RequestIdFilter();
requestContextMock = mock(ContainerRequestContext.class);
responseContextMock = mock(ContainerResponseContext.class);
}

@Test
public void testProvidedRequestId() throws Exception {
final Map<String, Boolean> testCases = Map.ofEntries(
Map.entry("a".repeat(15), false),
Map.entry("a".repeat(16), true),
Map.entry("a".repeat(192), true),
Map.entry("a".repeat(193), false),
Map.entry("Zm9vYmFyYmF6cXV4cXV1eA==", true),
Map.entry("112bfb53-eb65-41b5-a093-b73902f43447", true),
Map.entry("foo%24bar%40baz%C2%A7", false)
);

final var softAssertions = new SoftAssertions();
for (final Map.Entry<String, Boolean> entry : testCases.entrySet()) {
final String providedRequestId = entry.getKey();
final boolean shouldTakeProvidedRequestId = entry.getValue();

doReturn(providedRequestId).when(requestContextMock).getHeaderString(eq("X-Request-Id"));
requestIdFilter.filter(requestContextMock);

final ArgumentCaptor<String> requestIdCaptor = ArgumentCaptor.forClass(String.class);
verify(requestContextMock).setProperty(eq("requestId"), requestIdCaptor.capture());
Mockito.reset(requestContextMock);

if (shouldTakeProvidedRequestId) {
softAssertions.assertThat(requestIdCaptor.getValue()).isEqualTo(providedRequestId);
} else {
softAssertions.assertThat(requestIdCaptor.getValue())
.matches("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$");
}
}

softAssertions.assertAll();
}

@Test
public void testResponseHeader() throws Exception {
final var headers = new MultivaluedHashMap<String, Object>();
doReturn(headers).when(responseContextMock).getHeaders();

doReturn("foobarbazquxquux").when(requestContextMock).getProperty("requestId");
requestIdFilter.filter(requestContextMock, responseContextMock);

assertThat(headers).containsEntry("X-Request-Id", List.of("foobarbazquxquux"));
}

}

0 comments on commit 7b725ec

Please sign in to comment.