Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: extracted PortForwarderWebsocketListener and added tests #4159

Merged
merged 1 commit into from May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -21,7 +21,6 @@
import io.fabric8.kubernetes.client.http.WebSocket;
import io.fabric8.kubernetes.client.utils.URLUtils;
import io.fabric8.kubernetes.client.utils.Utils;
import io.fabric8.kubernetes.client.utils.internal.SerialExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -30,15 +29,12 @@
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -170,157 +166,7 @@ public Collection<Throwable> getServerThrowables() {

@Override
public PortForward forward(URL resourceBaseUrl, int port, final ReadableByteChannel in, final WritableByteChannel out) {
final AtomicBoolean alive = new AtomicBoolean(true);
final AtomicBoolean errorOccurred = new AtomicBoolean(false);
final Collection<Throwable> clientThrowables = Collections.synchronizedCollection(new ArrayList<>());
final Collection<Throwable> serverThrowables = Collections.synchronizedCollection(new ArrayList<>());
final String logPrefix = "FWD";

WebSocket.Listener listener = new WebSocket.Listener() {
private int messagesRead = 0;

private final ExecutorService pumperService = Executors.newSingleThreadExecutor();
private final SerialExecutor serialExecutor = new SerialExecutor(Utils.getCommonExecutorSerive());

@Override
public void onOpen(final WebSocket webSocket) {
LOG.debug("{}: onOpen", logPrefix);

if (in != null) {
pumperService.execute(() -> {
ByteBuffer buffer = ByteBuffer.allocate(4096);
int read;
try {
do {
buffer.clear();
buffer.put((byte) 0); // channel byte
read = in.read(buffer);
if (read > 0) {
buffer.flip();
webSocket.send(buffer);
} else if (read == 0) {
// in is non-blocking, prevent a busy loop
Thread.sleep(50);
}
} while (alive.get() && read >= 0);

} catch (IOException | InterruptedException e) {
if (alive.get()) {
clientThrowables.add(e);
LOG.error("Error while writing client data");
closeBothWays(webSocket, 1001, "Client error");
}
}
});
}
}

@Override
public void onMessage(WebSocket webSocket, String text) {
LOG.debug("{}: onMessage(String)", logPrefix);
onMessage(webSocket, ByteBuffer.wrap(text.getBytes(StandardCharsets.UTF_8)));
}

@Override
public void onMessage(WebSocket webSocket, ByteBuffer buffer) {
messagesRead++;
if (messagesRead <= 2) {
// skip the first two messages, containing the ports used internally
webSocket.request();
return;
}

if (!buffer.hasRemaining()) {
errorOccurred.set(true);
LOG.error("Received an empty message");
closeBothWays(webSocket, 1002, "Protocol error");
}

byte channel = buffer.get();
if (channel < 0 || channel > 1) {
errorOccurred.set(true);
LOG.error("Received a wrong channel from the remote socket: {}", channel);
closeBothWays(webSocket, 1002, "Protocol error");
} else if (channel == 1) {
// Error channel
errorOccurred.set(true);
LOG.error("Received an error from the remote socket");
closeForwarder();
} else {
// Data
if (out != null) {
serialExecutor.execute(() -> {
try {
while (buffer.hasRemaining()) {
int written = out.write(buffer); // channel byte already skipped
if (written == 0) {
// out is non-blocking, prevent a busy loop
Thread.sleep(50);
}
}
webSocket.request();
} catch (IOException | InterruptedException e) {
if (alive.get()) {
clientThrowables.add(e);
LOG.error("Error while forwarding data to the client", e);
closeBothWays(webSocket, 1002, "Protocol error");
}
}
});
}
}
}

@Override
public void onClose(WebSocket webSocket, int code, String reason) {
LOG.debug("{}: onClose. Code={}, Reason={}", logPrefix, code, reason);
if (alive.get()) {
closeForwarder();
}
}

@Override
public void onError(WebSocket webSocket, Throwable t) {
LOG.debug("{}: onFailure", logPrefix);
if (alive.get()) {
serverThrowables.add(t);
LOG.error("{}: Throwable received from websocket", logPrefix, t);
closeForwarder();
}
}

private void closeBothWays(WebSocket webSocket, int code, String message) {
LOG.debug("{}: Closing with code {} and reason: {}", logPrefix, code, message);
alive.set(false);
try {
webSocket.sendClose(code, message);
} catch (Exception e) {
serverThrowables.add(e);
LOG.error("Error while closing the websocket", e);
}
closeForwarder();
}

private void closeForwarder() {
alive.set(false);
if (in != null) {
try {
in.close();
} catch (IOException e) {
LOG.error("{}: Error while closing the client input channel", logPrefix, e);
}
}
if (out != null && out != in) {
try {
out.close();
} catch (IOException e) {
LOG.error("{}: Error while closing the client output channel", logPrefix, e);
}
}
pumperService.shutdownNow();
serialExecutor.shutdownNow();
}
};
final PortForwarderWebsocketListener listener = new PortForwarderWebsocketListener(in, out);
CompletableFuture<WebSocket> socket = client
.newWebSocketBuilder()
.uri(URI.create(URLUtils.join(resourceBaseUrl.toString(), "portforward?ports=" + port)))
Expand All @@ -334,7 +180,7 @@ private void closeForwarder() {

return new PortForward() {
@Override
public void close() throws IOException {
public void close() {
socket.cancel(true);
socket.whenComplete((w, t) -> {
if (w != null) {
Expand All @@ -345,22 +191,22 @@ public void close() throws IOException {

@Override
public boolean isAlive() {
return alive.get();
return listener.isAlive();
}

@Override
public boolean errorOccurred() {
return errorOccurred.get() || !clientThrowables.isEmpty() || !serverThrowables.isEmpty();
return listener.errorOccurred();
}

@Override
public Collection<Throwable> getClientThrowables() {
return clientThrowables;
return listener.getClientThrowables();
}

@Override
public Collection<Throwable> getServerThrowables() {
return serverThrowables;
return listener.getServerThrowables();
}
};
}
Expand Down