Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #5378 Setting Holders during STARTING #5397

Merged
merged 2 commits into from Oct 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -19,7 +19,6 @@
package org.eclipse.jetty.servlet;

import java.util.EventListener;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;

import org.eclipse.jetty.server.handler.ContextHandler;
Expand Down Expand Up @@ -78,33 +77,30 @@ public void doStart() throws Exception
throw new IllegalStateException(msg);
}

ContextHandler contextHandler = ContextHandler.getCurrentContext().getContextHandler();
if (contextHandler != null)
ContextHandler contextHandler = null;
if (getServletHandler() != null)
contextHandler = getServletHandler().getServletContextHandler();
if (contextHandler == null && ContextHandler.getCurrentContext() != null)
contextHandler = ContextHandler.getCurrentContext().getContextHandler();
if (contextHandler == null)
throw new IllegalStateException("No Context");

_listener = getInstance();
if (_listener == null)
{
_listener = getInstance();
if (_listener == null)
//create an instance of the listener and decorate it
try
{
//create an instance of the listener and decorate it
try
{
ServletContext context = contextHandler.getServletContext();
_listener = (context != null)
? context.createListener(getHeldClass())
: getHeldClass().getDeclaredConstructor().newInstance();
}
catch (ServletException ex)
{
Throwable cause = ex.getRootCause();
if (cause instanceof InstantiationException)
throw (InstantiationException)cause;
if (cause instanceof IllegalAccessException)
throw (IllegalAccessException)cause;
throw ex;
}
_listener = contextHandler.getServletContext().createListener(getHeldClass());
}
catch (ServletException ex)
{
throw ex;
}

_listener = wrap(_listener, WrapFunction.class, WrapFunction::wrapEventListener);
contextHandler.addEventListener(_listener);
}
contextHandler.addEventListener(_listener);
}

@Override
Expand Down
Expand Up @@ -773,6 +773,29 @@ public boolean isInitialized()
return _initialized;
}

protected void initializeHolders(BaseHolder<?>[] holders)
{
for (BaseHolder<?> holder : holders)
{
holder.setServletHandler(this);
if (isInitialized())
{
try
{
if (!holder.isStarted())
{
holder.start();
holder.initialize();
}
}
catch (Exception e)
{
throw new RuntimeException(e);
}
}
}
}

/**
* @return whether the filter chains are cached.
*/
Expand Down Expand Up @@ -800,10 +823,7 @@ public ListenerHolder[] getListeners()
public void setListeners(ListenerHolder[] listeners)
{
if (listeners != null)
for (ListenerHolder holder : listeners)
{
holder.setServletHandler(this);
}
initializeHolders(listeners);
gregw marked this conversation as resolved.
Show resolved Hide resolved
updateBeans(_listeners,listeners);
_listeners = listeners;
}
Expand Down Expand Up @@ -865,9 +885,6 @@ public void addServletWithMapping(ServletHolder servlet, String pathSpec)
{
Objects.requireNonNull(servlet);
ServletHolder[] holders = getServlets();
if (holders != null)
holders = holders.clone();

try
{
synchronized (this)
Expand Down Expand Up @@ -979,8 +996,6 @@ public void addFilterWithMapping(FilterHolder holder, String pathSpec, EnumSet<D
{
Objects.requireNonNull(holder);
FilterHolder[] holders = getFilters();
if (holders != null)
holders = holders.clone();

try
{
Expand Down Expand Up @@ -1435,16 +1450,6 @@ else if (isAllowDuplicateMappings())
LOG.debug("servletPathMap=" + _servletPathMap);
LOG.debug("servletNameMap=" + _servletNameMap);
}

try
{
if (_contextHandler != null && _contextHandler.isStarted() || _contextHandler == null && isStarted())
initialize();
}
catch (Exception e)
{
throw new RuntimeException(e);
}
}

protected void notFound(Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
Expand Down Expand Up @@ -1496,18 +1501,15 @@ public void setFilterMappings(FilterMapping[] filterMappings)
{
updateBeans(_filterMappings,filterMappings);
_filterMappings = filterMappings;
if (isStarted())
if (isRunning())
updateMappings();
invalidateChainsCache();
}

public synchronized void setFilters(FilterHolder[] holders)
{
if (holders != null)
for (FilterHolder holder : holders)
{
holder.setServletHandler(this);
}
initializeHolders(holders);
updateBeans(_filters,holders);
_filters = holders;
updateNameMappings();
Expand All @@ -1521,7 +1523,7 @@ public void setServletMappings(ServletMapping[] servletMappings)
{
updateBeans(_servletMappings,servletMappings);
_servletMappings = servletMappings;
if (isStarted())
if (isRunning())
updateMappings();
invalidateChainsCache();
}
Expand All @@ -1534,10 +1536,7 @@ public void setServletMappings(ServletMapping[] servletMappings)
public synchronized void setServlets(ServletHolder[] holders)
{
if (holders != null)
for (ServletHolder holder : holders)
{
holder.setServletHandler(this);
}
initializeHolders(holders);
updateBeans(_servlets,holders);
_servlets = holders;
updateNameMappings();
Expand Down
Expand Up @@ -1644,38 +1644,179 @@ else if ("delete".equalsIgnoreCase(action))
}
}

public static class TestPListener implements ServletRequestListener
{
@Override
public void requestInitialized(ServletRequestEvent sre)
{
ServletRequest request = sre.getServletRequest();
Integer count = (Integer)request.getAttribute("testRequestListener");
request.setAttribute("testRequestListener", count == null ? 1 : count + 1);
}

@Override
public void requestDestroyed(ServletRequestEvent sre)
{
}
}

@Test
public void testProgrammaticFilterServlet() throws Exception
public void testProgrammaticListener() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = new ServletHandler();
_server.setHandler(context);
context.setHandler(handler);
handler.addServletWithMapping(new ServletHolder(new TestServlet()), "/");

// Add a servlet to report number of listeners
handler.addServletWithMapping(new ServletHolder(new HttpServlet()
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getOutputStream().print("Listeners=" + req.getAttribute("testRequestListener"));
}
}), "/");

// Add a listener in STOPPED, STARTING and STARTED states
handler.addListener(new ListenerHolder(TestPListener.class));
handler.addServlet(new ServletHolder(new HttpServlet()
{
@Override
public void init() throws ServletException
{
handler.addListener(new ListenerHolder(TestPListener.class));
}
})
{
{
setInitOrder(1);
}
});
_server.start();
handler.addListener(new ListenerHolder(TestPListener.class));

String request =
"GET /test HTTP/1.0\n" +
"Host: localhost\n" +
"\n";
String response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("Listeners=3"));
}

public static class TestPFilter implements Filter
{
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
}

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException
{
Integer count = (Integer)request.getAttribute("testFilter");
request.setAttribute("testFilter", count == null ? 1 : count + 1);
chain.doFilter(request, response);
}

@Override
public void destroy()
{
}
}

@Test
public void testProgrammaticFilters() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = new ServletHandler();
_server.setHandler(context);
context.setHandler(handler);

// Add a servlet to report number of filters
handler.addServletWithMapping(new ServletHolder(new HttpServlet()
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getOutputStream().print("Filters=" + req.getAttribute("testFilter"));
}
}), "/");

// Add a filter in STOPPED, STARTING and STARTED states
handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST));
handler.addServlet(new ServletHolder(new HttpServlet()
{
@Override
public void init() throws ServletException
{
handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST));
}
})
{
{
setInitOrder(1);
}
});
_server.start();
handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST));

String request =
"GET /hello HTTP/1.0\n" +
"Host: localhost\n" +
"\n";
"GET /test HTTP/1.0\n" +
"Host: localhost\n" +
"\n";
String response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("Test"));
assertThat(response, containsString("Filters=3"));
}

public static class TestPServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getOutputStream().println(req.getRequestURI());
}
}

handler.addFilterWithMapping(new FilterHolder(new MyFilter()), "/*", EnumSet.of(DispatcherType.REQUEST));
handler.addServletWithMapping(new ServletHolder(new HelloServlet()), "/hello/*");
@Test
public void testProgrammaticServlets() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = new ServletHandler();
_server.setHandler(context);
context.setHandler(handler);

_server.dumpStdErr();
// Add a filter in STOPPED, STARTING and STARTED states
handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/one");
handler.addServlet(new ServletHolder(new HttpServlet()
{
@Override
public void init() throws ServletException
{
handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/two");
}
})
{
{
setInitOrder(1);
}
});
_server.start();
handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/three");

request =
"GET /hello HTTP/1.0\n" +
"Host: localhost\n" +
"\n";
String request = "GET /one HTTP/1.0\n" + "Host: localhost\n" + "\n";
String response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("/one"));
request = "GET /two HTTP/1.0\n" + "Host: localhost\n" + "\n";
response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("filter: filter"));
assertThat(response, containsString("Hello World"));
assertThat(response, containsString("/two"));
request = "GET /three HTTP/1.0\n" + "Host: localhost\n" + "\n";
response = _connector.getResponse(request);
assertThat(response, containsString("200 OK"));
assertThat(response, containsString("/three"));
}
}