Skip to content

Commit

Permalink
Issue #5368 - ensure onMessage exits before next frame is read
Browse files Browse the repository at this point in the history
Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
  • Loading branch information
lachlan-roberts committed Oct 1, 2020
1 parent e3ed05f commit 941ffce
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 67 deletions.
Expand Up @@ -118,10 +118,8 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
if (LOG.isDebugEnabled())
LOG.debug("Binary Message InputStream");

final MessageInputStream stream = new MessageInputStream(session);
MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream;

// Always dispatch streaming read to another thread.
dispatch(() ->
{
try
Expand Down Expand Up @@ -329,11 +327,8 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
if (LOG.isDebugEnabled())
LOG.debug("Text Message Writer");

MessageInputStream inputStream = new MessageInputStream(session);
final MessageReader reader = new MessageReader(inputStream);
activeMessage = inputStream;

// Always dispatch streaming read to another thread.
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
try
Expand All @@ -343,9 +338,10 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
catch (Throwable e)
{
session.close(e);
return;
}

inputStream.close();
reader.handlerComplete();
});
}
}
Expand Down
Expand Up @@ -100,9 +100,10 @@ else if (wrapper.wantsStreams())
catch (Throwable t)
{
session.close(t);
return;
}

inputStream.close();
inputStream.handlerComplete();
});
}
else
Expand Down Expand Up @@ -197,8 +198,7 @@ else if (wrapper.wantsStreams())
{
@SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
MessageReader reader = new MessageReader(inputStream);
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
Expand All @@ -209,9 +209,10 @@ else if (wrapper.wantsStreams())
catch (Throwable t)
{
session.close(t);
return;
}

inputStream.close();
reader.handlerComplete();
});
}
else
Expand Down
Expand Up @@ -32,7 +32,6 @@
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.message.MessageAppender;
import org.eclipse.jetty.websocket.common.message.MessageInputStream;
import org.eclipse.jetty.websocket.common.message.MessageReader;
import org.eclipse.jetty.websocket.common.message.NullMessage;
Expand Down Expand Up @@ -105,7 +104,7 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (events.onBinary.isStreaming())
{
final MessageInputStream inputStream = new MessageInputStream(session);
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream;
dispatch(() ->
{
Expand All @@ -115,11 +114,11 @@ else if (events.onBinary.isStreaming())
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
return;
}

inputStream.close();
inputStream.handlerComplete();
});
}
else
Expand Down Expand Up @@ -262,22 +261,21 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (events.onText.isStreaming())
{
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = new MessageReader(inputStream);
final MessageAppender msg = activeMessage;
MessageReader reader = new MessageReader(session);
activeMessage = reader;
dispatch(() ->
{
try
{
events.onText.call(websocket, session, msg);
events.onText.call(websocket, session, reader);
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
return;
}

inputStream.close();
reader.handlerComplete();
});
}
else
Expand Down
Expand Up @@ -55,6 +55,7 @@ private enum State
{
RESUMED,
SUSPENDED,
COMPLETE,
CLOSED
}

Expand All @@ -76,23 +77,11 @@ public void appendFrame(ByteBuffer framePayload, boolean fin) throws IOException
if (LOG.isDebugEnabled())
LOG.debug("Appending {} chunk: {}", fin ? "final" : "non-final", BufferUtil.toDetailString(framePayload));

// Early non atomic test that we aren't closed to avoid an unnecessary copy (will be checked again later).
if (state == State.CLOSED)
return;

// Put the payload into the queue, by copying it.
// Copying is necessary because the payload will
// be processed after this method returns.
try
{
if (framePayload == null || !framePayload.hasRemaining())
if (BufferUtil.isEmpty(framePayload))
return;

ByteBuffer copy = acquire(framePayload.remaining(), framePayload.isDirect());
BufferUtil.clearToFill(copy);
copy.put(framePayload);
BufferUtil.flipToFlush(copy, 0);

synchronized (this)
{
switch (state)
Expand All @@ -105,11 +94,14 @@ public void appendFrame(ByteBuffer framePayload, boolean fin) throws IOException
state = State.SUSPENDED;
break;

case SUSPENDED:
default:
throw new IllegalStateException();
}

buffers.put(copy);
// Put the payload into the queue, by copying it.
// Copying is necessary because the payload will
// be processed after this method returns.
buffers.put(copy(framePayload));
}
}
catch (InterruptedException e)
Expand All @@ -121,7 +113,23 @@ public void appendFrame(ByteBuffer framePayload, boolean fin) throws IOException
@Override
public void close()
{
SuspendToken resume = null;
synchronized (this)
{
if (state == State.CLOSED)
return;

state = State.CLOSED;
buffers.clear();
buffers.offer(EOF);
}
}

@Override
public void messageComplete()
{
if (LOG.isDebugEnabled())
LOG.debug("Message completed");

synchronized (this)
{
switch (state)
Expand All @@ -130,43 +138,33 @@ public void close()
return;

case SUSPENDED:
resume = suspendToken;
suspendToken = null;
state = State.CLOSED;
break;

case RESUMED:
state = State.CLOSED;
state = State.COMPLETE;
break;

default:
throw new IllegalStateException();
}

buffers.clear();
buffers.offer(EOF);
}

// May need to resume to discard until we reach next message.
if (resume != null)
resume.resume();
}

@Override
public void mark(int readlimit)
{
// Not supported.
}

@Override
public boolean markSupported()
public void handlerComplete()
{
return false;
}
// May need to resume to resume and read to the next message.
SuspendToken resume;
synchronized (this)
{
state = State.CLOSED;
resume = suspendToken;
suspendToken = null;
buffers.clear();
buffers.offer(EOF);
}

@Override
public void messageComplete()
{
if (LOG.isDebugEnabled())
LOG.debug("Message completed");
buffers.offer(EOF);
if (resume != null)
resume.resume();
}

@Override
Expand All @@ -186,6 +184,7 @@ public int read() throws IOException
{
if (LOG.isDebugEnabled())
LOG.debug("Waiting {} ms to read", timeoutMs);

if (timeoutMs < 0)
{
// Wait forever until a buffer is available.
Expand All @@ -212,7 +211,6 @@ public int read() throws IOException
int result = activeBuffer.get() & 0xFF;
if (!activeBuffer.hasRemaining())
{

SuspendToken resume = null;
synchronized (this)
{
Expand All @@ -221,6 +219,11 @@ public int read() throws IOException
case CLOSED:
return -1;

case COMPLETE:
// If we are complete we have read the last frame but
// don't want to resume reading until onMessage() exits.
break;

case SUSPENDED:
resume = suspendToken;
suspendToken = null;
Expand Down Expand Up @@ -254,6 +257,27 @@ public void reset() throws IOException
throw new IOException("reset() not supported");
}

@Override
public void mark(int readlimit)
{
// Not supported.
}

@Override
public boolean markSupported()
{
return false;
}

private ByteBuffer copy(ByteBuffer buffer)
{
ByteBuffer copy = acquire(buffer.remaining(), buffer.isDirect());
BufferUtil.clearToFill(copy);
copy.put(buffer);
BufferUtil.flipToFlush(copy, 0);
return copy;
}

private ByteBuffer acquire(int capacity, boolean direct)
{
ByteBuffer buffer;
Expand Down
Expand Up @@ -24,6 +24,8 @@
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;

import org.eclipse.jetty.websocket.api.Session;

/**
* Support class for reading a (single) WebSocket TEXT message via a Reader.
* <p>
Expand All @@ -33,6 +35,11 @@ public class MessageReader extends InputStreamReader implements MessageAppender
{
private final MessageInputStream stream;

public MessageReader(Session session)
{
this(new MessageInputStream(session));
}

public MessageReader(MessageInputStream stream)
{
super(stream, StandardCharsets.UTF_8);
Expand All @@ -50,4 +57,9 @@ public void messageComplete()
{
this.stream.messageComplete();
}

public void handlerComplete()
{
this.stream.handlerComplete();
}
}

0 comments on commit 941ffce

Please sign in to comment.