Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.2.x backport bug fixes #1361

Merged
merged 4 commits into from Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -18,11 +18,6 @@

package io.undertow.conduits;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;

import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.UndertowOptions;
Expand All @@ -41,6 +36,11 @@
import org.xnio.conduits.ReadReadyHandler;
import org.xnio.conduits.StreamSourceConduit;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;

/**
* Wrapper for read timeout. This should always be the first wrapper applied to the underlying channel.
*
Expand All @@ -49,7 +49,7 @@
*/
public final class ReadTimeoutStreamSourceConduit extends AbstractStreamSourceConduit<StreamSourceConduit> {

private XnioExecutor.Key handle;
private volatile XnioExecutor.Key handle;
private final StreamConnection connection;
private volatile long expireTime = -1;
private final OpenListener openListener;
Expand All @@ -60,14 +60,21 @@ public final class ReadTimeoutStreamSourceConduit extends AbstractStreamSourceCo
private final Runnable timeoutCommand = new Runnable() {
@Override
public void run() {
handle = null;
if (expireTime == -1) {
synchronized (ReadTimeoutStreamSourceConduit.this) {
handle = null;
}
if (expireTime == -1 || !connection.isOpen()) {
return;
}
long current = System.currentTimeMillis();
if (current < expireTime) {
//timeout has been bumped, re-schedule
handle = WorkerUtils.executeAfter(connection.getIoThread(),timeoutCommand, (expireTime - current) + FUZZ_FACTOR, TimeUnit.MILLISECONDS);
if (handle == null) {
synchronized (ReadTimeoutStreamSourceConduit.this) {
if (handle == null)
handle = WorkerUtils.executeAfter(connection.getIoThread(), timeoutCommand, (expireTime - current) + FUZZ_FACTOR, TimeUnit.MILLISECONDS);
}
}
return;
}
UndertowLogger.REQUEST_LOGGER.tracef("Timing out channel %s due to inactivity", connection.getSourceChannel());
Expand Down Expand Up @@ -131,12 +138,16 @@ private void handleReadTimeout(final long ret) throws IOException {
final long expireTimeVar = expireTime;
if (expireTimeVar != -1 && currentTime > expireTimeVar) {
IoUtils.safeClose(connection);
throw UndertowMessages.MESSAGES.readTimedOut(this.getTimeout());
throw UndertowMessages.MESSAGES.readTimedOut(currentTime - (expireTimeVar - this.getTimeout()));
}
}
expireTime = currentTime + timeout;
if (handle == null) {
handle = connection.getIoThread().executeAfter(timeoutCommand, timeout, TimeUnit.MILLISECONDS);
synchronized (this) {
if (handle == null)
handle = connection.getIoThread().executeAfter(timeoutCommand, timeout, TimeUnit.MILLISECONDS);
}

}
}

Expand Down Expand Up @@ -232,9 +243,13 @@ public void terminateReads() throws IOException {

private void cleanup() {
if (handle != null) {
handle.remove();
handle = null;
expireTime = -1;
synchronized (this) {
if (handle != null) {
handle.remove();
handle = null;
expireTime = -1;
}
}
}
}

Expand All @@ -247,7 +262,7 @@ public void suspendReads() {
private void checkExpired() throws ReadTimeoutException {
synchronized (this) {
if (expired) {
throw UndertowMessages.MESSAGES.readTimedOut(System.currentTimeMillis());
throw UndertowMessages.MESSAGES.readTimedOut(System.currentTimeMillis() - (expireTime - getTimeout()));
}
}
}
Expand Down
Expand Up @@ -24,6 +24,7 @@
import io.undertow.websockets.extensions.ExtensionFunction;
import org.xnio.ChannelExceptionHandler;
import org.xnio.ChannelListener;
import org.xnio.ChannelListener.SimpleSetter;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
Expand Down Expand Up @@ -82,6 +83,7 @@ public abstract class WebSocketChannel extends AbstractFramedChannel<WebSocketCh
*/
private final Set<WebSocketChannel> peerConnections;

private static final CloseMessage CLOSE_MSG = new CloseMessage(CloseMessage.GOING_AWAY, WebSocketMessages.MESSAGES.messageCloseWebSocket());
/**
* Create a new {@link WebSocketChannel}
* 8
Expand Down Expand Up @@ -158,6 +160,15 @@ protected void lastDataRead() {
} catch (IOException e) {
IoUtils.safeClose(this);
}
final ChannelListener<?> listener = ((SimpleSetter<WebSocketChannel>)getReceiveSetter()).get();
if(listener instanceof AbstractReceiveListener) {
final AbstractReceiveListener abstractReceiveListener = (AbstractReceiveListener) listener;
try {
abstractReceiveListener.onCloseMessage(CLOSE_MSG, this);
} catch(Exception e) {
e.printStackTrace();
}
}
}
}

Expand Down
Expand Up @@ -171,4 +171,7 @@ public interface WebSocketMessages {

@Message(id = 2045, value = "Unable to send on newly created channel!")
IllegalStateException unableToSendOnNewChannel();

@Message(id = 2046, value = "Closing WebSocket, peer went away.")
String messageCloseWebSocket();
}
Expand Up @@ -26,8 +26,10 @@
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedBinaryMessage;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.CloseMessage;
import io.undertow.websockets.core.WebSocketCallback;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketMessages;
import io.undertow.websockets.core.WebSockets;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import io.undertow.websockets.utils.FrameChecker;
Expand All @@ -46,6 +48,7 @@
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

/**
Expand Down Expand Up @@ -167,6 +170,50 @@ protected void onFullCloseMessage(WebSocketChannel channel, BufferedBinaryMessag
client.destroy();
}

@Test
public void testCloseOnPeerGone() throws Exception {
if (getVersion() == WebSocketVersion.V00) {
// ignore 00 tests for now
return;
}
final AtomicBoolean connected = new AtomicBoolean(false);
final FutureResult<CloseMessage> latch = new FutureResult();
DefaultServer.setRootHandler(new WebSocketProtocolHandshakeHandler(new WebSocketConnectionCallback() {
@Override
public void onConnect(final WebSocketHttpExchange exchange, final WebSocketChannel channel) {
connected.set(true);
channel.getReceiveSetter().set(new AbstractReceiveListener() {

@Override
protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) {
Assert.fail();
}

@Override
protected void onCloseMessage(CloseMessage msg, WebSocketChannel channel) {
latch.setResult(msg);
}

@Override
protected void onError(WebSocketChannel channel, Throwable t) {
Assert.fail();
}
});
channel.resumeReceives();
}
}));

WebSocketTestClient client = new WebSocketTestClient(getVersion(),
new URI("ws://" + NetworkUtils.formatPossibleIpv6Address(DefaultServer.getHostAddress("default")) + ":"
+ DefaultServer.getHostPort("default") + "/"));
client.connect();
client.destroy(true);
latch.getIoFuture().await(5000, TimeUnit.MILLISECONDS);
final CloseMessage msg = latch.getIoFuture().get();
Assert.assertNotNull(msg);
Assert.assertEquals(WebSocketMessages.MESSAGES.messageCloseWebSocket(), msg.getReason());
}

protected WebSocketVersion getVersion() {
return WebSocketVersion.V00;
}
Expand Down
Expand Up @@ -137,7 +137,11 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
* Destroy the client and also close open connections if any exist
*/
public void destroy() {
if (!closed) {
this.destroy(false);
}

public void destroy(boolean dirty) {
if (!closed && !dirty) {
final CountDownLatch latch = new CountDownLatch(1);
send(new CloseWebSocketFrame(), new FrameListener() {
@Override
Expand Down
Expand Up @@ -139,6 +139,14 @@ public SecurityPathMatch getSecurityInfo(final String path, final String method)
handleMatch(method, extensionMatch, currentMatch);
return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch));
}

// if nothing else, check for security info defined for URL pattern '/'
match = exactPathRoleInformation.get("/");
if (match != null) {
handleMatch(method, match, currentMatch);
return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch));
}

return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch));
}

Expand Down
Expand Up @@ -18,10 +18,6 @@

package io.undertow.servlet.test.security.constraint;

import java.io.IOException;

import javax.servlet.ServletException;

import io.undertow.server.handlers.PathHandler;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
Expand All @@ -36,18 +32,21 @@
import io.undertow.servlet.test.util.TestClassIntrospector;
import io.undertow.testutils.DefaultServer;
import io.undertow.testutils.HttpClientUtils;
import io.undertow.testutils.TestHttpClient;
import io.undertow.util.FlexBase64;
import io.undertow.util.StatusCodes;
import org.apache.http.Header;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import io.undertow.testutils.TestHttpClient;
import io.undertow.util.StatusCodes;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;

import javax.servlet.ServletException;
import java.io.IOException;

import static io.undertow.util.Headers.AUTHORIZATION;
import static io.undertow.util.Headers.BASIC;
import static io.undertow.util.Headers.WWW_AUTHENTICATE;
Expand Down Expand Up @@ -196,6 +195,19 @@ public void testAggregatedRoles() throws IOException {
runSimpleUrlTest(DefaultServer.getDefaultServerURL() + "/servletContext/secured/1/2/aa", "user1:password1", "user2:password2");
}

@Test
public void testUnknown() throws IOException {
TestHttpClient client = new TestHttpClient();
try {
HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/unknown");
HttpResponse result = client.execute(get);
assertEquals(StatusCodes.NOT_FOUND, result.getStatusLine().getStatusCode());
HttpClientUtils.readResponse(result);
} finally {
client.getConnectionManager().shutdown();
}
}

@Test
public void testHttpMethod() throws IOException {
TestHttpClient client = new TestHttpClient();
Expand Down