Skip to content

Commit

Permalink
fix: Jetty WS requests and responses intercepted
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Nuri <marc@marcnuri.com>
  • Loading branch information
manusa committed Nov 30, 2022
1 parent 97fcaca commit c134637
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 33 deletions.
Expand Up @@ -268,7 +268,7 @@ public <T> CompletableFuture<AsyncResponse<T>> sendAsync(HttpRequest request,
JdkHttpRequestImpl jdkRequest = (JdkHttpRequestImpl) request;
JdkHttpRequestImpl.BuilderImpl builderImpl = jdkRequest.newBuilder();
for (Interceptor interceptor : builder.getInterceptors().values()) {
Interceptor.useConfig(interceptor, config).before(builderImpl, jdkRequest);
Interceptor.useConfig(config).apply(interceptor).before(builderImpl, jdkRequest);
jdkRequest = builderImpl.build();
}

Expand All @@ -281,7 +281,7 @@ public <T> CompletableFuture<AsyncResponse<T>> sendAsync(HttpRequest request,
cf = cf.thenCompose(ar -> {
java.net.http.HttpResponse<T> response = ar.response;
if (response != null && !HttpResponse.isSuccessful(response.statusCode())) {
return Interceptor.useConfig(interceptor, config).afterFailure(builderImpl, new JdkHttpResponseImpl<>(response))
return Interceptor.useConfig(config).apply(interceptor).afterFailure(builderImpl, new JdkHttpResponseImpl<>(response))
.thenCompose(b -> {
if (b) {
HandlerAndAsyncBody<T> interceptedHandlerAndAsyncBody = handlerAndAsyncBodySupplier.get();
Expand Down Expand Up @@ -326,9 +326,8 @@ public WebSocketResponse(WebSocket w, java.net.http.WebSocketHandshakeException
public CompletableFuture<WebSocket> buildAsync(JdkWebSocketImpl.BuilderImpl webSocketBuilder, Listener listener) {
JdkWebSocketImpl.BuilderImpl copy = webSocketBuilder.copy();

for (Interceptor interceptor : builder.getInterceptors().values()) {
Interceptor.useConfig(interceptor, config).before(copy, new JdkHttpRequestImpl(null, copy.asRequest()));
}
builder.getInterceptors().values().stream().map(Interceptor.useConfig(config))
.forEach(i -> i.before(copy, new JdkHttpRequestImpl(null, copy.asRequest())));

CompletableFuture<WebSocket> result = new CompletableFuture<>();

Expand All @@ -337,7 +336,7 @@ public CompletableFuture<WebSocket> buildAsync(JdkWebSocketImpl.BuilderImpl webS
for (Interceptor interceptor : builder.getInterceptors().values()) {
cf = cf.thenCompose(response -> {
if (response.wshse != null && response.wshse.getResponse() != null) {
return Interceptor.useConfig(interceptor, config)
return Interceptor.useConfig(config).apply(interceptor)
.afterFailure(copy, new JdkHttpResponseImpl<>(response.wshse.getResponse())).thenCompose(b -> {
if (b) {
return this.internalBuildAsync(copy, listener);
Expand Down
Expand Up @@ -30,12 +30,14 @@
import org.eclipse.jetty.websocket.client.WebSocketClient;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;

import static io.fabric8.kubernetes.client.http.BufferUtil.copy;
import static io.fabric8.kubernetes.client.http.StandardMediaTypes.APPLICATION_OCTET_STREAM;
Expand Down Expand Up @@ -92,7 +94,9 @@ protected void onContent(ByteBuffer content) throws Exception {

@Override
public WebSocket.Builder newWebSocketBuilder() {
return new JettyWebSocketBuilder(jettyWs, builder.getReadTimeout());
return new JettyWebSocketBuilder(
jettyWs, builder.getReadTimeout(),
interceptors.stream().map(Interceptor.useConfig(config)).collect(Collectors.toCollection(ArrayList::new)));
}

@Override
Expand All @@ -111,7 +115,7 @@ private Request newRequest(StandardHttpRequest originalRequest) {
throw KubernetesClientException.launderThrowable(e);
}
final var requestBuilder = originalRequest.toBuilder();
interceptors.forEach(i -> Interceptor.useConfig(i, config).before(requestBuilder, originalRequest));
interceptors.stream().map(Interceptor.useConfig(config)).forEach(i -> i.before(requestBuilder, originalRequest));
final var request = requestBuilder.build();

var jettyRequest = jetty.newRequest(request.uri()).method(request.method());
Expand All @@ -133,7 +137,7 @@ private <T> CompletableFuture<HttpResponse<T>> interceptResponse(
for (var interceptor : interceptors) {
originalResponse = originalResponse.thenCompose(r -> {
if (!r.isSuccessful()) {
return Interceptor.useConfig(interceptor, config).afterFailure(builder, r)
return Interceptor.useConfig(config).apply(interceptor).afterFailure(builder, r)
.thenCompose(b -> {
if (Boolean.TRUE.equals(b)) {
return function.apply(builder.build());
Expand Down
Expand Up @@ -17,6 +17,8 @@

import io.fabric8.kubernetes.client.KubernetesClientException;
import io.fabric8.kubernetes.client.http.AbstractBasicBuilder;
import io.fabric8.kubernetes.client.http.Interceptor;
import io.fabric8.kubernetes.client.http.StandardHttpHeaders;
import io.fabric8.kubernetes.client.http.StandardHttpRequest;
import io.fabric8.kubernetes.client.http.WebSocket;
import io.fabric8.kubernetes.client.http.WebSocketHandshakeException;
Expand All @@ -27,6 +29,8 @@
import org.eclipse.jetty.websocket.client.WebSocketClient;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
Expand All @@ -37,27 +41,32 @@ public class JettyWebSocketBuilder extends AbstractBasicBuilder<JettyWebSocketBu

private final WebSocketClient webSocketClient;
private final Duration handshakeTimeout;
private final Collection<Interceptor> interceptors;
private String subprotocol;

public JettyWebSocketBuilder(WebSocketClient webSocketClient, Duration handshakeTimeout) {
public JettyWebSocketBuilder(
WebSocketClient webSocketClient, Duration handshakeTimeout, Collection<Interceptor> interceptors) {
this.webSocketClient = webSocketClient;
this.handshakeTimeout = handshakeTimeout;
this.interceptors = interceptors;
}

@Override
public CompletableFuture<WebSocket> buildAsync(WebSocket.Listener listener) {
try {
webSocketClient.start();
final var requestBuilder = copy(this);
interceptors.forEach(i -> i.before(requestBuilder, new StandardHttpHeaders(requestBuilder.getHeaders())));
final ClientUpgradeRequest cur = new ClientUpgradeRequest();
if (Utils.isNotNullOrEmpty(subprotocol)) {
cur.setSubProtocols(subprotocol);
if (Utils.isNotNullOrEmpty(requestBuilder.subprotocol)) {
cur.setSubProtocols(requestBuilder.subprotocol);
}
cur.setHeaders(getHeaders());
cur.setTimeout(handshakeTimeout.toMillis(), TimeUnit.MILLISECONDS);
cur.setHeaders(requestBuilder.getHeaders());
cur.setTimeout(requestBuilder.handshakeTimeout.toMillis(), TimeUnit.MILLISECONDS);
// Extra-future required because we can't Map the UpgradeException to a WebSocketHandshakeException easily
final CompletableFuture<WebSocket> future = new CompletableFuture<>();
final var webSocket = new JettyWebSocket(listener);
return webSocketClient.connect(webSocket, Objects.requireNonNull(WebSocket.toWebSocketUri(getUri())), cur)
return webSocketClient.connect(webSocket, Objects.requireNonNull(WebSocket.toWebSocketUri(requestBuilder.getUri())), cur)
.thenApply(s -> webSocket)
.exceptionally(ex -> {
if (ex instanceof CompletionException && ex.getCause() instanceof UpgradeException) {
Expand Down Expand Up @@ -91,4 +100,13 @@ private static WebSocketHandshakeException toHandshakeException(UpgradeException
null))
.initCause(ex);
}

private static JettyWebSocketBuilder copy(JettyWebSocketBuilder original) {
final var copy = new JettyWebSocketBuilder(
original.webSocketClient, original.handshakeTimeout, new ArrayList<>(original.interceptors));
copy.uri(original.getUri());
original.getHeaders().forEach((h, values) -> values.forEach(v -> copy.header(h, v)));
copy.subprotocol(original.subprotocol);
return copy;
}
}
Expand Up @@ -57,7 +57,7 @@ void buildAsyncConnectsAndUpgrades() throws Exception {
.done()
.always();
final var open = new AtomicBoolean(false);
new JettyWebSocketBuilder(new WebSocketClient(new HttpClient()), Duration.ZERO)
new JettyWebSocketBuilder(new WebSocketClient(new HttpClient()), Duration.ZERO, Collections.emptyList())
.uri(URI.create(server.url("/websocket-test")))
.buildAsync(new WebSocket.Listener() {
@Override
Expand All @@ -71,7 +71,7 @@ public void onOpen(WebSocket webSocket) {
@Test
void buildAsyncCantUpgradeThrowsWebSocketHandshakeException() {
final var result = assertThrows(ExecutionException.class,
() -> new JettyWebSocketBuilder(new WebSocketClient(new HttpClient()), Duration.ZERO)
() -> new JettyWebSocketBuilder(new WebSocketClient(new HttpClient()), Duration.ZERO, Collections.emptyList())
.uri(URI.create(server.url("/not-found")))
.buildAsync(new WebSocket.Listener() {
})
Expand All @@ -87,7 +87,7 @@ void buildAsyncIncludesRequiredHeadersAndPropagatesConfigured() throws Exception
.done()
.always();
final var open = new AtomicBoolean(false);
new JettyWebSocketBuilder(new WebSocketClient(new HttpClient()), Duration.ZERO)
new JettyWebSocketBuilder(new WebSocketClient(new HttpClient()), Duration.ZERO, Collections.emptyList())
.header("A-Random-Header", "A-Random-Value")
.subprotocol("amqp")
.uri(URI.create(server.url("/websocket-headers-test")))
Expand Down
Expand Up @@ -65,13 +65,13 @@ public Response intercept(Chain chain) throws IOException {
Request.Builder requestBuilder = chain.request().newBuilder();
Config config = chain.request().tag(Config.class);
OkHttpRequestImpl.BuilderImpl builderImpl = new OkHttpRequestImpl.BuilderImpl(requestBuilder);
io.fabric8.kubernetes.client.http.Interceptor.useConfig(interceptor, config)
io.fabric8.kubernetes.client.http.Interceptor.useConfig(config).apply(interceptor)
.before(new OkHttpRequestImpl.BuilderImpl(requestBuilder), new OkHttpRequestImpl(chain.request()));
Response response = chain.proceed(requestBuilder.build());
if (!response.isSuccessful()) {
// for okhttp this token refresh will be blocking
try {
boolean call = io.fabric8.kubernetes.client.http.Interceptor.useConfig(interceptor, config)
boolean call = io.fabric8.kubernetes.client.http.Interceptor.useConfig(config).apply(interceptor)
.afterFailure(builderImpl, new OkHttpResponseImpl<>(response, InputStream.class)).get();
if (call) {
response.close();
Expand Down
Expand Up @@ -20,6 +20,7 @@
import io.fabric8.kubernetes.client.RequestConfig;

import java.util.concurrent.CompletableFuture;
import java.util.function.UnaryOperator;

public interface Interceptor {

Expand All @@ -34,11 +35,13 @@ default Interceptor withConfig(Config config) {
return this;
}

static Interceptor useConfig(Interceptor interceptor, Config config) {
if (config == null) {
return interceptor;
}
return interceptor.withConfig(config);
static UnaryOperator<Interceptor> useConfig(Config config) {
return interceptor -> {
if (config == null) {
return interceptor;
}
return interceptor.withConfig(config);
};
}

/**
Expand Down
Expand Up @@ -17,8 +17,8 @@

import io.fabric8.kubernetes.client.Config;
import io.fabric8.mockwebserver.DefaultMockServer;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

Expand All @@ -34,22 +34,22 @@ public abstract class AbstractInterceptorTest {

private static DefaultMockServer server;

@BeforeAll
static void beforeAll() {
@BeforeEach
void startServer() {
server = new DefaultMockServer(false);
server.start();
}

@AfterAll
static void afterAll() {
@AfterEach
void stopServer() {
server.shutdown();
}

protected abstract HttpClient.Factory getHttpClientFactory();

@Test
@DisplayName("before, should add a header to the HTTP request")
public void beforeAddsHeaderToRequest() throws Exception {
@DisplayName("before (HTTP), should add a header to the HTTP request")
public void beforeHttpAddsHeaderToRequest() throws Exception {
// Given
final HttpClient.Builder builder = getHttpClientFactory().newBuilder()
.addOrReplaceInterceptor("test", new Interceptor() {
Expand All @@ -68,6 +68,79 @@ public void before(BasicBuilder builder, HttpHeaders headers) {
.containsEntry("test-header", Collections.singletonList("Test-Value"));
}

@Test
@DisplayName("before (HTTP), should modify the HTTP request URI")
public void beforeHttpModifiesRequestUri() throws Exception {
// Given
final HttpClient.Builder builder = getHttpClientFactory().newBuilder()
.addOrReplaceInterceptor("test", new Interceptor() {
@Override
public void before(BasicBuilder builder, HttpHeaders headers) {
builder.uri(URI.create(server.url("valid-url")));
}
});
// When
try (HttpClient client = builder.build()) {
client.sendAsync(client.newHttpRequestBuilder().uri(server.url("/invalid-url")).build(), String.class)
.get(10L, TimeUnit.SECONDS);
}
// Then
assertThat(server.getRequestCount()).isEqualTo(1);
assertThat(server.getLastRequest().getPath()).isEqualTo("/valid-url");
}

@Test
@DisplayName("before (WS), should add a header to the HTTP request")
public void beforeWsAddsHeaderToRequest() throws Exception {
// Given
server.expect().withPath("/intercept-before")
.andUpgradeToWebSocket()
.open().done().always();
final HttpClient.Builder builder = getHttpClientFactory().newBuilder()
.addOrReplaceInterceptor("test", new Interceptor() {
@Override
public void before(BasicBuilder builder, HttpHeaders headers) {
builder.header("Test-Header", "Test-Value");
}
});
try (HttpClient client = builder.build()) {
// When
client.newWebSocketBuilder()
.uri(URI.create(server.url("intercept-before")))
.buildAsync(new WebSocket.Listener() {
}).get(10L, TimeUnit.SECONDS);
}
// Then
assertThat(server.getLastRequest().getHeaders().toMultimap())
.containsEntry("test-header", Collections.singletonList("Test-Value"));
}

@Test
@DisplayName("before (WS), should modify the HTTP request URI")
public void beforeWsModifiesRequestUri() throws Exception {
// Given
server.expect().withPath("/valid-url")
.andUpgradeToWebSocket()
.open().done().always();
final HttpClient.Builder builder = getHttpClientFactory().newBuilder()
.addOrReplaceInterceptor("test", new Interceptor() {
@Override
public void before(BasicBuilder builder, HttpHeaders headers) {
builder.uri(URI.create(server.url("valid-url")));
}
});
try (HttpClient client = builder.build()) {
// When
client.newWebSocketBuilder()
.uri(URI.create(server.url("invalid-url")))
.buildAsync(new WebSocket.Listener() {
}).get(10L, TimeUnit.SECONDS);
}
// Then
assertThat(server.getRequestCount()).isEqualTo(1);
assertThat(server.getLastRequest().getPath()).isEqualTo("/valid-url");
}

@Test
@DisplayName("afterFailure (HTTP), replaces the HttpResponse produced by HttpClient.sendAsync")
public void afterHttpFailureReplacesResponseInSendAsync() throws Exception {
Expand Down

0 comments on commit c134637

Please sign in to comment.