From bb2082165c2880eabd05d7cdc60ac0ce3034e93e Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Wed, 18 May 2022 15:49:46 +0200 Subject: [PATCH] refactor: extracted PortForwarderWebsocketListener and added tests Signed-off-by: Marc Nuri --- .../dsl/internal/PortForwarderWebsocket.java | 166 +------------ .../PortForwarderWebsocketListener.java | 230 ++++++++++++++++++ .../PortForwarderWebsocketListenerTest.java | 216 ++++++++++++++++ 3 files changed, 452 insertions(+), 160 deletions(-) create mode 100644 kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListener.java create mode 100644 kubernetes-client/src/test/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListenerTest.java diff --git a/kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocket.java b/kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocket.java index a3039a78cb..972ea06e49 100644 --- a/kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocket.java +++ b/kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocket.java @@ -21,7 +21,6 @@ import io.fabric8.kubernetes.client.http.WebSocket; import io.fabric8.kubernetes.client.utils.URLUtils; import io.fabric8.kubernetes.client.utils.Utils; -import io.fabric8.kubernetes.client.utils.internal.SerialExecutor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,15 +29,12 @@ import java.net.InetSocketAddress; import java.net.URI; import java.net.URL; -import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; @@ -170,157 +166,7 @@ public Collection getServerThrowables() { @Override public PortForward forward(URL resourceBaseUrl, int port, final ReadableByteChannel in, final WritableByteChannel out) { - final AtomicBoolean alive = new AtomicBoolean(true); - final AtomicBoolean errorOccurred = new AtomicBoolean(false); - final Collection clientThrowables = Collections.synchronizedCollection(new ArrayList<>()); - final Collection serverThrowables = Collections.synchronizedCollection(new ArrayList<>()); - final String logPrefix = "FWD"; - - WebSocket.Listener listener = new WebSocket.Listener() { - private int messagesRead = 0; - - private final ExecutorService pumperService = Executors.newSingleThreadExecutor(); - private final SerialExecutor serialExecutor = new SerialExecutor(Utils.getCommonExecutorSerive()); - - @Override - public void onOpen(final WebSocket webSocket) { - LOG.debug("{}: onOpen", logPrefix); - - if (in != null) { - pumperService.execute(() -> { - ByteBuffer buffer = ByteBuffer.allocate(4096); - int read; - try { - do { - buffer.clear(); - buffer.put((byte) 0); // channel byte - read = in.read(buffer); - if (read > 0) { - buffer.flip(); - webSocket.send(buffer); - } else if (read == 0) { - // in is non-blocking, prevent a busy loop - Thread.sleep(50); - } - } while (alive.get() && read >= 0); - - } catch (IOException | InterruptedException e) { - if (alive.get()) { - clientThrowables.add(e); - LOG.error("Error while writing client data"); - closeBothWays(webSocket, 1001, "Client error"); - } - } - }); - } - } - - @Override - public void onMessage(WebSocket webSocket, String text) { - LOG.debug("{}: onMessage(String)", logPrefix); - onMessage(webSocket, ByteBuffer.wrap(text.getBytes(StandardCharsets.UTF_8))); - } - - @Override - public void onMessage(WebSocket webSocket, ByteBuffer buffer) { - messagesRead++; - if (messagesRead <= 2) { - // skip the first two messages, containing the ports used internally - webSocket.request(); - return; - } - - if (!buffer.hasRemaining()) { - errorOccurred.set(true); - LOG.error("Received an empty message"); - closeBothWays(webSocket, 1002, "Protocol error"); - } - - byte channel = buffer.get(); - if (channel < 0 || channel > 1) { - errorOccurred.set(true); - LOG.error("Received a wrong channel from the remote socket: {}", channel); - closeBothWays(webSocket, 1002, "Protocol error"); - } else if (channel == 1) { - // Error channel - errorOccurred.set(true); - LOG.error("Received an error from the remote socket"); - closeForwarder(); - } else { - // Data - if (out != null) { - serialExecutor.execute(() -> { - try { - while (buffer.hasRemaining()) { - int written = out.write(buffer); // channel byte already skipped - if (written == 0) { - // out is non-blocking, prevent a busy loop - Thread.sleep(50); - } - } - webSocket.request(); - } catch (IOException | InterruptedException e) { - if (alive.get()) { - clientThrowables.add(e); - LOG.error("Error while forwarding data to the client", e); - closeBothWays(webSocket, 1002, "Protocol error"); - } - } - }); - } - } - } - - @Override - public void onClose(WebSocket webSocket, int code, String reason) { - LOG.debug("{}: onClose. Code={}, Reason={}", logPrefix, code, reason); - if (alive.get()) { - closeForwarder(); - } - } - - @Override - public void onError(WebSocket webSocket, Throwable t) { - LOG.debug("{}: onFailure", logPrefix); - if (alive.get()) { - serverThrowables.add(t); - LOG.error("{}: Throwable received from websocket", logPrefix, t); - closeForwarder(); - } - } - - private void closeBothWays(WebSocket webSocket, int code, String message) { - LOG.debug("{}: Closing with code {} and reason: {}", logPrefix, code, message); - alive.set(false); - try { - webSocket.sendClose(code, message); - } catch (Exception e) { - serverThrowables.add(e); - LOG.error("Error while closing the websocket", e); - } - closeForwarder(); - } - - private void closeForwarder() { - alive.set(false); - if (in != null) { - try { - in.close(); - } catch (IOException e) { - LOG.error("{}: Error while closing the client input channel", logPrefix, e); - } - } - if (out != null && out != in) { - try { - out.close(); - } catch (IOException e) { - LOG.error("{}: Error while closing the client output channel", logPrefix, e); - } - } - pumperService.shutdownNow(); - serialExecutor.shutdownNow(); - } - }; + final PortForwarderWebsocketListener listener = new PortForwarderWebsocketListener(in, out); CompletableFuture socket = client .newWebSocketBuilder() .uri(URI.create(URLUtils.join(resourceBaseUrl.toString(), "portforward?ports=" + port))) @@ -334,7 +180,7 @@ private void closeForwarder() { return new PortForward() { @Override - public void close() throws IOException { + public void close() { socket.cancel(true); socket.whenComplete((w, t) -> { if (w != null) { @@ -345,22 +191,22 @@ public void close() throws IOException { @Override public boolean isAlive() { - return alive.get(); + return listener.isAlive(); } @Override public boolean errorOccurred() { - return errorOccurred.get() || !clientThrowables.isEmpty() || !serverThrowables.isEmpty(); + return listener.errorOccurred(); } @Override public Collection getClientThrowables() { - return clientThrowables; + return listener.getClientThrowables(); } @Override public Collection getServerThrowables() { - return serverThrowables; + return listener.getServerThrowables(); } }; } diff --git a/kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListener.java b/kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListener.java new file mode 100644 index 0000000000..d6e9085450 --- /dev/null +++ b/kubernetes-client/src/main/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListener.java @@ -0,0 +1,230 @@ +/** + * Copyright (C) 2015 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.fabric8.kubernetes.client.dsl.internal; + +import io.fabric8.kubernetes.client.http.WebSocket; +import io.fabric8.kubernetes.client.utils.Utils; +import io.fabric8.kubernetes.client.utils.internal.SerialExecutor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BooleanSupplier; + +public class PortForwarderWebsocketListener implements WebSocket.Listener { + + private static final Logger logger = LoggerFactory.getLogger(PortForwarderWebsocketListener.class); + private static final String LOG_PREFIX = "FWD"; + private static final String PROTOCOL_ERROR = "Protocol error"; + private static final int BUFFER_SIZE = 4096; + + private final ExecutorService pumperService = Executors.newSingleThreadExecutor(); + + private final SerialExecutor serialExecutor = new SerialExecutor(Utils.getCommonExecutorSerive()); + + private final AtomicBoolean alive = new AtomicBoolean(true); + + private final AtomicBoolean errorOccurred = new AtomicBoolean(false); + + final Collection clientThrowables = new CopyOnWriteArrayList<>(); + + final Collection serverThrowables = new CopyOnWriteArrayList<>(); + + private final ReadableByteChannel in; + + private final WritableByteChannel out; + + private int messagesRead = 0; + + public PortForwarderWebsocketListener(ReadableByteChannel in, WritableByteChannel out) { + this.in = in; + this.out = out; + } + + @Override + public void onOpen(final WebSocket webSocket) { + logger.debug("{}: onOpen", LOG_PREFIX); + if (in != null) { + pumperService.execute(() -> { + try { + pipe(in, webSocket, alive::get); + } catch (IOException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (alive.get()) { + clientThrowables.add(e); + logger.error("Error while writing client data"); + closeBothWays(webSocket, 1001, "Client error"); + } + } + }); + } + } + + @Override + public void onMessage(WebSocket webSocket, String text) { + logger.debug("{}: onMessage(String)", LOG_PREFIX); + onMessage(webSocket, ByteBuffer.wrap(text.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public void onMessage(WebSocket webSocket, ByteBuffer buffer) { + messagesRead++; + if (messagesRead <= 2) { + // skip the first two messages, containing the ports used internally + webSocket.request(); + return; + } + + if (!buffer.hasRemaining()) { + errorOccurred.set(true); + logger.error("Received an empty message"); + closeBothWays(webSocket, 1002, PROTOCOL_ERROR); + return; + } + + byte channel = buffer.get(); + if (channel < 0 || channel > 1) { + errorOccurred.set(true); + logger.error("Received a wrong channel from the remote socket: {}", channel); + closeBothWays(webSocket, 1002, PROTOCOL_ERROR); + } else if (channel == 1) { + // Error channel + errorOccurred.set(true); + logger.error("Received an error from the remote socket"); + closeForwarder(); + } else { + // Data + if (out != null) { + serialExecutor.execute(() -> { + try { + while (buffer.hasRemaining()) { + int written = out.write(buffer); // channel byte already skipped + if (written == 0) { + // out is non-blocking, prevent a busy loop + Thread.sleep(50); + } + } + webSocket.request(); + } catch (IOException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (alive.get()) { + clientThrowables.add(e); + logger.error("Error while forwarding data to the client", e); + closeBothWays(webSocket, 1002, PROTOCOL_ERROR); + } + } + }); + } + } + } + + @Override + public void onClose(WebSocket webSocket, int code, String reason) { + logger.debug("{}: onClose. Code={}, Reason={}", LOG_PREFIX, code, reason); + if (alive.get()) { + closeForwarder(); + } + } + + @Override + public void onError(WebSocket webSocket, Throwable t) { + logger.debug("{}: onFailure", LOG_PREFIX); + if (alive.get()) { + serverThrowables.add(t); + logger.error("{}: Throwable received from websocket", LOG_PREFIX, t); + closeForwarder(); + } + } + + boolean isAlive() { + return alive.get(); + } + + boolean errorOccurred() { + return errorOccurred.get() || !clientThrowables.isEmpty() || !serverThrowables.isEmpty(); + } + + Collection getClientThrowables() { + return clientThrowables; + } + + Collection getServerThrowables() { + return serverThrowables; + } + + private void closeBothWays(WebSocket webSocket, int code, String message) { + logger.debug("{}: Closing with code {} and reason: {}", LOG_PREFIX, code, message); + alive.set(false); + try { + webSocket.sendClose(code, message); + } catch (Exception e) { + serverThrowables.add(e); + logger.error("Error while closing the websocket", e); + } + closeForwarder(); + } + + private void closeForwarder() { + alive.set(false); + if (in != null) { + try { + in.close(); + } catch (IOException e) { + logger.error("{}: Error while closing the client input channel", LOG_PREFIX, e); + } + } + if (out != null && out != in) { + try { + out.close(); + } catch (IOException e) { + logger.error("{}: Error while closing the client output channel", LOG_PREFIX, e); + } + } + pumperService.shutdownNow(); + serialExecutor.shutdownNow(); + } + + private static void pipe(ReadableByteChannel in, WebSocket webSocket, BooleanSupplier isAlive) + throws IOException, InterruptedException { + final ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE); + int read; + do { + buffer.clear(); + buffer.put((byte) 0); // channel byte + read = in.read(buffer); + if (read > 0) { + buffer.flip(); + webSocket.send(buffer); + } else if (read == 0) { + // in is non-blocking, prevent a busy loop + Thread.sleep(50); + } + } while (isAlive.getAsBoolean() && read >= 0); + } +} diff --git a/kubernetes-client/src/test/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListenerTest.java b/kubernetes-client/src/test/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListenerTest.java new file mode 100644 index 0000000000..5cd6d8c346 --- /dev/null +++ b/kubernetes-client/src/test/java/io/fabric8/kubernetes/client/dsl/internal/PortForwarderWebsocketListenerTest.java @@ -0,0 +1,216 @@ +/** + * Copyright (C) 2015 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.fabric8.kubernetes.client.dsl.internal; + +import io.fabric8.kubernetes.client.http.WebSocket; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class PortForwarderWebsocketListenerTest { + + private WebSocket webSocket; + private ReadableByteChannel in; + private WritableByteChannel out; + private ByteArrayOutputStream outputContent; + private PortForwarderWebsocketListener listener; + + @BeforeEach + void setUp() { + webSocket = mock(WebSocket.class); + in = Channels.newChannel(new ByteArrayInputStream("THIS IS A TEST".getBytes(StandardCharsets.UTF_8))); + outputContent = new ByteArrayOutputStream(); + out = Channels.newChannel(outputContent); + } + + @AfterEach + void tearDown() throws IOException { + if (listener != null) { + listener.onClose(null, 1337, "Test ended"); + } + out.close(); + outputContent.close(); + in.close(); + } + + @Test + void onOpen_shouldPipeInChannelToWebSocket() { + listener = new PortForwarderWebsocketListener(in, out); + listener.onOpen(webSocket); + ArgumentCaptor contentTypeCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + // Then + verify(webSocket, timeout(10_000).times(1)).send(contentTypeCaptor.capture()); + assertThat(contentTypeCaptor.getValue()) + .extracting(StandardCharsets.UTF_8::decode) + .extracting(CharBuffer::toString).asString() + .startsWith("THIS IS A TEST"); + assertThat(in.isOpen()).isTrue(); + assertThat(out.isOpen()).isTrue(); + } + + @Test + void onOpen_withException_shouldCloseWebSocketAndStoreException() throws IOException { + final ReadableByteChannel inWithException = mock(ReadableByteChannel.class); + when(inWithException.read(any())).thenThrow(new IOException("Error reading packets")); + listener = new PortForwarderWebsocketListener(inWithException, out); + listener.onOpen(webSocket); + // Then + verify(webSocket, timeout(10_000).times(1)).sendClose(anyInt(), anyString()); + assertThat(listener.getClientThrowables()) + .singleElement() + .asInstanceOf(InstanceOfAssertFactories.throwable(IOException.class)) + .hasMessage("Error reading packets"); + } + + @Test + void onError_shouldStoreExceptionAndCloseChannels() { + listener = new PortForwarderWebsocketListener(in, out); + listener.onError(webSocket, new RuntimeException("Server error")); + // Then + assertThat(listener.getServerThrowables()) + .singleElement() + .asInstanceOf(InstanceOfAssertFactories.throwable(RuntimeException.class)) + .hasMessage("Server error"); + assertThat(in.isOpen()).isFalse(); + assertThat(out.isOpen()).isFalse(); + } + + @Test + void onClose_shouldCloseChannels() { + listener = new PortForwarderWebsocketListener(in, out); + listener.onClose(webSocket, 1337, "Test ended"); + // Then + assertThat(listener.getServerThrowables()).isEmpty(); + assertThat(in.isOpen()).isFalse(); + assertThat(out.isOpen()).isFalse(); + } + + @Test + void onMessage_shouldSkipTwoMessagesAndPipeTheThird() { + listener = new PortForwarderWebsocketListener(in, out); + doAnswer(i -> { + listener.onMessage(webSocket, "SKIP 2"); + return true; + }).doAnswer(i -> { + listener.onMessage(webSocket, ByteBuffer.wrap( + ByteBuffer.allocate(18).put((byte) 0).put("PROCESSED MESSAGE".getBytes(StandardCharsets.UTF_8)).array())); + return true; + }) + .doNothing() + .when(webSocket).request(); + listener.onMessage(webSocket, "SKIP 1"); + // Then + verify(webSocket, timeout(10_000).times(3)).request(); + assertThat(outputContent.toString()).contains("PROCESSED MESSAGE"); + } + + @Test + void onMessage_withEmptyMessage_shouldEndWithError() { + listener = new PortForwarderWebsocketListener(in, out); + doAnswer(i -> { + listener.onMessage(webSocket, "SKIP 2"); + return true; + }).doAnswer(i -> { + listener.onMessage(webSocket, ByteBuffer.wrap(new byte[0])); + return true; + }).when(webSocket).request(); + listener.onMessage(webSocket, "SKIP 1"); + // Then + verify(webSocket, timeout(10_000)).sendClose(1002, "Protocol error"); + assertThat(outputContent.toString()).isEmpty(); + assertThat(listener.errorOccurred()).isTrue(); + assertThat(listener.getServerThrowables()).isEmpty(); + assertThat(in.isOpen()).isFalse(); + assertThat(out.isOpen()).isFalse(); + } + + @Test + void onMessage_withServerClose_shouldSkipTwoMessagesAndPipeTheThird() { + listener = new PortForwarderWebsocketListener(in, out); + doAnswer(i -> { + listener.onMessage(webSocket, "SKIP 2"); + return true; + }).doAnswer(i -> { + listener.onMessage(webSocket, ByteBuffer.wrap( + ByteBuffer.allocate(18).put((byte) 0).put("PROCESSED MESSAGE".getBytes(StandardCharsets.UTF_8)).array())); + return true; + }).doAnswer(i -> { + listener.onClose(webSocket, 31337, "Transmission complete"); + return true; + }).when(webSocket).request(); + listener.onMessage(webSocket, "SKIP 1"); + // Then + await().atMost(10, TimeUnit.SECONDS).until(() -> !listener.isAlive()); + assertThat(outputContent.toString()).contains("PROCESSED MESSAGE"); + assertThat(listener.errorOccurred()).isFalse(); + assertThat(in.isOpen()).isFalse(); + assertThat(out.isOpen()).isFalse(); + } + + @Test + void onMessage_withWrongChannel_shouldLogAndEndWithError() { + try (MockedStatic loggerFactory = mockStatic(LoggerFactory.class)) { + final Logger logger = mock(Logger.class); + loggerFactory.when(() -> LoggerFactory.getLogger(PortForwarderWebsocketListener.class)).thenReturn(logger); + listener = new PortForwarderWebsocketListener(in, out); + doAnswer(i -> { + listener.onMessage(webSocket, "SKIP 2"); + return true; + }).doAnswer(i -> { + listener.onMessage(webSocket, ByteBuffer.wrap( + ByteBuffer.allocate(18).put((byte) 5).put("WRONG CHANNEL".getBytes(StandardCharsets.UTF_8)).array())); + return true; + }) + .doNothing() + .when(webSocket).request(); + listener.onMessage(webSocket, "SKIP 1"); + // Then + verify(webSocket, timeout(10_000)).sendClose(1002, "Protocol error"); + assertThat(outputContent.toString()).isEmpty(); + assertThat(listener.errorOccurred()).isTrue(); + verify(logger).error("Received a wrong channel from the remote socket: {}", (byte) 5); + } + + } +}