Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #5368 - ensure onMessage exits before next frame is read #5377

Merged
merged 8 commits into from Oct 16, 2020
Expand Up @@ -118,10 +118,8 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
if (LOG.isDebugEnabled())
LOG.debug("Binary Message InputStream");

final MessageInputStream stream = new MessageInputStream(session);
MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream;

// Always dispatch streaming read to another thread.
dispatch(() ->
{
try
Expand All @@ -133,7 +131,7 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
session.close(e);
}

stream.close();
stream.handlerComplete();
});
}
}
Expand Down Expand Up @@ -329,11 +327,8 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
if (LOG.isDebugEnabled())
LOG.debug("Text Message Writer");

MessageInputStream inputStream = new MessageInputStream(session);
final MessageReader reader = new MessageReader(inputStream);
activeMessage = inputStream;

// Always dispatch streaming read to another thread.
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
try
Expand All @@ -343,9 +338,10 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
catch (Throwable e)
{
session.close(e);
return;
joakime marked this conversation as resolved.
Show resolved Hide resolved
}

inputStream.close();
reader.handlerComplete();
});
}
}
Expand Down
Expand Up @@ -100,9 +100,10 @@ else if (wrapper.wantsStreams())
catch (Throwable t)
{
session.close(t);
return;
}

inputStream.close();
inputStream.handlerComplete();
});
}
else
Expand Down Expand Up @@ -197,8 +198,7 @@ else if (wrapper.wantsStreams())
{
@SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
MessageReader reader = new MessageReader(inputStream);
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
Expand All @@ -209,9 +209,10 @@ else if (wrapper.wantsStreams())
catch (Throwable t)
{
session.close(t);
return;
joakime marked this conversation as resolved.
Show resolved Hide resolved
}

inputStream.close();
reader.handlerComplete();
});
}
else
Expand Down
Expand Up @@ -27,7 +27,9 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
Expand All @@ -37,11 +39,15 @@
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -127,13 +133,100 @@ public void testMoreThanLargestMessageOneByteAtATime() throws Exception
assertArrayEquals(data, client.getEcho());
}

@Test
public void testNotReadingToEndOfStream() throws Exception
{
int size = 32;
byte[] data = randomBytes(size);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH);

CountDownLatch handlerComplete = new CountDownLatch(1);
BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) ->
{
byte[] recv = new byte[16];
int read = inputStream.read(recv);
assertThat(read, not(is(0)));
handlerComplete.countDown();
});

Session session = wsClient.connectToServer(client, uri);
session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data));
assertTrue(handlerComplete.await(5, TimeUnit.SECONDS));

session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test"));
assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
assertThat(client.closeReason.getReasonPhrase(), is("close from test"));
}

@Test
public void testClosingBeforeReadingToEndOfStream() throws Exception
{
int size = 32;
byte[] data = randomBytes(size);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH);

CountDownLatch handlerComplete = new CountDownLatch(1);
BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) ->
{
byte[] recv = new byte[16];
int read = inputStream.read(recv);
assertThat(read, not(is(0)));

inputStream.close();
read = inputStream.read(recv);
assertThat(read, is(-1));
handlerComplete.countDown();
});

Session session = wsClient.connectToServer(client, uri);
session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data));
assertTrue(handlerComplete.await(5, TimeUnit.SECONDS));

session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test"));
assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
assertThat(client.closeReason.getReasonPhrase(), is("close from test"));
}

private byte[] randomBytes(int size)
{
byte[] data = new byte[size];
new Random().nextBytes(data);
return data;
}

@ClientEndpoint
public static class BasicClientBinaryStreamer
{
public interface MessageHandler
{
void accept(Session session, InputStream inputStream) throws Exception;
}

private final MessageHandler handler;
private final CountDownLatch closeLatch = new CountDownLatch(1);
private CloseReason closeReason;

public BasicClientBinaryStreamer(MessageHandler consumer)
{
this.handler = consumer;
}

@OnMessage
public void echoed(Session session, InputStream input) throws Exception
{
handler.accept(session, input);
}

@OnClose
public void onClosed(CloseReason closeReason)
{
this.closeReason = closeReason;
closeLatch.countDown();
}
}

@ClientEndpoint
public static class ClientBinaryStreamer
{
Expand Down
Expand Up @@ -32,7 +32,6 @@
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.message.MessageAppender;
import org.eclipse.jetty.websocket.common.message.MessageInputStream;
import org.eclipse.jetty.websocket.common.message.MessageReader;
import org.eclipse.jetty.websocket.common.message.NullMessage;
Expand Down Expand Up @@ -105,7 +104,7 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (events.onBinary.isStreaming())
{
final MessageInputStream inputStream = new MessageInputStream(session);
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream;
dispatch(() ->
{
Expand All @@ -115,11 +114,11 @@ else if (events.onBinary.isStreaming())
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
return;
joakime marked this conversation as resolved.
Show resolved Hide resolved
}

inputStream.close();
inputStream.handlerComplete();
});
}
else
Expand Down Expand Up @@ -262,22 +261,21 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (events.onText.isStreaming())
{
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = new MessageReader(inputStream);
final MessageAppender msg = activeMessage;
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
try
{
events.onText.call(websocket, session, msg);
events.onText.call(websocket, session, reader);
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
return;
joakime marked this conversation as resolved.
Show resolved Hide resolved
}

inputStream.close();
reader.handlerComplete();
});
}
else
Expand Down
Expand Up @@ -521,7 +521,7 @@ public void resume()
{
ByteBuffer resume = readState.resume();
if (resume != null)
onFillable(resume);
getExecutor().execute(() -> onFillable(resume));
joakime marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
Expand Down