Skip to content

Commit

Permalink
Fixes #5104 - AbstractProxyServlet include incorrect protocol version…
Browse files Browse the repository at this point in the history
… in Via header when accessed over H2.

* Introduced HttpFields.computeField() to put/append header values.
* Reworked AbstractProxyServlet.addViaHeader().

Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
  • Loading branch information
sbordet committed Aug 12, 2020
1 parent b2ab05c commit 79d340f
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 159 deletions.
126 changes: 126 additions & 0 deletions jetty-http/src/main/java/org/eclipse/jetty/http/HttpFields.java
Expand Up @@ -32,6 +32,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.StringTokenizer;
import java.util.function.BiFunction;
import java.util.function.ToIntFunction;
import java.util.stream.Stream;

Expand Down Expand Up @@ -88,6 +89,115 @@ public HttpFields(HttpFields fields)
_size = fields._size;
}

/**
* <p>Computes a single field for the given HTTP header name and for existing fields with the same name.</p>
*
* <p>The compute function receives the field name and a list of fields with the same name
* so that their values can be used to compute the value of the field that is returned
* by the compute function.
* If the compute function returns {@code null}, the fields with the given name are removed.</p>
* <p>This method comes handy when you want to add an HTTP header if it does not exist,
* or add a value if the HTTP header already exists, similarly to
* {@link Map#compute(Object, BiFunction)}.</p>
*
* <p>This method can be used to {@link #put(HttpField) put} a new field (or blindly replace its value):</p>
* <pre>
* httpFields.computeField("X-New-Header",
* (name, fields) -> new HttpField(name, "NewValue"));
* </pre>
*
* <p>This method can be used to coalesce many fields into one:</p>
* <pre>
* // Input:
* GET / HTTP/1.1
* Host: localhost
* Cookie: foo=1
* Cookie: bar=2,baz=3
* User-Agent: Jetty
*
* // Computation:
* httpFields.computeField("Cookie", (name, fields) ->
* {
* // No cookies, nothing to do.
* if (fields == null)
* return null;
*
* // Coalesces all cookies.
* String coalesced = fields.stream()
* .flatMap(field -> Stream.of(field.getValues()))
* .collect(Collectors.joining(", "));
*
* // Returns a single Cookie header with all cookies.
* return new HttpField(name, coalesced);
* }
*
* // Output:
* GET / HTTP/1.1
* Host: localhost
* Cookie: foo=1, bar=2, baz=3
* User-Agent: Jetty
* </pre>
*
* <p>This method can be used to replace a field:</p>
* <pre>
* httpFields.computeField("X-Length", (name, fields) ->
* {
* if (fields == null)
* return null;
*
* // Get any value among the X-Length headers.
* String length = fields.stream()
* .map(HttpField::getValue)
* .findAny()
* .orElse("0");
*
* // Replace X-Length headers with X-Capacity header.
* return new HttpField("X-Capacity", length);
* });
* </pre>
*
* <p>This method can be used to remove a field:</p>
* <pre>
* httpFields.computeField("Connection", (name, fields) -> null);
* </pre>
*
* @param name the HTTP header name
* @param computeFn the compute function
*/
public void computeField(String name, BiFunction<String, List<HttpField>, HttpField> computeFn)
{
boolean found = false;
ListIterator<HttpField> iterator = listIterator();
while (iterator.hasNext())
{
HttpField field = iterator.next();
if (field.getName().equalsIgnoreCase(name))
{
if (found)
{
// Remove other headers with the same name, since
// we have computed one from all of them already.
iterator.remove();
}
else
{
found = true;
HttpField newField = computeFn.apply(name, Collections.unmodifiableList(getFields(name)));
if (newField == null)
iterator.remove();
else
iterator.set(newField);
}
}
}
if (!found)
{
HttpField newField = computeFn.apply(name, null);
if (newField != null)
put(newField);
}
}

public int size()
{
return _size;
Expand Down Expand Up @@ -189,6 +299,22 @@ public List<HttpField> getFields(HttpHeader header)
return fields == null ? Collections.emptyList() : fields;
}

public List<HttpField> getFields(String name)
{
List<HttpField> fields = null;
for (int i = 0; i < _size; i++)
{
HttpField f = _fields[i];
if (f.getName().equalsIgnoreCase(name))
{
if (fields == null)
fields = new ArrayList<>();
fields.add(f);
}
}
return fields == null ? Collections.emptyList() : fields;
}

public boolean contains(HttpField field)
{
for (int i = _size; i-- > 0; )
Expand Down
Expand Up @@ -29,6 +29,8 @@
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.servlet.AsyncContext;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
Expand All @@ -48,6 +50,7 @@
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpHeaderValue;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.util.HttpCookieStore;
import org.eclipse.jetty.util.ProcessorUtils;
import org.eclipse.jetty.util.StringUtil;
Expand Down Expand Up @@ -109,7 +112,6 @@ public abstract class AbstractProxyServlet extends HttpServlet
private String _viaHost;
private HttpClient _client;
private long _timeout;
private boolean oldAddViaHeaderCalled;

@Override
public void init() throws ServletException
Expand Down Expand Up @@ -168,9 +170,6 @@ public String getHostHeader()

public String getViaHost()
{
if (_viaHost == null)
_viaHost = viaHost();

return _viaHost;
}

Expand Down Expand Up @@ -456,6 +455,14 @@ protected boolean expects100Continue(HttpServletRequest request)
return HttpHeaderValue.CONTINUE.is(request.getHeader(HttpHeader.EXPECT.asString()));
}

protected Request newProxyRequest(HttpServletRequest request, String rewrittenTarget)
{
return getHttpClient().newRequest(rewrittenTarget)
.method(request.getMethod())
.version(HttpVersion.fromString(request.getProtocol()))
.attribute(CLIENT_REQUEST_ATTRIBUTE, request);
}

protected void copyRequestHeaders(HttpServletRequest clientRequest, Request proxyRequest)
{
// First clear possibly existing headers, as we are going to copy those from the client request.
Expand Down Expand Up @@ -513,59 +520,54 @@ protected Set<String> findConnectionHeaders(HttpServletRequest clientRequest)

protected void addProxyHeaders(HttpServletRequest clientRequest, Request proxyRequest)
{
addViaHeader(clientRequest, proxyRequest);
addViaHeader(proxyRequest);
addXForwardedHeaders(clientRequest, proxyRequest);
}

/**
* Adds the HTTP Via header to the proxied request.
* Adds the HTTP {@code Via} header to the proxied request.
*
* @deprecated Use {@link #addViaHeader(HttpServletRequest, Request)} instead.
* @param proxyRequest the request being proxied
* @see #addViaHeader(HttpServletRequest, Request)
*/
@Deprecated
protected void addViaHeader(Request proxyRequest)
{
oldAddViaHeaderCalled = true;
HttpServletRequest clientRequest = (HttpServletRequest)proxyRequest.getAttributes().get(CLIENT_REQUEST_ATTRIBUTE);
addViaHeader(clientRequest, proxyRequest);
}

/**
* Adds the HTTP Via header to the proxied request, taking into account data present in the client request.
* This method considers the protocol of the client request when forming the proxied request. If it
* is HTTP, then the protocol name will not be included in the Via header that is sent by the proxy, and only
* <p>Adds the HTTP {@code Via} header to the proxied request, taking into account data present in the client request.</p>
* <p>This method considers the protocol of the client request when forming the proxied request. If it
* is HTTP, then the protocol name will not be included in the {@code Via} header that is sent by the proxy, and only
* the protocol version will be sent. If it is not, the entire protocol (name and version) will be included.
* If the client request includes a Via header, the result will be appended to that to form a chain.
* If the client request includes a {@code Via} header, the result will be appended to that to form a chain.</p>
*
* @param clientRequest the client request
* @param proxyRequest the request being proxied
* @see <a href="https://tools.ietf.org/html/rfc7230#section-5.7.1">RFC 7230 section 5.7.1</a>
*/
protected void addViaHeader(HttpServletRequest clientRequest, Request proxyRequest)
{
// For backward compatibility reasons, call old, deprecated version of this method.
// If our flag isn't set, the deprecated method was overridden and we shouldn't do
// anything more.

oldAddViaHeaderCalled = false;
addViaHeader(proxyRequest);

if (!oldAddViaHeaderCalled)
return; // Old method was overridden, so bail out.

// Old version of this method wasn't overridden, so do the new logic instead.

String protocol = clientRequest.getProtocol();
String[] parts = protocol.split("/", 2);
String protocolName = parts.length == 2 && "HTTP".equals(parts[0]) ? parts[1] : protocol;
String viaHeaderValue = "";
String clientViaHeader = clientRequest.getHeader(HttpHeader.VIA.name());

if (clientViaHeader != null)
viaHeaderValue = clientViaHeader;

viaHeaderValue += protocolName + " " + getViaHost();

proxyRequest.header(HttpHeader.VIA, viaHeaderValue);
// Retain only the version if the protocol is HTTP.
String protocolPart = parts.length == 2 && "HTTP".equalsIgnoreCase(parts[0]) ? parts[1] : protocol;
String viaHeaderValue = protocolPart + " " + getViaHost();
proxyRequest.getHeaders().computeField(HttpHeader.VIA.asString(), (name, viaFields) ->
{
if (viaFields == null || viaFields.isEmpty())
return new HttpField(name, viaHeaderValue);
String separator = ", ";
String newValue = viaFields.stream()
.flatMap(field -> Stream.of(field.getValues()))
.filter(value -> !StringUtil.isBlank(value))
.collect(Collectors.joining(separator));
if (newValue.length() > 0)
newValue += separator;
newValue += viaHeaderValue;
return new HttpField(HttpHeader.VIA, newValue);
});
}

protected void addXForwardedHeaders(HttpServletRequest clientRequest, Request proxyRequest)
Expand Down
Expand Up @@ -47,7 +47,6 @@
import org.eclipse.jetty.client.api.Result;
import org.eclipse.jetty.client.util.DeferredContentProvider;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.RuntimeIOException;
import org.eclipse.jetty.util.BufferUtil;
Expand Down Expand Up @@ -92,9 +91,7 @@ protected void service(HttpServletRequest clientRequest, HttpServletResponse pro
return;
}

final Request proxyRequest = getHttpClient().newRequest(rewrittenTarget)
.method(clientRequest.getMethod())
.version(HttpVersion.fromString(clientRequest.getProtocol()));
Request proxyRequest = newProxyRequest(clientRequest, rewrittenTarget);

copyRequestHeaders(clientRequest, proxyRequest);

Expand All @@ -115,7 +112,6 @@ protected void service(HttpServletRequest clientRequest, HttpServletResponse pro

if (expects100Continue(clientRequest))
{
proxyRequest.attribute(CLIENT_REQUEST_ATTRIBUTE, clientRequest);
proxyRequest.attribute(CONTINUE_ACTION_ATTRIBUTE, (Runnable)() ->
{
try
Expand Down
Expand Up @@ -38,7 +38,6 @@
import org.eclipse.jetty.client.api.Result;
import org.eclipse.jetty.client.util.DeferredContentProvider;
import org.eclipse.jetty.client.util.InputStreamContentProvider;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IteratingCallback;

Expand Down Expand Up @@ -76,9 +75,7 @@ protected void service(final HttpServletRequest request, final HttpServletRespon
return;
}

final Request proxyRequest = getHttpClient().newRequest(rewrittenTarget)
.method(request.getMethod())
.version(HttpVersion.fromString(request.getProtocol()));
Request proxyRequest = newProxyRequest(request, rewrittenTarget);

copyRequestHeaders(request, proxyRequest);

Expand All @@ -95,7 +92,6 @@ protected void service(final HttpServletRequest request, final HttpServletRespon
{
DeferredContentProvider deferred = new DeferredContentProvider();
proxyRequest.content(deferred);
proxyRequest.attribute(CLIENT_REQUEST_ATTRIBUTE, request);
proxyRequest.attribute(CONTINUE_ACTION_ATTRIBUTE, (Runnable)() ->
{
try
Expand Down

0 comments on commit 79d340f

Please sign in to comment.