Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #4824 - add configuration on RemoteEndpoint for maxOutgoingFrames #5059

Merged
merged 7 commits into from Sep 10, 2020
@@ -0,0 +1,187 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//

package org.eclipse.jetty.websocket.tests;

import java.net.URI;
import java.nio.channels.WritePendingException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.websocket.api.BatchMode;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
import org.eclipse.jetty.websocket.common.io.FutureWriteCallback;
import org.eclipse.jetty.websocket.server.NativeWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.server.WebSocketUpgradeFilter;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class MaxOutgoingFramesTest
{
public static CountDownLatch outgoingBlocked;
public static CountDownLatch firstFrameBlocked;

private final EventSocket serverSocket = new EventSocket();
private Server server;
private ServerConnector connector;
private WebSocketClient client;

@BeforeEach
public void start() throws Exception
{
outgoingBlocked = new CountDownLatch(1);
firstFrameBlocked = new CountDownLatch(1);

server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);

ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS);
contextHandler.setContextPath("/");
NativeWebSocketServletContainerInitializer.configure(contextHandler, (context, container) ->
{
container.addMapping("/", (req, resp) -> serverSocket);
container.getFactory().getExtensionFactory().register(BlockingOutgoingExtension.class.getName(), BlockingOutgoingExtension.class);
});

WebSocketUpgradeFilter.configure(contextHandler);
server.setHandler(contextHandler);

client = new WebSocketClient();
server.start();
client.start();
}

@AfterEach
public void stop() throws Exception
{
outgoingBlocked.countDown();
server.stop();
client.stop();
}

public static class BlockingOutgoingExtension extends AbstractExtension
{
@Override
public String getName()
{
return BlockingOutgoingExtension.class.getName();
}

@Override
public void incomingFrame(Frame frame)
{
getNextIncoming().incomingFrame(frame);
}

@Override
public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode)
{
try
{
firstFrameBlocked.countDown();
outgoingBlocked.await();
getNextOutgoing().outgoingFrame(frame, callback, batchMode);
}
catch (InterruptedException e)
{
throw new RuntimeException(e);
}
}
}

public static class CountingCallback implements WriteCallback
{
private final CountDownLatch successes;

public CountingCallback(int count)
{
successes = new CountDownLatch(count);
}

@Override
public void writeSuccess()
{
successes.countDown();
}

@Override
public void writeFailed(Throwable t)
{
t.printStackTrace();
}
}

@Test
public void testMaxOutgoingFrames() throws Exception
{
// We need to have the frames queued but not yet sent, we do this by blocking in the ExtensionStack.
client.getExtensionFactory().register(BlockingOutgoingExtension.class.getName(), BlockingOutgoingExtension.class);

URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/");
EventSocket socket = new EventSocket();
ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.addExtensions(BlockingOutgoingExtension.class.getName());
client.connect(socket, uri, upgradeRequest).get(5, TimeUnit.SECONDS);
assertTrue(socket.openLatch.await(5, TimeUnit.SECONDS));

int numFrames = 30;
RemoteEndpoint remote = socket.session.getRemote();
remote.setMaxOutgoingFrames(numFrames);

// Verify that we can send up to numFrames without any problem.
// First send will block in the Extension so it needs to be done in new thread, others frames will be queued.
CountingCallback countingCallback = new CountingCallback(numFrames);
new Thread(() -> remote.sendString("0", countingCallback)).start();
assertTrue(firstFrameBlocked.await(5, TimeUnit.SECONDS));
for (int i = 1; i < numFrames; i++)
{
remote.sendString(Integer.toString(i), countingCallback);
}

// Sending any more frames will result in WritePendingException.
FutureWriteCallback callback = new FutureWriteCallback();
remote.sendString("fail", callback);
ExecutionException executionException = assertThrows(ExecutionException.class, () -> callback.get(5, TimeUnit.SECONDS));
assertThat(executionException.getCause(), instanceOf(WritePendingException.class));

// Check that all callbacks are succeeded when the server processes the frames.
outgoingBlocked.countDown();
assertTrue(countingCallback.successes.await(5, TimeUnit.SECONDS));

// Close successfully.
socket.session.close();
assertTrue(serverSocket.closeLatch.await(5, TimeUnit.SECONDS));
assertTrue(socket.closeLatch.await(5, TimeUnit.SECONDS));
}
}
Expand Up @@ -141,6 +141,28 @@ public interface RemoteEndpoint
*/
void setBatchMode(BatchMode mode);

/**
* Get the maximum number of data frames allowed to be waiting to be sent at any one time.
* The default value is -1, this indicates there is no limit on how many frames can be
* queued to be sent by the implementation. If the limit is exceeded, subsequent frames
* sent are failed with a {@link java.nio.channels.WritePendingException} but
* the connection is not failed and will remain open.
*
* @return the max number of frames.
*/
int getMaxOutgoingFrames();

/**
* Set the maximum number of data frames allowed to be waiting to be sent at any one time.
* The default value is -1, this indicates there is no limit on how many frames can be
* queued to be sent by the implementation. If the limit is exceeded, subsequent frames
* sent are failed with a {@link java.nio.channels.WritePendingException} but
* the connection is not failed and will remain open.
*
* @param maxOutgoingFrames the max number of frames.
*/
void setMaxOutgoingFrames(int maxOutgoingFrames);

/**
* Get the InetSocketAddress for the established connection.
*
Expand Down
Expand Up @@ -21,6 +21,7 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.WritePendingException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -81,7 +82,9 @@ public void writeFailed(Throwable x)
private final OutgoingFrames outgoing;
private final AtomicInteger msgState = new AtomicInteger();
private final BlockingWriteCallback blocker = new BlockingWriteCallback();
private final AtomicInteger numOutgoingFrames = new AtomicInteger();
private volatile BatchMode batchMode;
private int maxNumOutgoingFrames = -1;

public WebSocketRemoteEndpoint(LogicalConnection connection, OutgoingFrames outgoing)
{
Expand Down Expand Up @@ -303,6 +306,19 @@ public void uncheckedSendFrame(WebSocketFrame frame, WriteCallback callback)
BatchMode batchMode = BatchMode.OFF;
if (frame.isDataFrame())
batchMode = getBatchMode();

if (maxNumOutgoingFrames > 0 && frame.isDataFrame())
{
// Increase the number of outgoing frames, will be decremented when callback is completed.
int outgoingFrames = numOutgoingFrames.incrementAndGet();
callback = from(callback, numOutgoingFrames::decrementAndGet);
if (outgoingFrames > maxNumOutgoingFrames)
{
callback.writeFailed(new WritePendingException());
return;
}
}

outgoing.outgoingFrame(frame, callback, batchMode);
}

Expand Down Expand Up @@ -439,6 +455,18 @@ public void setBatchMode(BatchMode batchMode)
this.batchMode = batchMode;
}

@Override
public int getMaxOutgoingFrames()
{
return maxNumOutgoingFrames;
}

@Override
public void setMaxOutgoingFrames(int maxOutgoingFrames)
{
this.maxNumOutgoingFrames = maxOutgoingFrames;
}

@Override
public void flush() throws IOException
{
Expand All @@ -459,4 +487,36 @@ public String toString()
{
return String.format("%s@%x[batching=%b]", getClass().getSimpleName(), hashCode(), getBatchMode());
}

private static WriteCallback from(WriteCallback callback, Runnable completed)
{
return new WriteCallback()
{
@Override
public void writeFailed(Throwable x)
{
try
{
callback.writeFailed(x);
}
finally
{
completed.run();
}
}

@Override
public void writeSuccess()
{
try
{
callback.writeSuccess();
}
finally
{
completed.run();
}
}
};
}
}