Skip to content

Commit

Permalink
Merge pull request #5377 from eclipse/jetty-9.4.x-5368-WebSocketInput…
Browse files Browse the repository at this point in the history
…Stream

Issue #5368 - ensure onMessage exits before next frame is read
  • Loading branch information
lachlan-roberts committed Oct 16, 2020
2 parents 9ad6beb + be041d3 commit f99b4ca
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 93 deletions.
Expand Up @@ -118,10 +118,8 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
if (LOG.isDebugEnabled())
LOG.debug("Binary Message InputStream");

final MessageInputStream stream = new MessageInputStream(session);
MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream;

// Always dispatch streaming read to another thread.
dispatch(() ->
{
try
Expand All @@ -133,7 +131,7 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
session.close(e);
}

stream.close();
stream.handlerComplete();
});
}
}
Expand Down Expand Up @@ -329,11 +327,8 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
if (LOG.isDebugEnabled())
LOG.debug("Text Message Writer");

MessageInputStream inputStream = new MessageInputStream(session);
final MessageReader reader = new MessageReader(inputStream);
activeMessage = inputStream;

// Always dispatch streaming read to another thread.
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
try
Expand All @@ -343,9 +338,10 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
catch (Throwable e)
{
session.close(e);
return;
}

inputStream.close();
reader.handlerComplete();
});
}
}
Expand Down
Expand Up @@ -100,9 +100,10 @@ else if (wrapper.wantsStreams())
catch (Throwable t)
{
session.close(t);
return;
}

inputStream.close();
inputStream.handlerComplete();
});
}
else
Expand Down Expand Up @@ -197,8 +198,7 @@ else if (wrapper.wantsStreams())
{
@SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
MessageReader reader = new MessageReader(inputStream);
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
Expand All @@ -209,9 +209,10 @@ else if (wrapper.wantsStreams())
catch (Throwable t)
{
session.close(t);
return;
}

inputStream.close();
reader.handlerComplete();
});
}
else
Expand Down
Expand Up @@ -27,7 +27,9 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
Expand All @@ -37,11 +39,15 @@
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -127,13 +133,100 @@ public void testMoreThanLargestMessageOneByteAtATime() throws Exception
assertArrayEquals(data, client.getEcho());
}

@Test
public void testNotReadingToEndOfStream() throws Exception
{
int size = 32;
byte[] data = randomBytes(size);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH);

CountDownLatch handlerComplete = new CountDownLatch(1);
BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) ->
{
byte[] recv = new byte[16];
int read = inputStream.read(recv);
assertThat(read, not(is(0)));
handlerComplete.countDown();
});

Session session = wsClient.connectToServer(client, uri);
session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data));
assertTrue(handlerComplete.await(5, TimeUnit.SECONDS));

session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test"));
assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
assertThat(client.closeReason.getReasonPhrase(), is("close from test"));
}

@Test
public void testClosingBeforeReadingToEndOfStream() throws Exception
{
int size = 32;
byte[] data = randomBytes(size);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH);

CountDownLatch handlerComplete = new CountDownLatch(1);
BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) ->
{
byte[] recv = new byte[16];
int read = inputStream.read(recv);
assertThat(read, not(is(0)));

inputStream.close();
read = inputStream.read(recv);
assertThat(read, is(-1));
handlerComplete.countDown();
});

Session session = wsClient.connectToServer(client, uri);
session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data));
assertTrue(handlerComplete.await(5, TimeUnit.SECONDS));

session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test"));
assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
assertThat(client.closeReason.getReasonPhrase(), is("close from test"));
}

private byte[] randomBytes(int size)
{
byte[] data = new byte[size];
new Random().nextBytes(data);
return data;
}

@ClientEndpoint
public static class BasicClientBinaryStreamer
{
public interface MessageHandler
{
void accept(Session session, InputStream inputStream) throws Exception;
}

private final MessageHandler handler;
private final CountDownLatch closeLatch = new CountDownLatch(1);
private CloseReason closeReason;

public BasicClientBinaryStreamer(MessageHandler consumer)
{
this.handler = consumer;
}

@OnMessage
public void echoed(Session session, InputStream input) throws Exception
{
handler.accept(session, input);
}

@OnClose
public void onClosed(CloseReason closeReason)
{
this.closeReason = closeReason;
closeLatch.countDown();
}
}

@ClientEndpoint
public static class ClientBinaryStreamer
{
Expand Down
Expand Up @@ -32,7 +32,6 @@
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.message.MessageAppender;
import org.eclipse.jetty.websocket.common.message.MessageInputStream;
import org.eclipse.jetty.websocket.common.message.MessageReader;
import org.eclipse.jetty.websocket.common.message.NullMessage;
Expand Down Expand Up @@ -105,7 +104,7 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (events.onBinary.isStreaming())
{
final MessageInputStream inputStream = new MessageInputStream(session);
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream;
dispatch(() ->
{
Expand All @@ -115,11 +114,11 @@ else if (events.onBinary.isStreaming())
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
return;
}

inputStream.close();
inputStream.handlerComplete();
});
}
else
Expand Down Expand Up @@ -262,22 +261,21 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (events.onText.isStreaming())
{
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = new MessageReader(inputStream);
final MessageAppender msg = activeMessage;
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
try
{
events.onText.call(websocket, session, msg);
events.onText.call(websocket, session, reader);
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
return;
}

inputStream.close();
reader.handlerComplete();
});
}
else
Expand Down
Expand Up @@ -521,7 +521,7 @@ public void resume()
{
ByteBuffer resume = readState.resume();
if (resume != null)
onFillable(resume);
getExecutor().execute(() -> onFillable(resume));
}

@Override
Expand Down

0 comments on commit f99b4ca

Please sign in to comment.