From 9e1e0985833e33a97eacdd972b97825eb21cbac0 Mon Sep 17 00:00:00 2001 From: gregw Date: Tue, 6 Oct 2020 14:51:19 +0200 Subject: [PATCH] Fixes #5378 Setting Holders during STARTING Holders are now started/initialized if needed by a new utility method --- .../eclipse/jetty/servlet/ListenerHolder.java | 42 ++--- .../servlet/ServletContextHandlerTest.java | 165 ++++++++++++++++-- 2 files changed, 165 insertions(+), 42 deletions(-) diff --git a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java index ef10588e5d3c..42d4c0549a82 100644 --- a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java +++ b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java @@ -19,7 +19,6 @@ package org.eclipse.jetty.servlet; import java.util.EventListener; -import javax.servlet.ServletContext; import javax.servlet.ServletException; import org.eclipse.jetty.server.handler.ContextHandler; @@ -78,33 +77,30 @@ public void doStart() throws Exception throw new IllegalStateException(msg); } - ContextHandler contextHandler = ContextHandler.getCurrentContext().getContextHandler(); - if (contextHandler != null) + ContextHandler contextHandler = null; + if (getServletHandler() != null) + contextHandler = getServletHandler().getServletContextHandler(); + if (contextHandler == null && ContextHandler.getCurrentContext() != null) + contextHandler = ContextHandler.getCurrentContext().getContextHandler(); + if (contextHandler == null) + throw new IllegalStateException("No Context"); + + _listener = getInstance(); + if (_listener == null) { - _listener = getInstance(); - if (_listener == null) + //create an instance of the listener and decorate it + try { - //create an instance of the listener and decorate it - try - { - ServletContext context = contextHandler.getServletContext(); - _listener = (context != null) - ? context.createListener(getHeldClass()) - : getHeldClass().getDeclaredConstructor().newInstance(); - } - catch (ServletException ex) - { - Throwable cause = ex.getRootCause(); - if (cause instanceof InstantiationException) - throw (InstantiationException)cause; - if (cause instanceof IllegalAccessException) - throw (IllegalAccessException)cause; - throw ex; - } + _listener = contextHandler.getServletContext().createListener(getHeldClass()); } + catch (ServletException ex) + { + throw ex; + } + _listener = wrap(_listener, WrapFunction.class, WrapFunction::wrapEventListener); - contextHandler.addEventListener(_listener); } + contextHandler.addEventListener(_listener); } @Override diff --git a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java index efab1cb97921..3119dec21a9d 100644 --- a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java +++ b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java @@ -1644,52 +1644,179 @@ else if ("delete".equalsIgnoreCase(action)) } } + public static class TestPListener implements ServletRequestListener + { + @Override + public void requestInitialized(ServletRequestEvent sre) + { + ServletRequest request = sre.getServletRequest(); + Integer count = (Integer)request.getAttribute("testRequestListener"); + request.setAttribute("testRequestListener", count == null ? 1 : count + 1); + } + + @Override + public void requestDestroyed(ServletRequestEvent sre) + { + } + } + @Test - public void testProgrammaticFilterServlet() throws Exception + public void testProgrammaticListener() throws Exception { ServletContextHandler context = new ServletContextHandler(); ServletHandler handler = new ServletHandler(); _server.setHandler(context); context.setHandler(handler); - handler.addServletWithMapping(new ServletHolder(new TestServlet() + + // Add a servlet to report number of listeners + handler.addServletWithMapping(new ServletHolder(new HttpServlet() { @Override - public void init() throws ServletException + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - handler.addFilterWithMapping(new FilterHolder(new MyFilter()), "/test/*", EnumSet.of(DispatcherType.REQUEST)); + resp.getOutputStream().print("Listeners=" + req.getAttribute("testRequestListener")); } - }), "/test"); + }), "/"); + // Add a listener in STOPPED, STARTING and STARTED states + handler.addListener(new ListenerHolder(TestPListener.class)); + handler.addServlet(new ServletHolder(new HttpServlet() + { + @Override + public void init() throws ServletException + { + handler.addListener(new ListenerHolder(TestPListener.class)); + } + }) + { + { + setInitOrder(1); + } + }); _server.start(); + handler.addListener(new ListenerHolder(TestPListener.class)); String request = "GET /test HTTP/1.0\n" + - "Host: localhost\n" + - "\n"; + "Host: localhost\n" + + "\n"; String response = _connector.getResponse(request); assertThat(response, containsString("200 OK")); - assertThat(response, containsString("filter: filter")); - assertThat(response, containsString("Test")); + assertThat(response, containsString("Listeners=3")); + } + + public static class TestPFilter implements Filter + { + @Override + public void init(FilterConfig filterConfig) throws ServletException + { + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException + { + Integer count = (Integer)request.getAttribute("testFilter"); + request.setAttribute("testFilter", count == null ? 1 : count + 1); + chain.doFilter(request, response); + } + + @Override + public void destroy() + { + } + } + + @Test + public void testProgrammaticFilters() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + ServletHandler handler = new ServletHandler(); + _server.setHandler(context); + context.setHandler(handler); + // Add a servlet to report number of filters + handler.addServletWithMapping(new ServletHolder(new HttpServlet() + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getOutputStream().print("Filters=" + req.getAttribute("testFilter")); + } + }), "/"); - handler.addServletWithMapping(new ServletHolder(new HelloServlet() + // Add a filter in STOPPED, STARTING and STARTED states + handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST)); + handler.addServlet(new ServletHolder(new HttpServlet() { @Override public void init() throws ServletException { - handler.addFilterWithMapping(new FilterHolder(new MyFilter()), "/hello/*", EnumSet.of(DispatcherType.REQUEST)); + handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST)); } - }), "/hello/*"); + }) + { + { + setInitOrder(1); + } + }); + _server.start(); + handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST)); + + String request = + "GET /test HTTP/1.0\n" + + "Host: localhost\n" + + "\n"; + String response = _connector.getResponse(request); + assertThat(response, containsString("200 OK")); + assertThat(response, containsString("Filters=3")); + } - _server.dumpStdErr(); + public static class TestPServlet extends HttpServlet + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getOutputStream().println(req.getRequestURI()); + } + } - request = - "GET /hello HTTP/1.0\n" + - "Host: localhost\n" + - "\n"; + @Test + public void testProgrammaticServlets() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + ServletHandler handler = new ServletHandler(); + _server.setHandler(context); + context.setHandler(handler); + + // Add a filter in STOPPED, STARTING and STARTED states + handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/one"); + handler.addServlet(new ServletHolder(new HttpServlet() + { + @Override + public void init() throws ServletException + { + handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/two"); + } + }) + { + { + setInitOrder(1); + } + }); + _server.start(); + handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/three"); + + String request = "GET /one HTTP/1.0\n" + "Host: localhost\n" + "\n"; + String response = _connector.getResponse(request); + assertThat(response, containsString("200 OK")); + assertThat(response, containsString("/one")); + request = "GET /two HTTP/1.0\n" + "Host: localhost\n" + "\n"; response = _connector.getResponse(request); assertThat(response, containsString("200 OK")); - assertThat(response, containsString("filter: filter")); - assertThat(response, containsString("Hello World")); + assertThat(response, containsString("/two")); + request = "GET /three HTTP/1.0\n" + "Host: localhost\n" + "\n"; + response = _connector.getResponse(request); + assertThat(response, containsString("200 OK")); + assertThat(response, containsString("/three")); } } \ No newline at end of file