diff --git a/core/src/main/java/io/undertow/conduits/ReadTimeoutStreamSourceConduit.java b/core/src/main/java/io/undertow/conduits/ReadTimeoutStreamSourceConduit.java index 6dc9cde747..69255122ce 100644 --- a/core/src/main/java/io/undertow/conduits/ReadTimeoutStreamSourceConduit.java +++ b/core/src/main/java/io/undertow/conduits/ReadTimeoutStreamSourceConduit.java @@ -18,11 +18,6 @@ package io.undertow.conduits; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; -import java.util.concurrent.TimeUnit; - import io.undertow.UndertowLogger; import io.undertow.UndertowMessages; import io.undertow.UndertowOptions; @@ -41,6 +36,11 @@ import org.xnio.conduits.ReadReadyHandler; import org.xnio.conduits.StreamSourceConduit; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.util.concurrent.TimeUnit; + /** * Wrapper for read timeout. This should always be the first wrapper applied to the underlying channel. * @@ -49,7 +49,7 @@ */ public final class ReadTimeoutStreamSourceConduit extends AbstractStreamSourceConduit { - private XnioExecutor.Key handle; + private volatile XnioExecutor.Key handle; private final StreamConnection connection; private volatile long expireTime = -1; private final OpenListener openListener; @@ -60,14 +60,21 @@ public final class ReadTimeoutStreamSourceConduit extends AbstractStreamSourceCo private final Runnable timeoutCommand = new Runnable() { @Override public void run() { - handle = null; - if (expireTime == -1) { + synchronized (ReadTimeoutStreamSourceConduit.this) { + handle = null; + } + if (expireTime == -1 || !connection.isOpen()) { return; } long current = System.currentTimeMillis(); if (current < expireTime) { //timeout has been bumped, re-schedule - handle = WorkerUtils.executeAfter(connection.getIoThread(),timeoutCommand, (expireTime - current) + FUZZ_FACTOR, TimeUnit.MILLISECONDS); + if (handle == null) { + synchronized (ReadTimeoutStreamSourceConduit.this) { + if (handle == null) + handle = WorkerUtils.executeAfter(connection.getIoThread(), timeoutCommand, (expireTime - current) + FUZZ_FACTOR, TimeUnit.MILLISECONDS); + } + } return; } UndertowLogger.REQUEST_LOGGER.tracef("Timing out channel %s due to inactivity", connection.getSourceChannel()); @@ -131,12 +138,16 @@ private void handleReadTimeout(final long ret) throws IOException { final long expireTimeVar = expireTime; if (expireTimeVar != -1 && currentTime > expireTimeVar) { IoUtils.safeClose(connection); - throw UndertowMessages.MESSAGES.readTimedOut(this.getTimeout()); + throw UndertowMessages.MESSAGES.readTimedOut(currentTime - (expireTimeVar - this.getTimeout())); } } expireTime = currentTime + timeout; if (handle == null) { - handle = connection.getIoThread().executeAfter(timeoutCommand, timeout, TimeUnit.MILLISECONDS); + synchronized (this) { + if (handle == null) + handle = connection.getIoThread().executeAfter(timeoutCommand, timeout, TimeUnit.MILLISECONDS); + } + } } @@ -232,9 +243,13 @@ public void terminateReads() throws IOException { private void cleanup() { if (handle != null) { - handle.remove(); - handle = null; - expireTime = -1; + synchronized (this) { + if (handle != null) { + handle.remove(); + handle = null; + expireTime = -1; + } + } } } @@ -247,7 +262,7 @@ public void suspendReads() { private void checkExpired() throws ReadTimeoutException { synchronized (this) { if (expired) { - throw UndertowMessages.MESSAGES.readTimedOut(System.currentTimeMillis()); + throw UndertowMessages.MESSAGES.readTimedOut(System.currentTimeMillis() - (expireTime - getTimeout())); } } } diff --git a/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java b/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java index aacebd8e51..4b76338091 100644 --- a/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java +++ b/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java @@ -24,6 +24,7 @@ import io.undertow.websockets.extensions.ExtensionFunction; import org.xnio.ChannelExceptionHandler; import org.xnio.ChannelListener; +import org.xnio.ChannelListener.SimpleSetter; import org.xnio.ChannelListeners; import org.xnio.IoUtils; import org.xnio.OptionMap; @@ -82,6 +83,7 @@ public abstract class WebSocketChannel extends AbstractFramedChannel peerConnections; + private static final CloseMessage CLOSE_MSG = new CloseMessage(CloseMessage.GOING_AWAY, WebSocketMessages.MESSAGES.messageCloseWebSocket()); /** * Create a new {@link WebSocketChannel} * 8 @@ -158,6 +160,15 @@ protected void lastDataRead() { } catch (IOException e) { IoUtils.safeClose(this); } + final ChannelListener listener = ((SimpleSetter)getReceiveSetter()).get(); + if(listener instanceof AbstractReceiveListener) { + final AbstractReceiveListener abstractReceiveListener = (AbstractReceiveListener) listener; + try { + abstractReceiveListener.onCloseMessage(CLOSE_MSG, this); + } catch(Exception e) { + e.printStackTrace(); + } + } } } diff --git a/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java b/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java index fc17a8387c..491a33f0e0 100644 --- a/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java +++ b/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java @@ -171,4 +171,7 @@ public interface WebSocketMessages { @Message(id = 2045, value = "Unable to send on newly created channel!") IllegalStateException unableToSendOnNewChannel(); + + @Message(id = 2046, value = "Closing WebSocket, peer went away.") + String messageCloseWebSocket(); } diff --git a/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java b/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java index ac52118f0b..4561aec3a9 100644 --- a/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java +++ b/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java @@ -26,8 +26,10 @@ import io.undertow.websockets.core.AbstractReceiveListener; import io.undertow.websockets.core.BufferedBinaryMessage; import io.undertow.websockets.core.BufferedTextMessage; +import io.undertow.websockets.core.CloseMessage; import io.undertow.websockets.core.WebSocketCallback; import io.undertow.websockets.core.WebSocketChannel; +import io.undertow.websockets.core.WebSocketMessages; import io.undertow.websockets.core.WebSockets; import io.undertow.websockets.spi.WebSocketHttpExchange; import io.undertow.websockets.utils.FrameChecker; @@ -46,6 +48,7 @@ import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -167,6 +170,50 @@ protected void onFullCloseMessage(WebSocketChannel channel, BufferedBinaryMessag client.destroy(); } + @Test + public void testCloseOnPeerGone() throws Exception { + if (getVersion() == WebSocketVersion.V00) { + // ignore 00 tests for now + return; + } + final AtomicBoolean connected = new AtomicBoolean(false); + final FutureResult latch = new FutureResult(); + DefaultServer.setRootHandler(new WebSocketProtocolHandshakeHandler(new WebSocketConnectionCallback() { + @Override + public void onConnect(final WebSocketHttpExchange exchange, final WebSocketChannel channel) { + connected.set(true); + channel.getReceiveSetter().set(new AbstractReceiveListener() { + + @Override + protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) { + Assert.fail(); + } + + @Override + protected void onCloseMessage(CloseMessage msg, WebSocketChannel channel) { + latch.setResult(msg); + } + + @Override + protected void onError(WebSocketChannel channel, Throwable t) { + Assert.fail(); + } + }); + channel.resumeReceives(); + } + })); + + WebSocketTestClient client = new WebSocketTestClient(getVersion(), + new URI("ws://" + NetworkUtils.formatPossibleIpv6Address(DefaultServer.getHostAddress("default")) + ":" + + DefaultServer.getHostPort("default") + "/")); + client.connect(); + client.destroy(true); + latch.getIoFuture().await(5000, TimeUnit.MILLISECONDS); + final CloseMessage msg = latch.getIoFuture().get(); + Assert.assertNotNull(msg); + Assert.assertEquals(WebSocketMessages.MESSAGES.messageCloseWebSocket(), msg.getReason()); + } + protected WebSocketVersion getVersion() { return WebSocketVersion.V00; } diff --git a/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java b/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java index b50c785369..373c186b55 100644 --- a/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java +++ b/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java @@ -137,7 +137,11 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E * Destroy the client and also close open connections if any exist */ public void destroy() { - if (!closed) { + this.destroy(false); + } + + public void destroy(boolean dirty) { + if (!closed && !dirty) { final CountDownLatch latch = new CountDownLatch(1); send(new CloseWebSocketFrame(), new FrameListener() { @Override diff --git a/servlet/src/main/java/io/undertow/servlet/handlers/security/SecurityPathMatches.java b/servlet/src/main/java/io/undertow/servlet/handlers/security/SecurityPathMatches.java index 04fca8fa68..9cf71e23d8 100644 --- a/servlet/src/main/java/io/undertow/servlet/handlers/security/SecurityPathMatches.java +++ b/servlet/src/main/java/io/undertow/servlet/handlers/security/SecurityPathMatches.java @@ -139,6 +139,14 @@ public SecurityPathMatch getSecurityInfo(final String path, final String method) handleMatch(method, extensionMatch, currentMatch); return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch)); } + + // if nothing else, check for security info defined for URL pattern '/' + match = exactPathRoleInformation.get("/"); + if (match != null) { + handleMatch(method, match, currentMatch); + return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch)); + } + return new SecurityPathMatch(currentMatch.type, mergeConstraints(currentMatch)); } diff --git a/servlet/src/test/java/io/undertow/servlet/test/security/constraint/SecurityConstraintUrlMappingTestCase.java b/servlet/src/test/java/io/undertow/servlet/test/security/constraint/SecurityConstraintUrlMappingTestCase.java index 55cb995dc0..9e92424edc 100644 --- a/servlet/src/test/java/io/undertow/servlet/test/security/constraint/SecurityConstraintUrlMappingTestCase.java +++ b/servlet/src/test/java/io/undertow/servlet/test/security/constraint/SecurityConstraintUrlMappingTestCase.java @@ -18,10 +18,6 @@ package io.undertow.servlet.test.security.constraint; -import java.io.IOException; - -import javax.servlet.ServletException; - import io.undertow.server.handlers.PathHandler; import io.undertow.servlet.api.DeploymentInfo; import io.undertow.servlet.api.DeploymentManager; @@ -36,18 +32,21 @@ import io.undertow.servlet.test.util.TestClassIntrospector; import io.undertow.testutils.DefaultServer; import io.undertow.testutils.HttpClientUtils; +import io.undertow.testutils.TestHttpClient; import io.undertow.util.FlexBase64; +import io.undertow.util.StatusCodes; import org.apache.http.Header; import org.apache.http.HttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; -import io.undertow.testutils.TestHttpClient; -import io.undertow.util.StatusCodes; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; +import javax.servlet.ServletException; +import java.io.IOException; + import static io.undertow.util.Headers.AUTHORIZATION; import static io.undertow.util.Headers.BASIC; import static io.undertow.util.Headers.WWW_AUTHENTICATE; @@ -196,6 +195,19 @@ public void testAggregatedRoles() throws IOException { runSimpleUrlTest(DefaultServer.getDefaultServerURL() + "/servletContext/secured/1/2/aa", "user1:password1", "user2:password2"); } + @Test + public void testUnknown() throws IOException { + TestHttpClient client = new TestHttpClient(); + try { + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/unknown"); + HttpResponse result = client.execute(get); + assertEquals(StatusCodes.NOT_FOUND, result.getStatusLine().getStatusCode()); + HttpClientUtils.readResponse(result); + } finally { + client.getConnectionManager().shutdown(); + } + } + @Test public void testHttpMethod() throws IOException { TestHttpClient client = new TestHttpClient(); diff --git a/servlet/src/test/java/io/undertow/servlet/test/security/constraint/SecurityConstraintUrlMappingWithUnspecifiedForbiddenTestCase.java b/servlet/src/test/java/io/undertow/servlet/test/security/constraint/SecurityConstraintUrlMappingWithUnspecifiedForbiddenTestCase.java new file mode 100644 index 0000000000..8b36f2950b --- /dev/null +++ b/servlet/src/test/java/io/undertow/servlet/test/security/constraint/SecurityConstraintUrlMappingWithUnspecifiedForbiddenTestCase.java @@ -0,0 +1,195 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2022 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.undertow.servlet.test.security.constraint; + +import io.undertow.server.handlers.PathHandler; +import io.undertow.servlet.api.DeploymentInfo; +import io.undertow.servlet.api.DeploymentManager; +import io.undertow.servlet.api.LoginConfig; +import io.undertow.servlet.api.SecurityConstraint; +import io.undertow.servlet.api.SecurityInfo; +import io.undertow.servlet.api.ServletContainer; +import io.undertow.servlet.api.ServletInfo; +import io.undertow.servlet.api.WebResourceCollection; +import io.undertow.servlet.test.SimpleServletTestCase; +import io.undertow.servlet.test.util.MessageServlet; +import io.undertow.servlet.test.util.TestClassIntrospector; +import io.undertow.testutils.DefaultServer; +import io.undertow.testutils.HttpClientUtils; +import io.undertow.testutils.TestHttpClient; +import io.undertow.util.StatusCodes; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; + +import javax.servlet.ServletException; +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +/** + * Do the same as SecurityConstraintURLMappingTestCase, with a small difference, all access to / are denied and public/* is no longer + * covered by a SecurityConstraint. + * Verify that all works as in the super class test case, except for public/*, that now is forbidden for all HTTP methods. + * + * @author Flavia Rainone + */ +@RunWith(DefaultServer.class) +public class SecurityConstraintUrlMappingWithUnspecifiedForbiddenTestCase extends SecurityConstraintUrlMappingTestCase { + + @BeforeClass + public static void setup() throws ServletException { + + final PathHandler root = new PathHandler(); + final ServletContainer container = ServletContainer.Factory.newInstance(); + + ServletInfo s = new ServletInfo("servlet", AuthenticationMessageServlet.class) + .addInitParam(MessageServlet.MESSAGE, HELLO_WORLD) + .addMapping("/role1") + .addMapping("/role2") + .addMapping("/starstar") + .addMapping("/secured/role2/*") + .addMapping("/secured/1/2/*") + .addMapping("/public/*") + .addMapping("/extension/*"); + + ServletIdentityManager identityManager = new ServletIdentityManager(); + identityManager.addUser("user1", "password1", "role1"); + identityManager.addUser("user2", "password2", "role2", "**"); + identityManager.addUser("user3", "password3", "role1", "role2"); + identityManager.addUser("user4", "password4", "badRole"); + + DeploymentInfo builder = new DeploymentInfo() + .setClassLoader(SimpleServletTestCase.class.getClassLoader()) + .setContextPath("/servletContext") + .setClassIntrospecter(TestClassIntrospector.INSTANCE) + .setDeploymentName("servletContext.war") + .setIdentityManager(identityManager) + .setLoginConfig(new LoginConfig("BASIC", "Test Realm")) + .addServlet(s); + + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/role1")) + .addRoleAllowed("role1")); + + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/starstar")) + .addRoleAllowed("**")); + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/secured/*")) + .addRoleAllowed("role2")); + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/secured/*")) + .addRoleAllowed("role2")); + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/secured/1/*")) + .addRoleAllowed("role1")); + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/secured/1/2/*")) + .addRoleAllowed("role2")); + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("*.html")) + .addRoleAllowed("role2")); + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/")).setEmptyRoleSemantic(SecurityInfo.EmptyRoleSemantic.DENY)); + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/public/postSecured/*") + .addHttpMethod("POST")) + .addRoleAllowed("role1")); + + DeploymentManager manager = container.addDeployment(builder); + manager.deploy(); + root.addPrefixPath(builder.getContextPath(), manager.start()); + + builder = new DeploymentInfo() + .setClassLoader(SimpleServletTestCase.class.getClassLoader()) + .setContextPath("/star") + .setClassIntrospecter(TestClassIntrospector.INSTANCE) + .setDeploymentName("servletContext.war") + .setIdentityManager(identityManager) + .setLoginConfig(new LoginConfig("BASIC", "Test Realm")) + .addSecurityRole("**") + .addServlet(s); + + builder.addSecurityConstraint(new SecurityConstraint() + .addWebResourceCollection(new WebResourceCollection() + .addUrlPattern("/starstar")) + .addRoleAllowed("**")); + + manager = container.addDeployment(builder); + manager.deploy(); + root.addPrefixPath(builder.getContextPath(), manager.start()); + DefaultServer.setRootHandler(root); + } + + @Test + @Override + public void testUnknown() throws IOException { + TestHttpClient client = new TestHttpClient(); + try { + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/unknown"); + HttpResponse result = client.execute(get); + assertEquals(StatusCodes.FORBIDDEN, result.getStatusLine().getStatusCode()); + HttpClientUtils.readResponse(result); + } finally { + client.getConnectionManager().shutdown(); + } + } + + @Test + public void testPublic() throws IOException { + TestHttpClient client = new TestHttpClient(); + try { + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/public"); + HttpResponse result = client.execute(get); + assertEquals(StatusCodes.FORBIDDEN, result.getStatusLine().getStatusCode()); + HttpClientUtils.readResponse(result); + } finally { + client.getConnectionManager().shutdown(); + } + } + + @Test + @Override + public void testExtensionMatch() throws IOException { + runSimpleUrlTest(DefaultServer.getDefaultServerURL() + "/servletContext/extension/a.html", "user1:password1", "user2:password2"); + TestHttpClient client = new TestHttpClient(); + try { + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/public/a.html"); + get.addHeader("ExpectedMechanism", "None"); + get.addHeader("ExpectedUser", "None"); + HttpResponse result = client.execute(get); + assertEquals(StatusCodes.UNAUTHORIZED, result.getStatusLine().getStatusCode()); + HttpClientUtils.readResponse(result); + } finally { + client.getConnectionManager().shutdown(); + } + } +}