diff --git a/jetty-util/src/main/java/org/eclipse/jetty/util/compression/CompressionPool.java b/jetty-util/src/main/java/org/eclipse/jetty/util/compression/CompressionPool.java index ce972ceae8ad..8cf5c9781a24 100644 --- a/jetty-util/src/main/java/org/eclipse/jetty/util/compression/CompressionPool.java +++ b/jetty-util/src/main/java/org/eclipse/jetty/util/compression/CompressionPool.java @@ -16,9 +16,11 @@ import java.io.Closeable; import org.eclipse.jetty.util.Pool; -import org.eclipse.jetty.util.component.AbstractLifeCycle; +import org.eclipse.jetty.util.annotation.ManagedObject; +import org.eclipse.jetty.util.component.ContainerLifeCycle; -public abstract class CompressionPool extends AbstractLifeCycle +@ManagedObject +public abstract class CompressionPool extends ContainerLifeCycle { public static final int DEFAULT_CAPACITY = 1024; @@ -51,6 +53,11 @@ public void setCapacity(int capacity) _capacity = capacity; } + public Pool getPool() + { + return _pool; + } + protected abstract T newPooled(); protected abstract void end(T object); @@ -85,7 +92,10 @@ public void release(Entry entry) protected void doStart() throws Exception { if (_capacity > 0) + { _pool = new Pool<>(Pool.StrategyType.RANDOM, _capacity, true); + addBean(_pool); + } super.doStart(); } @@ -95,6 +105,7 @@ public void doStop() throws Exception if (_pool != null) { _pool.close(); + removeBean(_pool); _pool = null; } super.doStop(); diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/Extension.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/Extension.java index 107540084416..33f0f05e04e0 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/Extension.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/Extension.java @@ -13,16 +13,25 @@ package org.eclipse.jetty.websocket.core; +import java.io.Closeable; + /** * Interface for WebSocket Extensions. *

* That {@link Frame}s are passed through the Extension via the {@link IncomingFrames} and {@link OutgoingFrames} interfaces */ -public interface Extension extends IncomingFrames, OutgoingFrames +public interface Extension extends IncomingFrames, OutgoingFrames, Closeable { void init(ExtensionConfig config, WebSocketComponents components); + /** + * Used to clean up any resources after connection close. + */ + default void close() + { + } + /** * The active configuration for this extension. * diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/ExtensionStack.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/ExtensionStack.java index 8204ec3da9e6..b894877b8b1e 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/ExtensionStack.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/ExtensionStack.java @@ -60,6 +60,22 @@ public ExtensionStack(WebSocketComponents components, Behavior behavior) this.behavior = behavior; } + public void close() + { + for (Extension ext : extensions) + { + try + { + ext.close(); + } + catch (Throwable t) + { + if (LOG.isDebugEnabled()) + LOG.debug("Extension Error During Close", t); + } + } + } + @ManagedAttribute(name = "Extension List", readonly = true) public List getExtensions() { diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java index b5ac85408c52..e157ce65177b 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java @@ -44,7 +44,6 @@ public class FrameFlusher extends IteratingCallback { public static final Frame FLUSH_FRAME = new Frame(OpCode.BINARY); private static final Logger LOG = LoggerFactory.getLogger(FrameFlusher.class); - private static final Throwable CLOSED_CHANNEL = new ClosedChannelException(); private final AutoLock lock = new AutoLock(); private final LongAdder messagesOut = new LongAdder(); @@ -185,7 +184,15 @@ public void onClose(Throwable cause) { try (AutoLock l = lock.lock()) { - closedCause = cause == null ? CLOSED_CHANNEL : cause; + // TODO: find a way to not create exception if cause is null. + closedCause = cause == null ? new ClosedChannelException() + { + @Override + public Throwable fillInStackTrace() + { + return this; + } + } : cause; } iterate(); } diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java index ec95d8faba1d..a3aa23628ea2 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java @@ -14,6 +14,7 @@ package org.eclipse.jetty.websocket.core.internal; import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @@ -146,6 +147,24 @@ public void init(final ExtensionConfig config, WebSocketComponents components) super.init(configNegotiated, components); } + @Override + public void close() + { + // TODO: use IteratingCallback.close() instead of creating exception with failFlusher methods. + ClosedChannelException exception = new ClosedChannelException() + { + @Override + public Throwable fillInStackTrace() + { + return this; + } + }; + incomingFlusher.failFlusher(exception); + outgoingFlusher.failFlusher(exception); + releaseInflater(); + releaseDeflater(); + } + private static String toDetail(Inflater inflater) { return String.format("Inflater[finished=%b,read=%d,written=%d,remaining=%d,in=%d,out=%d]", inflater.finished(), inflater.getBytesRead(), diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/TransformingFlusher.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/TransformingFlusher.java index f2fd8844929a..5bed90a3d448 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/TransformingFlusher.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/TransformingFlusher.java @@ -77,6 +77,34 @@ public final void sendFrame(Frame frame, Callback callback, boolean batch) notifyCallbackFailure(callback, failure); } + /** + * Used to fail this flusher possibly from an external event such as a callback. + * @param t the failure. + */ + public void failFlusher(Throwable t) + { + // TODO: find a way to close the flusher in non error case without exception. + boolean failed = false; + try (AutoLock l = lock.lock()) + { + if (failure == null) + { + failure = t; + failed = true; + } + else + { + failure.addSuppressed(t); + } + } + + if (failed) + { + flusher.failed(t); + flusher.iterate(); + } + } + private void onFailure(Throwable t) { try (AutoLock l = lock.lock()) @@ -103,8 +131,14 @@ private class Flusher extends IteratingCallback implements Callback private FrameEntry current; @Override - protected Action process() + protected Action process() throws Throwable { + try (AutoLock l = lock.lock()) + { + if (failure != null) + throw failure; + } + if (finished) { if (current != null) @@ -134,8 +168,11 @@ protected void onCompleteFailure(Throwable t) if (log.isDebugEnabled()) log.debug("onCompleteFailure {}", t.toString()); - notifyCallbackFailure(current.callback, t); - current = null; + if (current != null) + { + notifyCallbackFailure(current.callback, t); + current = null; + } onFailure(t); } } diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketCoreSession.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketCoreSession.java index e340faf922a5..757c4a947687 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketCoreSession.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketCoreSession.java @@ -254,12 +254,13 @@ public void onEof() closeConnection(sessionState.getCloseStatus(), Callback.NOOP); } - public void closeConnection(CloseStatus closeStatus, Callback callback) + private void closeConnection(CloseStatus closeStatus, Callback callback) { if (LOG.isDebugEnabled()) LOG.debug("closeConnection() {} {}", closeStatus, this); abort(); + extensionStack.close(); // Forward Errors to Local WebSocket EndPoint if (closeStatus.isAbnormal() && closeStatus.getCause() != null) diff --git a/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/PermessageDeflateBufferTest.java b/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/PermessageDeflateBufferTest.java index 5b0e74322be9..9ead21f3ca9b 100644 --- a/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/PermessageDeflateBufferTest.java +++ b/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/PermessageDeflateBufferTest.java @@ -13,30 +13,47 @@ package org.eclipse.jetty.websocket.tests; +import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.Random; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.compression.CompressionPool; +import org.eclipse.jetty.util.compression.DeflaterPool; +import org.eclipse.jetty.util.compression.InflaterPool; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.StatusCode; +import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose; +import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect; +import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage; import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.common.WebSocketSession; +import org.eclipse.jetty.websocket.core.internal.WebSocketCoreSession; +import org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; public class PermessageDeflateBufferTest @@ -44,6 +61,10 @@ public class PermessageDeflateBufferTest private Server server; private ServerConnector connector; private WebSocketClient client; + private JettyWebSocketServerContainer serverContainer; + private final FailEndPointOutgoing outgoingFailEndPoint = new FailEndPointOutgoing(); + private final FailEndPointIncoming incomingFailEndPoint = new FailEndPointIncoming(); + private final ServerSocket serverSocket = new ServerSocket(); // @checkstyle-disable-check : AvoidEscapedUnicodeCharactersCheck private static final List DICT = Arrays.asList( @@ -83,7 +104,7 @@ private static ByteBuffer randomBytes(int size) public void before() throws Exception { server = new Server(); - connector = new ServerConnector(server); + connector = new ServerConnector(server, 1, 1); server.addConnector(connector); ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS); @@ -93,10 +114,13 @@ public void before() throws Exception { container.setMaxTextMessageSize(65535); container.setInputBufferSize(16384); - container.addMapping("/", ServerSocket.class); + container.addMapping("/", (req, resp) -> serverSocket); + container.addMapping("/outgoingFail", (req, resp) -> outgoingFailEndPoint); + container.addMapping("/incomingFail", (req, resp) -> incomingFailEndPoint); }); server.start(); + serverContainer = JettyWebSocketServerContainer.getContainer(contextHandler.getServletContext()); client = new WebSocketClient(); client.start(); } @@ -157,4 +181,178 @@ public void testPermessageDeflateFragmentedBinaryMessage() throws Exception assertTrue(socket.closeLatch.await(5, TimeUnit.SECONDS)); assertThat(socket.closeCode, equalTo(StatusCode.NORMAL)); } + + @Test + public void testClientPartialMessageThenServerIdleTimeout() throws Exception + { + Duration idleTimeout = Duration.ofMillis(1000); + serverContainer.setIdleTimeout(idleTimeout); + + ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest(); + clientUpgradeRequest.addExtensions("permessage-deflate"); + + EventSocket socket = new EventSocket(); + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/incomingFail"); + Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS); + + session.getRemote().sendPartialString("partial", false); + + // Wait for the idle timeout to elapse. + assertTrue(incomingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS)); + + server.getContainedBeans(InflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump())); + server.getContainedBeans(DeflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump())); + } + + @Test + public void testClientPartialMessageThenClientClose() throws Exception + { + ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest(); + clientUpgradeRequest.addExtensions("permessage-deflate"); + + PartialTextSocket socket = new PartialTextSocket(); + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/incomingFail"); + Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS); + + session.getRemote().sendPartialString("partial", false); + // Wait for the server to process the partial message. + assertThat(socket.partialMessages.poll(5, TimeUnit.SECONDS), equalTo("partial" + "last=true")); + + // Abruptly close the connection from the client. + ((WebSocketCoreSession)((WebSocketSession)session).getCoreSession()).getConnection().getEndPoint().close(); + + // Wait for the server to process the close. + assertTrue(incomingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS)); + + server.getContainedBeans(InflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump())); + server.getContainedBeans(DeflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump())); + } + + @Test + public void testServerPartialMessageThenServerIdleTimeout() throws Exception + { + Duration idleTimeout = Duration.ofMillis(1000); + serverContainer.setIdleTimeout(idleTimeout); + + ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest(); + clientUpgradeRequest.addExtensions("permessage-deflate"); + + EventSocket socket = new EventSocket(); + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/outgoingFail"); + Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS); + + session.getRemote().sendString("hello"); + + // Wait for the idle timeout to elapse. + assertTrue(outgoingFailEndPoint.closeLatch.await(2 * idleTimeout.toMillis(), TimeUnit.SECONDS)); + + server.getContainedBeans(InflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump())); + server.getContainedBeans(DeflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump())); + } + + @Test + public void testServerPartialMessageThenClientClose() throws Exception + { + ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest(); + clientUpgradeRequest.addExtensions("permessage-deflate"); + + PartialTextSocket socket = new PartialTextSocket(); + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/outgoingFail"); + Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS); + + session.getRemote().sendString("hello"); + // Wait for the server to process the message. + assertThat(socket.partialMessages.poll(5, TimeUnit.SECONDS), equalTo("hello" + "last=false")); + + // Abruptly close the connection from the client. + ((WebSocketCoreSession)((WebSocketSession)session).getCoreSession()).getConnection().getEndPoint().close(); + + // Wait for the server to process the close. + assertTrue(outgoingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS)); + + server.getContainedBeans(InflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump())); + server.getContainedBeans(DeflaterPool.class).stream() + .map(CompressionPool::getPool) + .forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump())); + } + + @WebSocket + public static class PartialTextSocket + { + private static final Logger LOG = LoggerFactory.getLogger(EventSocket.class); + + public Session session; + public BlockingQueue partialMessages = new BlockingArrayQueue<>(); + public CountDownLatch openLatch = new CountDownLatch(1); + public CountDownLatch closeLatch = new CountDownLatch(1); + + @OnWebSocketConnect + public void onOpen(Session session) + { + this.session = session; + openLatch.countDown(); + } + + @OnWebSocketMessage + public void onMessage(String message, boolean last) throws IOException + { + partialMessages.offer(message + "last=" + last); + } + + @OnWebSocketClose + public void onClose(int statusCode, String reason) + { + closeLatch.countDown(); + } + } + + @WebSocket + public static class FailEndPointOutgoing + { + public CountDownLatch closeLatch = new CountDownLatch(1); + + @OnWebSocketMessage + public void onMessage(Session session, String message) throws IOException + { + session.getRemote().sendPartialString(message, false); + } + + @OnWebSocketClose + public void onClose(int statusCode, String reason) + { + closeLatch.countDown(); + } + } + + @WebSocket + public static class FailEndPointIncoming + { + public CountDownLatch closeLatch = new CountDownLatch(1); + + @OnWebSocketMessage + public void onMessage(Session session, String message, boolean last) throws IOException + { + session.getRemote().sendString(message); + } + + @OnWebSocketClose + public void onClose(int statusCode, String reason) + { + closeLatch.countDown(); + } + } }