diff --git a/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrAnnotatedEventDriver.java b/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrAnnotatedEventDriver.java index 4f2369e3a01c..32cc32868018 100644 --- a/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrAnnotatedEventDriver.java +++ b/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrAnnotatedEventDriver.java @@ -118,10 +118,8 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException if (LOG.isDebugEnabled()) LOG.debug("Binary Message InputStream"); - final MessageInputStream stream = new MessageInputStream(session); + MessageInputStream stream = new MessageInputStream(session); activeMessage = stream; - - // Always dispatch streaming read to another thread. dispatch(() -> { try @@ -133,7 +131,7 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException session.close(e); } - stream.close(); + stream.handlerComplete(); }); } } @@ -329,11 +327,8 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException if (LOG.isDebugEnabled()) LOG.debug("Text Message Writer"); - MessageInputStream inputStream = new MessageInputStream(session); - final MessageReader reader = new MessageReader(inputStream); - activeMessage = inputStream; - - // Always dispatch streaming read to another thread. + MessageReader reader = new MessageReader(session); + activeMessage = reader; dispatch(() -> { try @@ -343,9 +338,10 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException catch (Throwable e) { session.close(e); + return; } - inputStream.close(); + reader.handlerComplete(); }); } } diff --git a/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrEndpointEventDriver.java b/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrEndpointEventDriver.java index 4dc9f01272bb..7e5af34c62e9 100644 --- a/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrEndpointEventDriver.java +++ b/jetty-websocket/javax-websocket-client-impl/src/main/java/org/eclipse/jetty/websocket/jsr356/endpoints/JsrEndpointEventDriver.java @@ -100,9 +100,10 @@ else if (wrapper.wantsStreams()) catch (Throwable t) { session.close(t); + return; } - inputStream.close(); + inputStream.handlerComplete(); }); } else @@ -197,8 +198,7 @@ else if (wrapper.wantsStreams()) { @SuppressWarnings("unchecked") MessageHandler.Whole handler = (Whole)wrapper.getHandler(); - MessageInputStream inputStream = new MessageInputStream(session); - MessageReader reader = new MessageReader(inputStream); + MessageReader reader = new MessageReader(session); activeMessage = reader; dispatch(() -> { @@ -209,9 +209,10 @@ else if (wrapper.wantsStreams()) catch (Throwable t) { session.close(t); + return; } - inputStream.close(); + reader.handlerComplete(); }); } else diff --git a/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/BinaryStreamTest.java b/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/BinaryStreamTest.java index e5fe7a366824..ae9a31b280af 100644 --- a/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/BinaryStreamTest.java +++ b/jetty-websocket/javax-websocket-server-impl/src/test/java/org/eclipse/jetty/websocket/jsr356/server/BinaryStreamTest.java @@ -27,7 +27,9 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import javax.websocket.ClientEndpoint; +import javax.websocket.CloseReason; import javax.websocket.ContainerProvider; +import javax.websocket.OnClose; import javax.websocket.OnMessage; import javax.websocket.Session; import javax.websocket.WebSocketContainer; @@ -37,11 +39,15 @@ import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -127,6 +133,62 @@ public void testMoreThanLargestMessageOneByteAtATime() throws Exception assertArrayEquals(data, client.getEcho()); } + @Test + public void testNotReadingToEndOfStream() throws Exception + { + int size = 32; + byte[] data = randomBytes(size); + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH); + + CountDownLatch handlerComplete = new CountDownLatch(1); + BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) -> + { + byte[] recv = new byte[16]; + int read = inputStream.read(recv); + assertThat(read, not(is(0))); + handlerComplete.countDown(); + }); + + Session session = wsClient.connectToServer(client, uri); + session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data)); + assertTrue(handlerComplete.await(5, TimeUnit.SECONDS)); + + session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test")); + assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE)); + assertThat(client.closeReason.getReasonPhrase(), is("close from test")); + } + + @Test + public void testClosingBeforeReadingToEndOfStream() throws Exception + { + int size = 32; + byte[] data = randomBytes(size); + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH); + + CountDownLatch handlerComplete = new CountDownLatch(1); + BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) -> + { + byte[] recv = new byte[16]; + int read = inputStream.read(recv); + assertThat(read, not(is(0))); + + inputStream.close(); + read = inputStream.read(recv); + assertThat(read, is(-1)); + handlerComplete.countDown(); + }); + + Session session = wsClient.connectToServer(client, uri); + session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data)); + assertTrue(handlerComplete.await(5, TimeUnit.SECONDS)); + + session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test")); + assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE)); + assertThat(client.closeReason.getReasonPhrase(), is("close from test")); + } + private byte[] randomBytes(int size) { byte[] data = new byte[size]; @@ -134,6 +196,37 @@ private byte[] randomBytes(int size) return data; } + @ClientEndpoint + public static class BasicClientBinaryStreamer + { + public interface MessageHandler + { + void accept(Session session, InputStream inputStream) throws Exception; + } + + private final MessageHandler handler; + private final CountDownLatch closeLatch = new CountDownLatch(1); + private CloseReason closeReason; + + public BasicClientBinaryStreamer(MessageHandler consumer) + { + this.handler = consumer; + } + + @OnMessage + public void echoed(Session session, InputStream input) throws Exception + { + handler.accept(session, input); + } + + @OnClose + public void onClosed(CloseReason closeReason) + { + this.closeReason = closeReason; + closeLatch.countDown(); + } + } + @ClientEndpoint public static class ClientBinaryStreamer { diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyAnnotatedEventDriver.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyAnnotatedEventDriver.java index c48d0d9de245..7adfafda7428 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyAnnotatedEventDriver.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyAnnotatedEventDriver.java @@ -32,7 +32,6 @@ import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.common.CloseInfo; -import org.eclipse.jetty.websocket.common.message.MessageAppender; import org.eclipse.jetty.websocket.common.message.MessageInputStream; import org.eclipse.jetty.websocket.common.message.MessageReader; import org.eclipse.jetty.websocket.common.message.NullMessage; @@ -105,7 +104,7 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException } else if (events.onBinary.isStreaming()) { - final MessageInputStream inputStream = new MessageInputStream(session); + MessageInputStream inputStream = new MessageInputStream(session); activeMessage = inputStream; dispatch(() -> { @@ -115,11 +114,11 @@ else if (events.onBinary.isStreaming()) } catch (Throwable t) { - // dispatched calls need to be reported session.close(t); + return; } - inputStream.close(); + inputStream.handlerComplete(); }); } else @@ -262,22 +261,21 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException } else if (events.onText.isStreaming()) { - MessageInputStream inputStream = new MessageInputStream(session); - activeMessage = new MessageReader(inputStream); - final MessageAppender msg = activeMessage; + MessageReader reader = new MessageReader(session); + activeMessage = reader; dispatch(() -> { try { - events.onText.call(websocket, session, msg); + events.onText.call(websocket, session, reader); } catch (Throwable t) { - // dispatched calls need to be reported session.close(t); + return; } - inputStream.close(); + reader.handlerComplete(); }); } else diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/io/AbstractWebSocketConnection.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/io/AbstractWebSocketConnection.java index 11583bd58a05..1e3c98b63532 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/io/AbstractWebSocketConnection.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/io/AbstractWebSocketConnection.java @@ -521,7 +521,7 @@ public void resume() { ByteBuffer resume = readState.resume(); if (resume != null) - onFillable(resume); + getExecutor().execute(() -> onFillable(resume)); } @Override diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageInputStream.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageInputStream.java index 346076dd0ecf..96c9f99deeb9 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageInputStream.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageInputStream.java @@ -53,8 +53,24 @@ public class MessageInputStream extends InputStream implements MessageAppender private enum State { + /** + * Open and waiting for a frame to be delivered in {@link #appendFrame(ByteBuffer, boolean)}. + */ RESUMED, + + /** + * We have suspended the session after reading a websocket frame but have not reached the end of the message. + */ SUSPENDED, + + /** + * We have received a frame with fin==true and have suspended until we are signaled that onMessage method exited. + */ + COMPLETE, + + /** + * We have read to EOF or someone has called InputStream.close(), any further reads will result in reading -1. + */ CLOSED } @@ -76,40 +92,46 @@ public void appendFrame(ByteBuffer framePayload, boolean fin) throws IOException if (LOG.isDebugEnabled()) LOG.debug("Appending {} chunk: {}", fin ? "final" : "non-final", BufferUtil.toDetailString(framePayload)); - // Early non atomic test that we aren't closed to avoid an unnecessary copy (will be checked again later). - if (state == State.CLOSED) + // Avoid entering synchronized block if there is nothing to do. + boolean bufferIsEmpty = BufferUtil.isEmpty(framePayload); + if (bufferIsEmpty && !fin) return; - // Put the payload into the queue, by copying it. - // Copying is necessary because the payload will - // be processed after this method returns. try { - if (framePayload == null || !framePayload.hasRemaining()) - return; - - ByteBuffer copy = acquire(framePayload.remaining(), framePayload.isDirect()); - BufferUtil.clearToFill(copy); - copy.put(framePayload); - BufferUtil.flipToFlush(copy, 0); - synchronized (this) { - switch (state) + if (!bufferIsEmpty) { - case CLOSED: - return; + switch (state) + { + case CLOSED: + return; - case RESUMED: - suspendToken = session.suspend(); - state = State.SUSPENDED; - break; + case RESUMED: + suspendToken = session.suspend(); + state = State.SUSPENDED; + break; + + default: + throw new IllegalStateException("Incorrect State: " + state.name()); + } - case SUSPENDED: - throw new IllegalStateException(); + // Put the payload into the queue, by copying it. + // Copying is necessary because the payload will + // be processed after this method returns. + ByteBuffer copy = acquire(framePayload.remaining(), framePayload.isDirect()); + BufferUtil.clearToFill(copy); + copy.put(framePayload); + BufferUtil.flipToFlush(copy, 0); + buffers.put(copy); } - buffers.put(copy); + if (fin) + { + buffers.add(EOF); + state = State.COMPLETE; + } } } catch (InterruptedException e) @@ -121,56 +143,59 @@ public void appendFrame(ByteBuffer framePayload, boolean fin) throws IOException @Override public void close() { - SuspendToken resume = null; synchronized (this) { - switch (state) - { - case CLOSED: - return; - - case SUSPENDED: - resume = suspendToken; - suspendToken = null; - state = State.CLOSED; - break; - - case RESUMED: - state = State.CLOSED; - break; - } + if (state == State.CLOSED) + return; + + boolean remainingContent = (state != State.COMPLETE) || + (!buffers.isEmpty() && buffers.peek() != EOF) || + (activeBuffer != null && activeBuffer.hasRemaining()); + if (remainingContent) + LOG.warn("MessageInputStream closed without fully consuming content {}", session); + + state = State.CLOSED; buffers.clear(); - buffers.offer(EOF); + buffers.add(EOF); } - - // May need to resume to discard until we reach next message. - if (resume != null) - resume.resume(); } - @Override - public void mark(int readlimit) + public void handlerComplete() { - // Not supported. - } + // Close the InputStream. + close(); - @Override - public boolean markSupported() - { - return false; + // May need to resume to resume and read to the next message. + SuspendToken resume; + synchronized (this) + { + resume = suspendToken; + suspendToken = null; + } + + if (resume != null) + resume.resume(); } @Override - public void messageComplete() + public int read() throws IOException { - if (LOG.isDebugEnabled()) - LOG.debug("Message completed"); - buffers.offer(EOF); + byte[] bytes = new byte[1]; + while (true) + { + int read = read(bytes, 0, 1); + if (read < 0) + return -1; + if (read == 0) + continue; + + return bytes[0] & 0xFF; + } } @Override - public int read() throws IOException + public int read(byte[] b, int off, int len) throws IOException { try { @@ -186,6 +211,7 @@ public int read() throws IOException { if (LOG.isDebugEnabled()) LOG.debug("Waiting {} ms to read", timeoutMs); + if (timeoutMs < 0) { // Wait forever until a buffer is available. @@ -209,10 +235,14 @@ public int read() throws IOException } } - int result = activeBuffer.get() & 0xFF; + ByteBuffer buffer = BufferUtil.toBuffer(b, off, len); + BufferUtil.clearToFill(buffer); + int written = BufferUtil.put(activeBuffer, buffer); + BufferUtil.flipToFlush(buffer, 0); + + // If we have no more content we may need to resume to get more data. if (!activeBuffer.hasRemaining()) { - SuspendToken resume = null; synchronized (this) { @@ -221,6 +251,11 @@ public int read() throws IOException case CLOSED: return -1; + case COMPLETE: + // If we are complete we have read the last frame but + // don't want to resume reading until onMessage() exits. + break; + case SUSPENDED: resume = suspendToken; suspendToken = null; @@ -228,7 +263,7 @@ public int read() throws IOException break; case RESUMED: - throw new IllegalStateException(); + throw new IllegalStateException("Incorrect State: " + state.name()); } } @@ -237,7 +272,7 @@ public int read() throws IOException resume.resume(); } - return result; + return written; } catch (InterruptedException x) { @@ -248,12 +283,30 @@ public int read() throws IOException } } + @Override + public void messageComplete() + { + // We handle this case in appendFrame with fin==true. + } + @Override public void reset() throws IOException { throw new IOException("reset() not supported"); } + @Override + public void mark(int readlimit) + { + // Not supported. + } + + @Override + public boolean markSupported() + { + return false; + } + private ByteBuffer acquire(int capacity, boolean direct) { ByteBuffer buffer; diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageReader.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageReader.java index fbbdfc5ec3a5..9d1422d46a81 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageReader.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/message/MessageReader.java @@ -24,6 +24,8 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; +import org.eclipse.jetty.websocket.api.Session; + /** * Support class for reading a (single) WebSocket TEXT message via a Reader. *

@@ -33,6 +35,11 @@ public class MessageReader extends InputStreamReader implements MessageAppender { private final MessageInputStream stream; + public MessageReader(Session session) + { + this(new MessageInputStream(session)); + } + public MessageReader(MessageInputStream stream) { super(stream, StandardCharsets.UTF_8); @@ -50,4 +57,9 @@ public void messageComplete() { this.stream.messageComplete(); } + + public void handlerComplete() + { + this.stream.handlerComplete(); + } } diff --git a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java index 4266ed515426..96983a391a77 100644 --- a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java +++ b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java @@ -18,6 +18,7 @@ package org.eclipse.jetty.websocket.common.message; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -32,6 +33,7 @@ import org.eclipse.jetty.toolchain.test.jupiter.WorkDirExtension; import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.IO; import org.eclipse.jetty.websocket.api.SuspendToken; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -111,9 +113,10 @@ public void testBlockOnRead() throws Exception startLatch.await(); // Read it from the stream. - byte[] buf = new byte[32]; - int len = stream.read(buf); - String message = new String(buf, 0, len, StandardCharsets.UTF_8); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + IO.copy(stream, out); + byte[] bytes = out.toByteArray(); + String message = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8); // Test it assertThat("Error when appending", hadError.get(), is(false)); @@ -169,9 +172,10 @@ public void testReadByteNoBuffersClosed() throws IOException { // wait for a little bit before sending input closed TimeUnit.MILLISECONDS.sleep(1000); + stream.appendFrame(null, true); stream.messageComplete(); } - catch (InterruptedException e) + catch (InterruptedException | IOException e) { hadError.set(true); e.printStackTrace(System.err); @@ -206,9 +210,10 @@ public void testSplitMessageWithEmptyPayloads() throws IOException session.provideContent(); // Read entire message it from the stream. - byte[] buf = new byte[32]; - int len = stream.read(buf); - String message = new String(buf, 0, len, StandardCharsets.UTF_8); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + IO.copy(stream, out); + byte[] bytes = out.toByteArray(); + String message = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8); // Test it assertThat("Message", message, is("Hello World!"));