Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #5170 - fix upgrade bug in HttpReceiverOverHTTP #5266

Merged
merged 4 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class HttpReceiverOverHTTP extends HttpReceiver implements HttpParser.Res
private boolean shutdown;
private boolean complete;
private boolean unsolicited;
private int status;

public HttpReceiverOverHTTP(HttpChannelOverHTTP channel)
{
Expand Down Expand Up @@ -132,17 +133,18 @@ private void releaseNetworkBuffer()

protected ByteBuffer onUpgradeFrom()
{
ByteBuffer upgradeBuffer = null;
if (networkBuffer.hasRemaining())
{
HttpClient client = getHttpDestination().getHttpClient();
ByteBuffer upgradeBuffer = BufferUtil.allocate(networkBuffer.remaining(), client.isUseInputDirectByteBuffers());
upgradeBuffer = BufferUtil.allocate(networkBuffer.remaining(), client.isUseInputDirectByteBuffers());
BufferUtil.clearToFill(upgradeBuffer);
BufferUtil.put(networkBuffer.getBuffer(), upgradeBuffer);
BufferUtil.flipToFlush(upgradeBuffer, 0);
return upgradeBuffer;
}

releaseNetworkBuffer();
return null;
return upgradeBuffer;
}

private void process()
Expand Down Expand Up @@ -230,15 +232,19 @@ private boolean parse()
if (LOG.isDebugEnabled())
LOG.debug("Parse complete={}, remaining {} {}", complete, networkBuffer.remaining(), parser);

if (complete)
{
int status = this.status;
this.status = 0;
if (status == HttpStatus.SWITCHING_PROTOCOLS_101)
return true;
}

if (networkBuffer.isEmpty())
return false;

if (complete)
{
HttpExchange httpExchange = getHttpExchange();
if (httpExchange != null && httpExchange.getResponse().getStatus() == HttpStatus.SWITCHING_PROTOCOLS_101)
return true;

if (LOG.isDebugEnabled())
LOG.debug("Discarding unexpected content after response: {}", networkBuffer);
networkBuffer.clear();
Expand Down Expand Up @@ -281,6 +287,7 @@ public void startResponse(HttpVersion version, int status, String reason)
if (exchange == null)
return;

this.status = status;
String method = exchange.getRequest().getMethod();
parser.setHeadResponse(HttpMethod.HEAD.is(method) ||
(HttpMethod.CONNECT.is(method) && status == HttpStatus.OK_200));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ protected Action process()
protected void onCompleteFailure(Throwable t)
{
if (log.isDebugEnabled())
log.debug("failed to flush", t);
log.debug("onCompleteFailure {}", t.toString());

notifyCallbackFailure(current.callback, t);
current = null;
Expand All @@ -157,14 +157,14 @@ private void notifyCallbackSuccess(Callback callback)
}
catch (Throwable x)
{
log.warn("Exception while notifying success of callback " + callback, x);
log.warn("Exception while notifying success of callback {}", callback, x);
}
}

private void notifyCallbackFailure(Callback callback, Throwable failure)
{
if (log.isDebugEnabled())
log.debug("notifyCallbackFailure {} {}", callback, failure);
log.debug("notifyCallbackFailure {} {}", callback, failure.toString());

try
{
Expand All @@ -173,7 +173,7 @@ private void notifyCallbackFailure(Callback callback, Throwable failure)
}
catch (Throwable x)
{
log.warn("Exception while notifying failure of callback " + callback, x);
log.warn("Exception while notifying failure of callback {}", callback, x);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ public void succeeded()
public void failed(Throwable cause)
{
if (LOG.isDebugEnabled())
LOG.debug("failed onFrame(" + frame + ")", cause);
LOG.debug("failed onFrame({}) {}", frame, cause.toString());

frame.close();
if (referenced != null)
Expand Down Expand Up @@ -470,7 +470,7 @@ private void fillAndParse()
catch (Throwable t)
{
if (LOG.isDebugEnabled())
LOG.debug("Error during fillAndParse()", t);
LOG.debug("Error during fillAndParse() {}", t.toString());

if (networkBuffer != null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.net.SocketAddress;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.WritePendingException;
import java.time.Duration;
import java.util.List;
Expand Down Expand Up @@ -390,7 +391,7 @@ public void closeConnection(CloseStatus closeStatus, Callback callback)
public void processConnectionError(Throwable cause, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("processConnectionError {} {}", this, cause);
LOG.debug("processConnectionError {}", this, cause);

int code;
if (cause instanceof CloseException)
Expand Down Expand Up @@ -424,11 +425,13 @@ else if (cause instanceof WebSocketTimeoutException || cause instanceof TimeoutE
public void processHandlerError(Throwable cause, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("processHandlerError {} {}", this, cause);
LOG.debug("processHandlerError {}", this, cause);

int code;
if (cause instanceof CloseException)
code = ((CloseException)cause).getStatusCode();
else if (cause instanceof ClosedChannelException)
code = CloseStatus.NO_CLOSE;
else if (cause instanceof Utf8Appendable.NotUtf8Exception)
code = CloseStatus.BAD_PAYLOAD;
else if (cause instanceof WebSocketTimeoutException || cause instanceof TimeoutException || cause instanceof SocketTimeoutException)
Expand All @@ -438,7 +441,14 @@ else if (behavior == Behavior.CLIENT)
else
code = CloseStatus.SERVER_ERROR;

close(new CloseStatus(code, cause), callback);
CloseStatus closeStatus = new CloseStatus(code, cause);
if (CloseStatus.isTransmittableStatusCode(code))
close(closeStatus, callback);
else
{
if (sessionState.onClosed(closeStatus))
closeConnection(closeStatus, callback);
}
}

/**
Expand All @@ -458,10 +468,10 @@ public void onOpen()
() ->
{
sessionState.onOpen();
if (!demanding)
connection.demand(1);
if (LOG.isDebugEnabled())
LOG.debug("ConnectionState: Transition to OPEN");
if (!demanding)
connection.demand(1);
},
x ->
{
Expand Down Expand Up @@ -544,9 +554,7 @@ public void sendFrame(Frame frame, Callback callback, boolean batch)
}
catch (Throwable t)
{
if (LOG.isDebugEnabled())
LOG.warn("Invalid outgoing frame: {}", frame, t);

LOG.warn("Invalid outgoing frame: {}", frame, t);
callback.failed(t);
return;
}
Expand Down Expand Up @@ -574,7 +582,7 @@ public void sendFrame(Frame frame, Callback callback, boolean batch)
catch (Throwable t)
{
if (LOG.isDebugEnabled())
LOG.debug("Failed sendFrame()", t);
LOG.debug("Failed sendFrame() {}", t.toString());

if (frame.getOpCode() == OpCode.CLOSE)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,15 @@ else if (frame.isDataFrame())
return false;
}

public boolean onIncomingFrame(Frame frame) throws ProtocolException
public boolean onIncomingFrame(Frame frame) throws ProtocolException, ClosedChannelException
{
byte opcode = frame.getOpCode();
boolean fin = frame.isFin();

try (AutoLock l = lock.lock())
{
if (!isInputOpen())
throw new IllegalStateException(_sessionState.toString());
throw new ClosedChannelException();

if (opcode == OpCode.CLOSE)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,14 @@ public class TestMessageHandler extends MessageHandler
@Override
public void onOpen(CoreSession coreSession, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("onOpen {}", coreSession);
this.coreSession = coreSession;
super.onOpen(coreSession, callback);
this.coreSession = coreSession;
openLatch.countDown();
}

@Override
public void onFrame(Frame frame, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("onFrame {}", frame);
super.onFrame(frame, callback);
}

@Override
public void onError(Throwable cause, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("onError", cause);
super.onError(cause, callback);
error = cause;
errorLatch.countDown();
Expand All @@ -71,8 +59,6 @@ public void onError(Throwable cause, Callback callback)
@Override
public void onClosed(CloseStatus closeStatus, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("onClosed {}", closeStatus);
super.onClosed(closeStatus, callback);
this.closeStatus = closeStatus;
closeLatch.countDown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@
package org.eclipse.jetty.websocket.core;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Scanner;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.websocket.core.client.CoreClientUpgradeRequest;
import org.eclipse.jetty.websocket.core.client.WebSocketCoreClient;
import org.eclipse.jetty.websocket.core.internal.Generator;
import org.eclipse.jetty.websocket.core.internal.WebSocketCore;
Expand All @@ -43,8 +45,8 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class UpgradeWithLeftOverHttpBytesTest extends WebSocketTester
Expand Down Expand Up @@ -74,10 +76,21 @@ public void stop() throws Exception
@Test
public void testUpgradeWithLeftOverHttpBytes() throws Exception
{
TestMessageHandler clientEndpoint = new TestMessageHandler();
CompletableFuture<CoreSession> clientConnect = client.connect(clientEndpoint, serverUri);
CountDownLatch onOpenWait = new CountDownLatch(1);
TestMessageHandler clientEndpoint = new TestMessageHandler()
{
@Override
public void onOpen(CoreSession coreSession, Callback callback)
{
assertDoesNotThrow(() -> onOpenWait.await(5, TimeUnit.SECONDS));
super.onOpen(coreSession, callback);
}
};
CoreClientUpgradeRequest coreUpgrade = CoreClientUpgradeRequest.from(client, serverUri, clientEndpoint);
client.connect(coreUpgrade);
Socket serverSocket = server.accept();

// Receive the upgrade request with the Socket.
String upgradeRequest = getRequestHeaders(serverSocket.getInputStream());
assertThat(upgradeRequest, containsString("HTTP/1.1"));
assertThat(upgradeRequest, containsString("Upgrade: websocket"));
Expand All @@ -88,21 +101,34 @@ public void testUpgradeWithLeftOverHttpBytes() throws Exception
"Connection: Upgrade\n" +
"Sec-WebSocket-Accept: " + getAcceptKey(upgradeRequest) + "\n" +
"\n";

Frame dataFrame = new Frame(OpCode.TEXT, BufferUtil.toBuffer("first message payload"));
Frame firstFrame = new Frame(OpCode.TEXT, BufferUtil.toBuffer("first message payload"));
byte[] bytes = combineToByteArray(BufferUtil.toBuffer(upgradeResponse), generateFrame(firstFrame));
serverSocket.getOutputStream().write(bytes);

// Now we send the rest of the data.
int numFrames = 1000;
for (int i = 0; i < numFrames; i++)
{
Frame frame = new Frame(OpCode.TEXT, BufferUtil.toBuffer(Integer.toString(i)));
serverSocket.getOutputStream().write(toByteArray(frame));
}
Frame closeFrame = new CloseStatus(CloseStatus.NORMAL, "closed by test").toFrame();
serverSocket.getOutputStream().write(toByteArray(closeFrame));

ByteArrayOutputStream baos = new ByteArrayOutputStream();
baos.write(upgradeResponse.getBytes(StandardCharsets.ISO_8859_1));
BufferUtil.writeTo(generateFrame(dataFrame), baos);
BufferUtil.writeTo(generateFrame(closeFrame), baos);
serverSocket.getOutputStream().write(baos.toByteArray());

// Check the client receives upgrade response and then the two websocket frames.
CoreSession coreSession = clientConnect.get(5, TimeUnit.SECONDS);
assertNotNull(coreSession);
// First payload sent with upgrade request, delay to ensure HttpConnection is not still reading from network.
Thread.sleep(1000);
onOpenWait.countDown();
assertTrue(clientEndpoint.openLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS), is("first message payload"));

// We receive the rest of the frames all sent as separate writes.
for (int i = 0; i < numFrames; i++)
{
String msg = clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(msg, is(Integer.toString(i)));
}

// Closed successfully with correct status.
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeStatus.getCode(), is(CloseStatus.NORMAL));
assertThat(clientEndpoint.closeStatus.getReason(), is("closed by test"));
Expand Down Expand Up @@ -131,4 +157,20 @@ static String getRequestHeaders(InputStream is)
Scanner s = new Scanner(is).useDelimiter("\r\n\r\n");
return s.hasNext() ? s.next() : "";
}

byte[] combineToByteArray(ByteBuffer... buffers) throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
for (ByteBuffer bb : buffers)
{
BufferUtil.writeTo(bb, baos);
}

return baos.toByteArray();
}

byte[] toByteArray(Frame frame)
{
return BufferUtil.toArray(generateFrame(frame));
}
}