diff --git a/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/JettyWebSocketFilterTest.java b/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/JettyWebSocketFilterTest.java index e1786c196b8c..c0be53db2630 100644 --- a/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/JettyWebSocketFilterTest.java +++ b/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/JettyWebSocketFilterTest.java @@ -21,19 +21,24 @@ import java.net.URI; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import javax.servlet.http.HttpServlet; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.eclipse.jetty.util.component.AbstractLifeCycle; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.websocket.util.server.WebSocketUpgradeFilter; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -45,8 +50,17 @@ public class JettyWebSocketFilterTest private WebSocketClient client; private ServletContextHandler contextHandler; - @BeforeEach - public void start() throws Exception + public void start(JettyWebSocketServletContainerInitializer.Configurator configurator) throws Exception + { + start(configurator, null); + } + + public void start(ServletHolder servletHolder) throws Exception + { + start(null, servletHolder); + } + + public void start(JettyWebSocketServletContainerInitializer.Configurator configurator, ServletHolder servletHolder) throws Exception { server = new Server(); connector = new ServerConnector(server); @@ -54,9 +68,11 @@ public void start() throws Exception contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS); contextHandler.setContextPath("/"); + if (servletHolder != null) + contextHandler.addServlet(servletHolder, "/"); server.setHandler(contextHandler); - JettyWebSocketServletContainerInitializer.configure(contextHandler, null); + JettyWebSocketServletContainerInitializer.configure(contextHandler, configurator); server.start(); client = new WebSocketClient(); @@ -70,9 +86,37 @@ public void stop() throws Exception server.stop(); } + @Test + public void testWebSocketUpgradeFilter() throws Exception + { + start((context, container) -> container.addMapping("/", EchoSocket.class)); + + // After mapping is added we have an UpgradeFilter. + assertThat(contextHandler.getServletHandler().getFilters().length, is(1)); + FilterHolder filterHolder = contextHandler.getServletHandler().getFilter("WebSocketUpgradeFilter"); + assertNotNull(filterHolder); + assertThat(filterHolder.getState(), is(AbstractLifeCycle.STARTED)); + assertThat(filterHolder.getFilter(), instanceOf(WebSocketUpgradeFilter.class)); + + // Test we can upgrade to websocket and send a message. + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/filterPath"); + EventSocket socket = new EventSocket(); + CompletableFuture connect = client.connect(socket, uri); + try (Session session = connect.get(5, TimeUnit.SECONDS)) + { + session.getRemote().sendString("hello world"); + } + assertTrue(socket.closeLatch.await(10, TimeUnit.SECONDS)); + + String msg = socket.textMessages.poll(); + assertThat(msg, is("hello world")); + } + @Test public void testLazyWebSocketUpgradeFilter() throws Exception { + start(null, null); + // JettyWebSocketServerContainer has already been created. JettyWebSocketServerContainer container = JettyWebSocketServerContainer.getContainer(contextHandler.getServletContext()); assertNotNull(container); @@ -83,6 +127,47 @@ public void testLazyWebSocketUpgradeFilter() throws Exception // After mapping is added we have an UpgradeFilter. container.addMapping("/", EchoSocket.class); assertThat(contextHandler.getServletHandler().getFilters().length, is(1)); + FilterHolder filterHolder = contextHandler.getServletHandler().getFilter("WebSocketUpgradeFilter"); + assertNotNull(filterHolder); + assertThat(filterHolder.getState(), is(AbstractLifeCycle.STARTED)); + assertThat(filterHolder.getFilter(), instanceOf(WebSocketUpgradeFilter.class)); + + // Test we can upgrade to websocket and send a message. + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/filterPath"); + EventSocket socket = new EventSocket(); + CompletableFuture connect = client.connect(socket, uri); + try (Session session = connect.get(5, TimeUnit.SECONDS)) + { + session.getRemote().sendString("hello world"); + } + assertTrue(socket.closeLatch.await(10, TimeUnit.SECONDS)); + + String msg = socket.textMessages.poll(); + assertThat(msg, is("hello world")); + } + + @Test + public void testWebSocketUpgradeFilterWhileStarting() throws Exception + { + start(new ServletHolder(new HttpServlet() + { + @Override + public void init() + { + JettyWebSocketServerContainer container = JettyWebSocketServerContainer.getContainer(getServletContext()); + if (container == null) + throw new IllegalArgumentException("Missing JettyWebSocketServerContainer"); + + container.addMapping("/", EchoSocket.class); + } + })); + + // After mapping is added we have an UpgradeFilter. + assertThat(contextHandler.getServletHandler().getFilters().length, is(1)); + FilterHolder filterHolder = contextHandler.getServletHandler().getFilter("WebSocketUpgradeFilter"); + assertNotNull(filterHolder); + assertThat(filterHolder.getState(), is(AbstractLifeCycle.STARTED)); + assertThat(filterHolder.getFilter(), instanceOf(WebSocketUpgradeFilter.class)); // Test we can upgrade to websocket and send a message. URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/filterPath");