Skip to content

Commit

Permalink
Merge pull request #4429 from eclipse/jetty-9.4.x-4421-httpclient_pro…
Browse files Browse the repository at this point in the history
…xy_protocol2

Fixes #4421 - HttpClient support for PROXY protocol.
  • Loading branch information
sbordet committed Dec 17, 2019
2 parents 129a51c + bea7f1a commit fcc18b0
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 33 deletions.
Expand Up @@ -23,8 +23,8 @@
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -206,7 +206,7 @@ protected ProxyProtocolConnection newProxyProtocolConnection(EndPoint endPoint,
InetSocketAddress local = endPoint.getLocalAddress();
InetSocketAddress remote = endPoint.getRemoteAddress();
boolean ipv4 = local.getAddress() instanceof Inet4Address;
tag = new Tag(Tag.Command.PROXY, ipv4 ? Tag.Family.INET4 : Tag.Family.INET6, Tag.Protocol.STREAM, local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort());
tag = new Tag(Tag.Command.PROXY, ipv4 ? Tag.Family.INET4 : Tag.Family.INET6, Tag.Protocol.STREAM, local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort(), null);
}
return new ProxyProtocolConnectionV2(endPoint, executor, getClientConnectionFactory(), context, tag);
}
Expand All @@ -223,7 +223,7 @@ public static class Tag implements ClientConnectionFactory.Decorator
/**
* The PROXY V2 Tag typically used to "ping" the server.
*/
public static final Tag LOCAL = new Tag(Command.LOCAL, Family.UNSPEC, Protocol.UNSPEC, null, 0, null, 0);
public static final Tag LOCAL = new Tag(Command.LOCAL, Family.UNSPEC, Protocol.UNSPEC, null, 0, null, 0, null);

private Command command;
private Family family;
Expand All @@ -232,7 +232,7 @@ public static class Tag implements ClientConnectionFactory.Decorator
private int srcPort;
private String dstIP;
private int dstPort;
private Map<Integer, byte[]> vectors;
private List<TLV> tlvs;

/**
* <p>Creates a Tag whose metadata will be derived from the underlying EndPoint.</p>
Expand All @@ -251,7 +251,20 @@ public Tag()
*/
public Tag(String srcIP, int srcPort)
{
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0);
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0, null);
}

/**
* <p>Creates a Tag with the given source metadata and Type-Length-Value (TLV) objects.</p>
* <p>The destination metadata will be derived from the underlying EndPoint.</p>
*
* @param srcIP the source IP address
* @param srcPort the source port
* @param tlvs the TLV objects
*/
public Tag(String srcIP, int srcPort, List<TLV> tlvs)
{
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0, tlvs);
}

/**
Expand All @@ -264,8 +277,9 @@ public Tag(String srcIP, int srcPort)
* @param srcPort the source port
* @param dstIP the destination IP address
* @param dstPort the destination port
* @param tlvs the TLV objects
*/
public Tag(Command command, Family family, Protocol protocol, String srcIP, int srcPort, String dstIP, int dstPort)
public Tag(Command command, Family family, Protocol protocol, String srcIP, int srcPort, String dstIP, int dstPort, List<TLV> tlvs)
{
this.command = command;
this.family = family;
Expand All @@ -274,17 +288,7 @@ public Tag(Command command, Family family, Protocol protocol, String srcIP, int
this.srcPort = srcPort;
this.dstIP = dstIP;
this.dstPort = dstPort;
}

public void put(int type, byte[] data)
{
if (type < 0 || type > 255)
throw new IllegalArgumentException("Invalid type: " + type);
if (data != null && data.length > 65535)
throw new IllegalArgumentException("Invalid data length: " + data.length);
if (vectors == null)
vectors = new HashMap<>();
vectors.put(type, data);
this.tlvs = tlvs;
}

public Command getCommand()
Expand Down Expand Up @@ -322,9 +326,9 @@ public int getDestinationPort()
return dstPort;
}

public Map<Integer, byte[]> getVectors()
public List<TLV> getTLVs()
{
return vectors != null ? vectors : Collections.emptyMap();
return tlvs;
}

@Override
Expand All @@ -347,13 +351,14 @@ public boolean equals(Object obj)
Objects.equals(srcIP, that.srcIP) &&
srcPort == that.srcPort &&
Objects.equals(dstIP, that.dstIP) &&
dstPort == that.dstPort;
dstPort == that.dstPort &&
Objects.equals(tlvs, that.tlvs);
}

@Override
public int hashCode()
{
return Objects.hash(command, family, protocol, srcIP, srcPort, dstIP, dstPort);
return Objects.hash(command, family, protocol, srcIP, srcPort, dstIP, dstPort, tlvs);
}

public enum Command
Expand All @@ -370,6 +375,51 @@ public enum Protocol
{
UNSPEC, STREAM, DGRAM
}

public static class TLV
{
private final int type;
private final byte[] value;

public TLV(int type, byte[] value)
{
if (type < 0 || type > 255)
throw new IllegalArgumentException("Invalid type: " + type);
if (value != null && value.length > 65535)
throw new IllegalArgumentException("Invalid value length: " + value.length);
this.type = type;
this.value = Objects.requireNonNull(value);
}

public int getType()
{
return type;
}

public byte[] getValue()
{
return value;
}

@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
TLV that = (TLV)obj;
return type == that.type && Arrays.equals(value, that.value);
}

@Override
public int hashCode()
{
int result = Objects.hash(type);
result = 31 * result + Arrays.hashCode(value);
return result;
}
}
}
}

Expand Down Expand Up @@ -533,9 +583,9 @@ protected void writePROXYBytes(EndPoint endPoint, Callback callback)
capacity += 1; // family and protocol
capacity += 2; // length
capacity += 216; // max address length
Map<Integer, byte[]> vectors = tag.getVectors();
int vectorsLength = vectors.values().stream()
.mapToInt(data -> 1 + 2 + data.length)
List<V2.Tag.TLV> tlvs = tag.getTLVs();
int vectorsLength = tlvs == null ? 0 : tlvs.stream()
.mapToInt(tlv -> 1 + 2 + tlv.getValue().length)
.sum();
capacity += vectorsLength;
ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
Expand Down Expand Up @@ -602,12 +652,15 @@ protected void writePROXYBytes(EndPoint endPoint, Callback callback)
default:
throw new IllegalStateException();
}
for (Map.Entry<Integer, byte[]> entry : vectors.entrySet())
if (tlvs != null)
{
buffer.put(entry.getKey().byteValue());
byte[] data = entry.getValue();
buffer.putShort((short)data.length);
buffer.put(data);
for (V2.Tag.TLV tlv : tlvs)
{
buffer.put((byte)tlv.getType());
byte[] data = tlv.getValue();
buffer.putShort((short)data.length);
buffer.put(data);
}
}
buffer.flip();
endPoint.write(callback, buffer);
Expand Down
Expand Up @@ -20,6 +20,7 @@

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -174,7 +175,8 @@ protected void service(String target, Request jettyRequest, HttpServletRequest r
EndPoint endPoint = jettyRequest.getHttpChannel().getEndPoint();
assertTrue(endPoint instanceof ProxyConnectionFactory.ProxyEndPoint);
ProxyConnectionFactory.ProxyEndPoint proxyEndPoint = (ProxyConnectionFactory.ProxyEndPoint)endPoint;
assertEquals(tlsVersion, proxyEndPoint.getAttribute(ProxyConnectionFactory.TLS_VERSION));
if (target.equals("/tls_version"))
assertEquals(tlsVersion, proxyEndPoint.getAttribute(ProxyConnectionFactory.TLS_VERSION));
response.setContentType(MimeTypes.Type.TEXT_PLAIN.asString());
response.getOutputStream().print(request.getRemotePort());
}
Expand All @@ -184,21 +186,34 @@ protected void service(String target, Request jettyRequest, HttpServletRequest r
int serverPort = connector.getLocalPort();

int clientPort = ThreadLocalRandom.current().nextInt(1024, 65536);
V2.Tag tag = new V2.Tag("127.0.0.1", clientPort);
int typeTLS = 0x20;
byte[] dataTLS = new byte[1 + 4 + (1 + 2 + tlsVersionBytes.length)];
dataTLS[0] = 0x01; // CLIENT_SSL
dataTLS[5] = 0x21; // SUBTYPE_SSL_VERSION
dataTLS[6] = 0x00; // Length, hi byte
dataTLS[7] = (byte)tlsVersionBytes.length; // Length, lo byte
System.arraycopy(tlsVersionBytes, 0, dataTLS, 8, tlsVersionBytes.length);
tag.put(typeTLS, dataTLS);
V2.Tag.TLV tlv = new V2.Tag.TLV(typeTLS, dataTLS);
V2.Tag tag = new V2.Tag("127.0.0.1", clientPort, Collections.singletonList(tlv));

ContentResponse response = client.newRequest("localhost", serverPort)
.path("/tls_version")
.tag(tag)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());

// Make another request with the same address information, but different TLV.
V2.Tag.TLV tlv2 = new V2.Tag.TLV(0x01, "http/1.1".getBytes(StandardCharsets.UTF_8));
V2.Tag tag2 = new V2.Tag("127.0.0.1", clientPort, Collections.singletonList(tlv2));
response = client.newRequest("localhost", serverPort)
.tag(tag2)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());

// Make sure the two TLVs created two destinations.
assertEquals(2, client.getDestinations().size());
}

@Test
Expand Down

0 comments on commit fcc18b0

Please sign in to comment.