Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ThreadLimitHandler #11723

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
Expand All @@ -31,6 +30,7 @@
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.QuotedCSV;
import org.eclipse.jetty.io.Retainable;
import org.eclipse.jetty.server.ForwardedRequestCustomizer;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
Expand Down Expand Up @@ -68,7 +68,7 @@ public class ThreadLimitHandler extends ConditionalHandler.Abstract

private final boolean _rfc7239;
private final String _forwardedHeader;
private final ConcurrentMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private volatile boolean _enabled;
private int _threadLimit = 10;

Expand Down Expand Up @@ -163,7 +163,10 @@ public boolean onConditionsMet(Request request, Response response, Callback call
}

// We accept the request and will always handle it.
LimitedRequest limitedRequest = new LimitedRequest(remote, next, request, response, callback);
// Use a compute method to remove the Remote instance as it is necessary for
// the ref counter release and the removal to be atomic.
LimitedRequest limitedRequest = new LimitedRequest(remote, next, request, response, Callback.from(callback, () ->
_remotes.computeIfPresent(remote._ip, (k, v) -> v._referenceCounter.release() ? null : v)));
limitedRequest.handle();
return true;
}
Expand All @@ -177,23 +180,27 @@ protected boolean onConditionsNotMet(Request request, Response response, Callbac
private Remote getRemote(Request baseRequest)
{
String ip = getRemoteIP(baseRequest);
LOG.debug("ip={}", ip);
if (LOG.isDebugEnabled())
LOG.debug("ip={}", ip);
if (ip == null)
return null;

int limit = getThreadLimit(ip);
if (limit <= 0)
return null;

Remote remote = _remotes.get(ip);
if (remote == null)
// Use a compute method to create or retain the Remote instance as it is necessary for
// the ref counter increment or the instance creation to be mutually exclusive.
// The map MUST be a CHM as it guarantees the remapping function is only called once.
return _remotes.compute(ip, (k, v) ->
{
Remote r = new Remote(baseRequest.getContext(), ip, limit);
remote = _remotes.putIfAbsent(ip, r);
if (remote == null)
remote = r;
}
return remote;
if (v != null)
{
v._referenceCounter.retain();
return v;
}
return new Remote(baseRequest.getContext(), k, limit);
});
}

protected String getRemoteIP(Request baseRequest)
Expand All @@ -208,7 +215,7 @@ protected String getRemoteIP(Request baseRequest)
}

// If no remote IP from a header, determine it directly from the channel
// Do not use the request methods, as they may have been lied to by the
// Do not use the request methods, as they may have been lied to by the
// RequestCustomizer!
if (baseRequest.getConnectionMetaData().getRemoteSocketAddress() instanceof InetSocketAddress inetAddr)
{
Expand Down Expand Up @@ -255,7 +262,12 @@ private String getXForwardedFor(Request request)
int comma = forwardedFor.lastIndexOf(',');
return (comma >= 0) ? forwardedFor.substring(comma + 1).trim() : forwardedFor;
}


int getRemoteCount()
{
return _remotes.size();
}

private static class LimitedRequest extends Request.Wrapper
{
private final Remote _remote;
Expand Down Expand Up @@ -517,6 +529,7 @@ public void release()
private static final class Remote
{
private final Executor _executor;
private final Retainable.ReferenceCounter _referenceCounter = new Retainable.ReferenceCounter();
private final String _ip;
private final int _limit;
private final AutoLock _lock = new AutoLock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ public boolean handle(Request request, Response response, Callback callback)
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n");
assertThat(last.get(), is("0.0.0.0"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand Down Expand Up @@ -147,6 +149,8 @@ public boolean handle(Request request, Response response, Callback callback)
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nX-Forwarded-For: 6.6.6.6,1.2.3.4\r\nForwarded: for=1.2.3.4\r\n\r\n");
assertThat(last.get(), is("1.2.3.4"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand Down Expand Up @@ -190,6 +194,8 @@ public boolean handle(Request request, Response response, Callback callback)
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nForwarded: for=6.6.6.6; for=1.2.3.4\r\nX-Forwarded-For: 6.6.6.6\r\nForwarded: proto=https\r\n\r\n");
assertThat(last.get(), is("1.2.3.4"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand Down Expand Up @@ -248,6 +254,8 @@ public boolean handle(Request request, Response response, Callback callback) thr

await().atMost(10, TimeUnit.SECONDS).until(total::get, is(10));
await().atMost(10, TimeUnit.SECONDS).until(count::get, is(0));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand Down Expand Up @@ -367,5 +375,7 @@ public void run()
assertThat(response, containsString(" 200 OK"));
assertThat(response, containsString(" read 2"));
}

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
import java.util.Deque;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;

import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequestEvent;
import jakarta.servlet.ServletRequestListener;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HostPortHttpField;
import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.QuotedCSV;
import org.eclipse.jetty.io.Retainable;
import org.eclipse.jetty.server.ForwardedRequestCustomizer;
import org.eclipse.jetty.util.IncludeExcludeSet;
import org.eclipse.jetty.util.InetAddressSet;
Expand Down Expand Up @@ -72,7 +74,7 @@ public class ThreadLimitHandler extends HandlerWrapper
private final boolean _rfc7239;
private final String _forwardedHeader;
private final IncludeExcludeSet<String, InetAddress> _includeExcludeSet = new IncludeExcludeSet<>(InetAddressSet.class);
private final ConcurrentMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private volatile boolean _enabled;
private int _threadLimit = 10;

Expand Down Expand Up @@ -178,6 +180,17 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
}
else
{
baseRequest.addEventListener(new ServletRequestListener()
{
@Override
public void requestDestroyed(ServletRequestEvent sre)
{
// Use a compute method to remove the Remote instance as it is necessary for
// the ref counter release and the removal to be atomic.
_remotes.computeIfPresent(remote._ip, (k, v) -> v._referenceCounter.release() ? null : v);
}
});

// Do we already have a future permit from a previous invocation?
Closeable permit = (Closeable)baseRequest.getAttribute(PERMIT);
try
Expand Down Expand Up @@ -249,14 +262,18 @@ private Remote getRemote(Request baseRequest)
if (limit <= 0)
return null;

remote = _remotes.get(ip);
if (remote == null)
// Use a compute method to create or retain the Remote instance as it is necessary for
// the ref counter increment or the instance creation to be mutually exclusive.
// The map MUST be a CHM as it guarantees the remapping function is only called once.
remote = _remotes.compute(ip, (k, v) ->
{
Remote r = new Remote(ip, limit);
remote = _remotes.putIfAbsent(ip, r);
if (remote == null)
remote = r;
}
if (v != null)
{
v._referenceCounter.retain();
return v;
}
return new Remote(k, limit);
});

baseRequest.setAttribute(REMOTE, remote);

Expand Down Expand Up @@ -325,6 +342,7 @@ private static final class Remote implements Closeable
private final String _ip;
private final int _limit;
private final AutoLock _lock = new AutoLock();
private final Retainable.ReferenceCounter _referenceCounter = new Retainable.ReferenceCounter();
private int _permits;
private Deque<CompletableFuture<Closeable>> _queue = new ArrayDeque<>();
private final CompletableFuture<Closeable> _permitted = CompletableFuture.completedFuture(this);
Expand Down