Skip to content

Commit

Permalink
Issue #5104 - Fix protocol version in Via header to work with H2 and …
Browse files Browse the repository at this point in the history
…other protocols

Signed-off-by: Travis Spencer <travis@curity.io>
  • Loading branch information
travisspencer authored and sbordet committed Aug 12, 2020
1 parent 4a0af04 commit b2ab05c
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 4 deletions.
Expand Up @@ -109,6 +109,7 @@ public abstract class AbstractProxyServlet extends HttpServlet
private String _viaHost;
private HttpClient _client;
private long _timeout;
private boolean oldAddViaHeaderCalled;

@Override
public void init() throws ServletException
Expand Down Expand Up @@ -167,6 +168,9 @@ public String getHostHeader()

public String getViaHost()
{
if (_viaHost == null)
_viaHost = viaHost();

return _viaHost;
}

Expand Down Expand Up @@ -509,13 +513,59 @@ protected Set<String> findConnectionHeaders(HttpServletRequest clientRequest)

protected void addProxyHeaders(HttpServletRequest clientRequest, Request proxyRequest)
{
addViaHeader(proxyRequest);
addViaHeader(clientRequest, proxyRequest);
addXForwardedHeaders(clientRequest, proxyRequest);
}

/**
* Adds the HTTP Via header to the proxied request.
*
* @deprecated Use {@link #addViaHeader(HttpServletRequest, Request)} instead.
* @param proxyRequest the request being proxied
*/
@Deprecated
protected void addViaHeader(Request proxyRequest)
{
proxyRequest.header(HttpHeader.VIA, "http/1.1 " + getViaHost());
oldAddViaHeaderCalled = true;
}

/**
* Adds the HTTP Via header to the proxied request, taking into account data present in the client request.
* This method considers the protocol of the client request when forming the proxied request. If it
* is HTTP, then the protocol name will not be included in the Via header that is sent by the proxy, and only
* the protocol version will be sent. If it is not, the entire protocol (name and version) will be included.
* If the client request includes a Via header, the result will be appended to that to form a chain.
*
* @param clientRequest the client request
* @param proxyRequest the request being proxied
* @see <a href="https://tools.ietf.org/html/rfc7230#section-5.7.1">RFC 7230 section 5.7.1</a>
*/
protected void addViaHeader(HttpServletRequest clientRequest, Request proxyRequest)
{
// For backward compatibility reasons, call old, deprecated version of this method.
// If our flag isn't set, the deprecated method was overridden and we shouldn't do
// anything more.

oldAddViaHeaderCalled = false;
addViaHeader(proxyRequest);

if (!oldAddViaHeaderCalled)
return; // Old method was overridden, so bail out.

// Old version of this method wasn't overridden, so do the new logic instead.

String protocol = clientRequest.getProtocol();
String[] parts = protocol.split("/", 2);
String protocolName = parts.length == 2 && "HTTP".equals(parts[0]) ? parts[1] : protocol;
String viaHeaderValue = "";
String clientViaHeader = clientRequest.getHeader(HttpHeader.VIA.name());

if (clientViaHeader != null)
viaHeaderValue = clientViaHeader;

viaHeaderValue += protocolName + " " + getViaHost();

proxyRequest.header(HttpHeader.VIA, viaHeaderValue);
}

protected void addXForwardedHeaders(HttpServletRequest clientRequest, Request proxyRequest)
Expand Down
Expand Up @@ -26,6 +26,8 @@
import java.io.PrintWriter;
import java.net.ConnectException;
import java.net.HttpCookie;
import java.net.InetAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -63,6 +65,7 @@
import org.eclipse.jetty.client.DuplexConnectionPool;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.HttpProxy;
import org.eclipse.jetty.client.HttpRequest;
import org.eclipse.jetty.client.api.ContentResponse;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.client.api.Response;
Expand Down Expand Up @@ -116,6 +119,42 @@ public static Stream<Arguments> impls()
).map(Arguments::of);
}

public static Stream<Arguments> implsWithProtocols()
{
String[] protocols = {"HTTP/1.1", "HTTP/2.0", "OTHER/0.9"};

return impls()
.flatMap(impl -> Arrays.stream(protocols)
.flatMap(p -> Stream.of(Arguments.of(impl.get()[0], p))));
}

public static Stream<Arguments> subclassesWithProtocols()
{
ProxyServlet subclass1 = new ProxyServlet()
{
@Override
protected void addViaHeader(Request proxyRequest)
{
System.err.println("addViaHeader called: " + proxyRequest);
super.addViaHeader(proxyRequest);
}
};
String proto = "MY_GOOD_PROTO/0.8";
ProxyServlet subclass2 = new ProxyServlet()
{
@Override
protected void addViaHeader(Request proxyRequest)
{
proxyRequest.header(HttpHeader.VIA, proto + " " + getViaHost());
}
};

return Stream.of(
Arguments.of(subclass1, "1.1"), // HTTP 1.1 used by this proxy (w/ the connector created in startServer)
Arguments.of(subclass2, proto)
);
}

private HttpClient client;
private Server proxy;
private ServerConnector proxyConnector;
Expand Down Expand Up @@ -145,6 +184,12 @@ private void startProxy(Class<? extends ProxyServlet> proxyServletClass) throws
}

private void startProxy(Class<? extends ProxyServlet> proxyServletClass, Map<String, String> initParams) throws Exception
{
proxyServlet = proxyServletClass.getDeclaredConstructor().newInstance();
startProxy(proxyServlet, initParams);
}

private void startProxy(AbstractProxyServlet proxyServlet, Map<String, String> initParams) throws Exception
{
QueuedThreadPool proxyPool = new QueuedThreadPool();
proxyPool.setName("proxy");
Expand All @@ -159,8 +204,6 @@ private void startProxy(Class<? extends ProxyServlet> proxyServletClass, Map<Str
proxyConnector = new ServerConnector(proxy, new HttpConnectionFactory(configuration));
proxy.addConnector(proxyConnector);

proxyServlet = proxyServletClass.getDeclaredConstructor().newInstance();

proxyContext = new ServletContextHandler(proxy, "/", true, false);
ServletHolder proxyServletHolder = new ServletHolder(proxyServlet);
proxyServletHolder.setInitParameters(initParams);
Expand All @@ -185,6 +228,26 @@ private HttpClient prepareClient() throws Exception
return result;
}

private static HttpServletRequest mockClientRequest(String protocol)
{
return new org.eclipse.jetty.server.Request(null, null)
{
@Override
public String getProtocol()
{
return protocol;
}
};
}

private static HttpRequest mockProxyRequest()
{
return new HttpRequest(new HttpClient(), null, URI.create("https://example.com"))
{

};
}

@AfterEach
public void dispose() throws Exception
{
Expand Down Expand Up @@ -549,6 +612,79 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws Se
Matchers.equalTo("localhost:" + serverConnector.getLocalPort()));
}

@ParameterizedTest
@MethodSource("subclassesWithProtocols")
public void testInheritance(ProxyServlet derivedProxyServlet, String protocol) throws Exception
{
startServer(new HttpServlet()
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
PrintWriter writer = resp.getWriter();
writer.write(req.getHeader("Via"));
writer.flush();
}
});
String viaHost = "my-good-via-host.example.org";
startProxy(derivedProxyServlet, Collections.singletonMap("viaHost", viaHost));
startClient();

HttpRequest proxyRequest = mockProxyRequest();
derivedProxyServlet.addViaHeader(proxyRequest);

ContentResponse response = client.GET("http://localhost:" + serverConnector.getLocalPort());
String expectedVia = protocol + " " + viaHost;

assertThat("Response expected to contain content of Via Header from the request",
response.getContentAsString(),
Matchers.equalTo(expectedVia));
}

@ParameterizedTest
@MethodSource("impls")
public void testProxyViaHeaderIsPresent(Class<? extends ProxyServlet> proxyServletClass) throws Exception
{
startServer(new HttpServlet()
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
PrintWriter writer = resp.getWriter();
writer.write(req.getHeader("Via"));
writer.flush();
}
});
String viaHost = "my-good-via-host.example.org";
startProxy(proxyServletClass, Collections.singletonMap("viaHost", viaHost));
startClient();

ContentResponse response = client.GET("http://localhost:" + serverConnector.getLocalPort());
assertThat("Response expected to contain content of Via Header from the request",
response.getContentAsString(),
Matchers.equalTo("1.1 " + viaHost));
}

@ParameterizedTest
@MethodSource("implsWithProtocols")
public void testProxyViaHeaderForVariousProtocols(Class<? extends ProxyServlet> proxyServletClass, String protocol) throws Exception
{
AbstractProxyServlet proxyServlet = proxyServletClass.getDeclaredConstructor().newInstance();
String host = InetAddress.getLocalHost().getHostName();
HttpServletRequest clientRequest = mockClientRequest(protocol);
HttpRequest proxyRequest = mockProxyRequest();

proxyServlet.addViaHeader(clientRequest, proxyRequest);

String expectedProtocol = protocol.startsWith("HTTP") ? protocol.split("/", 2)[1] : protocol;
String expectedVia = expectedProtocol + " " + host;
String expectedViaWithLocalhost = expectedProtocol + " localhost";

assertThat("Response expected to contain a Via header with the right protocol version and host",
proxyRequest.getHeaders().getField("Via").getValue(),
Matchers.anyOf(Matchers.equalTo(expectedVia), Matchers.equalTo(expectedViaWithLocalhost)));
}

@ParameterizedTest
@MethodSource("impls")
public void testProxyWhiteList(Class<? extends ProxyServlet> proxyServletClass) throws Exception
Expand Down

0 comments on commit b2ab05c

Please sign in to comment.