Skip to content

Commit

Permalink
Fixes #5378 Setting Holders during STARTING
Browse files Browse the repository at this point in the history
Holders are now started/initialized if needed by a new utility method
  • Loading branch information
gregw committed Oct 6, 2020
1 parent 9a1cada commit 9e1e098
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 42 deletions.
Expand Up @@ -19,7 +19,6 @@
package org.eclipse.jetty.servlet;

import java.util.EventListener;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;

import org.eclipse.jetty.server.handler.ContextHandler;
Expand Down Expand Up @@ -78,33 +77,30 @@ public void doStart() throws Exception
throw new IllegalStateException(msg);
}

ContextHandler contextHandler = ContextHandler.getCurrentContext().getContextHandler();
if (contextHandler != null)
ContextHandler contextHandler = null;
if (getServletHandler() != null)
contextHandler = getServletHandler().getServletContextHandler();
if (contextHandler == null && ContextHandler.getCurrentContext() != null)
contextHandler = ContextHandler.getCurrentContext().getContextHandler();
if (contextHandler == null)
throw new IllegalStateException("No Context");

_listener = getInstance();
if (_listener == null)
{
_listener = getInstance();
if (_listener == null)
//create an instance of the listener and decorate it
try
{
//create an instance of the listener and decorate it
try
{
ServletContext context = contextHandler.getServletContext();
_listener = (context != null)
? context.createListener(getHeldClass())
: getHeldClass().getDeclaredConstructor().newInstance();
}
catch (ServletException ex)
{
Throwable cause = ex.getRootCause();
if (cause instanceof InstantiationException)
throw (InstantiationException)cause;
if (cause instanceof IllegalAccessException)
throw (IllegalAccessException)cause;
throw ex;
}
_listener = contextHandler.getServletContext().createListener(getHeldClass());
}
catch (ServletException ex)
{
throw ex;
}

_listener = wrap(_listener, WrapFunction.class, WrapFunction::wrapEventListener);
contextHandler.addEventListener(_listener);
}
contextHandler.addEventListener(_listener);
}

@Override
Expand Down
Expand Up @@ -1644,52 +1644,179 @@ else if ("delete".equalsIgnoreCase(action))
}
}

public static class TestPListener implements ServletRequestListener
{
@Override
public void requestInitialized(ServletRequestEvent sre)
{
ServletRequest request = sre.getServletRequest();
Integer count = (Integer)request.getAttribute("testRequestListener");
request.setAttribute("testRequestListener", count == null ? 1 : count + 1);
}

@Override
public void requestDestroyed(ServletRequestEvent sre)
{
}
}

@Test
public void testProgrammaticFilterServlet() throws Exception
public void testProgrammaticListener() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = new ServletHandler();
_server.setHandler(context);
context.setHandler(handler);
handler.addServletWithMapping(new ServletHolder(new TestServlet()

// Add a servlet to report number of listeners
handler.addServletWithMapping(new ServletHolder(new HttpServlet()
{
@Override
public void init() throws ServletException
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
handler.addFilterWithMapping(new FilterHolder(new MyFilter()), "/test/*", EnumSet.of(DispatcherType.REQUEST));
resp.getOutputStream().print("Listeners=" + req.getAttribute("testRequestListener"));
}
}), "/test");
}), "/");

// Add a listener in STOPPED, STARTING and STARTED states
handler.addListener(new ListenerHolder(TestPListener.class));
handler.addServlet(new ServletHolder(new HttpServlet()
{
@Override
public void init() throws ServletException
{
handler.addListener(new ListenerHolder(TestPListener.class));
}
})
{
{
setInitOrder(1);
}
});
_server.start();
handler.addListener(new ListenerHolder(TestPListener.class));

String request =
"GET /test HTTP/1.0\n" +
"Host: localhost\n" +
"\n";
"Host: localhost\n" +
"\n";
String response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("filter: filter"));
assertThat(response, containsString("Test"));
assertThat(response, containsString("Listeners=3"));
}

public static class TestPFilter implements Filter
{
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
}

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException
{
Integer count = (Integer)request.getAttribute("testFilter");
request.setAttribute("testFilter", count == null ? 1 : count + 1);
chain.doFilter(request, response);
}

@Override
public void destroy()
{
}
}

@Test
public void testProgrammaticFilters() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = new ServletHandler();
_server.setHandler(context);
context.setHandler(handler);

// Add a servlet to report number of filters
handler.addServletWithMapping(new ServletHolder(new HttpServlet()
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getOutputStream().print("Filters=" + req.getAttribute("testFilter"));
}
}), "/");

handler.addServletWithMapping(new ServletHolder(new HelloServlet()
// Add a filter in STOPPED, STARTING and STARTED states
handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST));
handler.addServlet(new ServletHolder(new HttpServlet()
{
@Override
public void init() throws ServletException
{
handler.addFilterWithMapping(new FilterHolder(new MyFilter()), "/hello/*", EnumSet.of(DispatcherType.REQUEST));
handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST));
}
}), "/hello/*");
})
{
{
setInitOrder(1);
}
});
_server.start();
handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST));

String request =
"GET /test HTTP/1.0\n" +
"Host: localhost\n" +
"\n";
String response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("Filters=3"));
}

_server.dumpStdErr();
public static class TestPServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getOutputStream().println(req.getRequestURI());
}
}

request =
"GET /hello HTTP/1.0\n" +
"Host: localhost\n" +
"\n";
@Test
public void testProgrammaticServlets() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = new ServletHandler();
_server.setHandler(context);
context.setHandler(handler);

// Add a filter in STOPPED, STARTING and STARTED states
handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/one");
handler.addServlet(new ServletHolder(new HttpServlet()
{
@Override
public void init() throws ServletException
{
handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/two");
}
})
{
{
setInitOrder(1);
}
});
_server.start();
handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/three");

String request = "GET /one HTTP/1.0\n" + "Host: localhost\n" + "\n";
String response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("/one"));
request = "GET /two HTTP/1.0\n" + "Host: localhost\n" + "\n";
response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("filter: filter"));
assertThat(response, containsString("Hello World"));
assertThat(response, containsString("/two"));
request = "GET /three HTTP/1.0\n" + "Host: localhost\n" + "\n";
response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("/three"));
}
}

0 comments on commit 9e1e098

Please sign in to comment.