Skip to content

Commit

Permalink
Merge pull request #1361 from fl4via/2.2.x_backport_bug_fixes
Browse files Browse the repository at this point in the history
2.2.x backport bug fixes
  • Loading branch information
fl4via committed Aug 11, 2022
2 parents 9a06b56 + 215316d commit e52cefb
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 22 deletions.
Expand Up @@ -18,11 +18,6 @@

package io.undertow.conduits;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;

import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.UndertowOptions;
Expand All @@ -41,6 +36,11 @@
import org.xnio.conduits.ReadReadyHandler;
import org.xnio.conduits.StreamSourceConduit;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;

/**
* Wrapper for read timeout. This should always be the first wrapper applied to the underlying channel.
*
Expand All @@ -49,7 +49,7 @@
*/
public final class ReadTimeoutStreamSourceConduit extends AbstractStreamSourceConduit<StreamSourceConduit> {

private XnioExecutor.Key handle;
private volatile XnioExecutor.Key handle;
private final StreamConnection connection;
private volatile long expireTime = -1;
private final OpenListener openListener;
Expand All @@ -60,14 +60,21 @@ public final class ReadTimeoutStreamSourceConduit extends AbstractStreamSourceCo
private final Runnable timeoutCommand = new Runnable() {
@Override
public void run() {
handle = null;
if (expireTime == -1) {
synchronized (ReadTimeoutStreamSourceConduit.this) {
handle = null;
}
if (expireTime == -1 || !connection.isOpen()) {
return;
}
long current = System.currentTimeMillis();
if (current < expireTime) {
//timeout has been bumped, re-schedule
handle = WorkerUtils.executeAfter(connection.getIoThread(),timeoutCommand, (expireTime - current) + FUZZ_FACTOR, TimeUnit.MILLISECONDS);
if (handle == null) {
synchronized (ReadTimeoutStreamSourceConduit.this) {
if (handle == null)
handle = WorkerUtils.executeAfter(connection.getIoThread(), timeoutCommand, (expireTime - current) + FUZZ_FACTOR, TimeUnit.MILLISECONDS);
}
}
return;
}
UndertowLogger.REQUEST_LOGGER.tracef("Timing out channel %s due to inactivity", connection.getSourceChannel());
Expand Down Expand Up @@ -131,12 +138,16 @@ private void handleReadTimeout(final long ret) throws IOException {
final long expireTimeVar = expireTime;
if (expireTimeVar != -1 && currentTime > expireTimeVar) {
IoUtils.safeClose(connection);
throw UndertowMessages.MESSAGES.readTimedOut(this.getTimeout());
throw UndertowMessages.MESSAGES.readTimedOut(currentTime - (expireTimeVar - this.getTimeout()));
}
}
expireTime = currentTime + timeout;
if (handle == null) {
handle = connection.getIoThread().executeAfter(timeoutCommand, timeout, TimeUnit.MILLISECONDS);
synchronized (this) {
if (handle == null)
handle = connection.getIoThread().executeAfter(timeoutCommand, timeout, TimeUnit.MILLISECONDS);
}

}
}

Expand Down Expand Up @@ -232,9 +243,13 @@ public void terminateReads() throws IOException {

private void cleanup() {
if (handle != null) {
handle.remove();
handle = null;
expireTime = -1;
synchronized (this) {
if (handle != null) {
handle.remove();
handle = null;
expireTime = -1;
}
}
}
}

Expand All @@ -247,7 +262,7 @@ public void suspendReads() {
private void checkExpired() throws ReadTimeoutException {
synchronized (this) {
if (expired) {
throw UndertowMessages.MESSAGES.readTimedOut(System.currentTimeMillis());
throw UndertowMessages.MESSAGES.readTimedOut(System.currentTimeMillis() - (expireTime - getTimeout()));
}
}
}
Expand Down
Expand Up @@ -24,6 +24,7 @@
import io.undertow.websockets.extensions.ExtensionFunction;
import org.xnio.ChannelExceptionHandler;
import org.xnio.ChannelListener;
import org.xnio.ChannelListener.SimpleSetter;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
Expand Down Expand Up @@ -82,6 +83,7 @@ public abstract class WebSocketChannel extends AbstractFramedChannel<WebSocketCh
*/
private final Set<WebSocketChannel> peerConnections;

private static final CloseMessage CLOSE_MSG = new CloseMessage(CloseMessage.GOING_AWAY, WebSocketMessages.MESSAGES.messageCloseWebSocket());
/**
* Create a new {@link WebSocketChannel}
* 8
Expand Down Expand Up @@ -158,6 +160,15 @@ protected void lastDataRead() {
} catch (IOException e) {
IoUtils.safeClose(this);
}
final ChannelListener<?> listener = ((SimpleSetter<WebSocketChannel>)getReceiveSetter()).get();
if(listener instanceof AbstractReceiveListener) {
final AbstractReceiveListener abstractReceiveListener = (AbstractReceiveListener) listener;
try {
abstractReceiveListener.onCloseMessage(CLOSE_MSG, this);
} catch(Exception e) {
e.printStackTrace();
}
}
}
}

Expand Down
Expand Up @@ -171,4 +171,7 @@ public interface WebSocketMessages {

@Message(id = 2045, value = "Unable to send on newly created channel!")
IllegalStateException unableToSendOnNewChannel();

@Message(id = 2046, value = "Closing WebSocket, peer went away.")
String messageCloseWebSocket();
}
Expand Up @@ -26,8 +26,10 @@
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedBinaryMessage;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.CloseMessage;
import io.undertow.websockets.core.WebSocketCallback;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketMessages;
import io.undertow.websockets.core.WebSockets;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import io.undertow.websockets.utils.FrameChecker;
Expand All @@ -46,6 +48,7 @@
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

/**
Expand Down Expand Up @@ -167,6 +170,50 @@ protected void onFullCloseMessage(WebSocketChannel channel, BufferedBinaryMessag
client.destroy();
}

@Test
public void testCloseOnPeerGone() throws Exception {
if (getVersion() == WebSocketVersion.V00) {
// ignore 00 tests for now
return;
}
final AtomicBoolean connected = new AtomicBoolean(false);
final FutureResult<CloseMessage> latch = new FutureResult();
DefaultServer.setRootHandler(new WebSocketProtocolHandshakeHandler(new WebSocketConnectionCallback() {
@Override
public void onConnect(final WebSocketHttpExchange exchange, final WebSocketChannel channel) {
connected.set(true);
channel.getReceiveSetter().set(new AbstractReceiveListener() {

@Override
protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) {
Assert.fail();
}

@Override
protected void onCloseMessage(CloseMessage msg, WebSocketChannel channel) {
latch.setResult(msg);
}

@Override
protected void onError(WebSocketChannel channel, Throwable t) {
Assert.fail();
}
});
channel.resumeReceives();
}
}));

WebSocketTestClient client = new WebSocketTestClient(getVersion(),
new URI("ws://" + NetworkUtils.formatPossibleIpv6Address(DefaultServer.getHostAddress("default")) + ":"
+ DefaultServer.getHostPort("default") + "/"));
client.connect();
client.destroy(true);
latch.getIoFuture().await(5000, TimeUnit.MILLISECONDS);
final CloseMessage msg = latch.getIoFuture().get();
Assert.assertNotNull(msg);
Assert.assertEquals(WebSocketMessages.MESSAGES.messageCloseWebSocket(), msg.getReason());
}

protected WebSocketVersion getVersion() {
return WebSocketVersion.V00;
}
Expand Down
Expand Up @@ -137,7 +137,11 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
* Destroy the client and also close open connections if any exist
*/
public void destroy() {
if (!closed) {
this.destroy(false);
}

public void destroy(boolean dirty) {
if (!closed && !dirty) {
final CountDownLatch latch = new CountDownLatch(1);
send(new CloseWebSocketFrame(), new FrameListener() {
@Override
Expand Down
Expand Up @@ -139,6 +139,14 @@ public SecurityPathMatch getSecurityInfo(final String path, final String method)
handleMatch(method, extensionMatch, currentMatch);
return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch));
}

// if nothing else, check for security info defined for URL pattern '/'
match = exactPathRoleInformation.get("/");
if (match != null) {
handleMatch(method, match, currentMatch);
return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch));
}

return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch));
}

Expand Down
Expand Up @@ -18,10 +18,6 @@

package io.undertow.servlet.test.security.constraint;

import java.io.IOException;

import javax.servlet.ServletException;

import io.undertow.server.handlers.PathHandler;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
Expand All @@ -36,18 +32,21 @@
import io.undertow.servlet.test.util.TestClassIntrospector;
import io.undertow.testutils.DefaultServer;
import io.undertow.testutils.HttpClientUtils;
import io.undertow.testutils.TestHttpClient;
import io.undertow.util.FlexBase64;
import io.undertow.util.StatusCodes;
import org.apache.http.Header;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import io.undertow.testutils.TestHttpClient;
import io.undertow.util.StatusCodes;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;

import javax.servlet.ServletException;
import java.io.IOException;

import static io.undertow.util.Headers.AUTHORIZATION;
import static io.undertow.util.Headers.BASIC;
import static io.undertow.util.Headers.WWW_AUTHENTICATE;
Expand Down Expand Up @@ -196,6 +195,19 @@ public void testAggregatedRoles() throws IOException {
runSimpleUrlTest(DefaultServer.getDefaultServerURL() + "/servletContext/secured/1/2/aa", "user1:password1", "user2:password2");
}

@Test
public void testUnknown() throws IOException {
TestHttpClient client = new TestHttpClient();
try {
HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/unknown");
HttpResponse result = client.execute(get);
assertEquals(StatusCodes.NOT_FOUND, result.getStatusLine().getStatusCode());
HttpClientUtils.readResponse(result);
} finally {
client.getConnectionManager().shutdown();
}
}

@Test
public void testHttpMethod() throws IOException {
TestHttpClient client = new TestHttpClient();
Expand Down

0 comments on commit e52cefb

Please sign in to comment.