Skip to content

Commit

Permalink
Backport SocketProtocolFamily (#14043)
Browse files Browse the repository at this point in the history
Motivation:

In main we introduced SocketProtocolFamily to replace our
InternetProtocolFamily. We should backport this change to 4.2 for more
flexibility in the future.

Modifications:

- Backport SocketProtocolFamily
- Deprecate usage of InternetProtocolFamily

Result:

More flexibility in the future
  • Loading branch information
normanmaurer committed May 11, 2024
1 parent 2c17831 commit d3fdf71
Show file tree
Hide file tree
Showing 42 changed files with 479 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package io.netty.handler.codec.dns;

import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.SocketProtocolFamily;
import io.netty.util.NetUtil;
import io.netty.util.internal.UnstableApi;

import java.net.InetAddress;
Expand Down Expand Up @@ -59,16 +61,40 @@ public DefaultDnsOptEcsRecord(int maxPayloadSize, int srcPrefixLength, byte[] ad
/**
* Creates a new instance.
*
* @param maxPayloadSize the suggested max payload size in bytes
* @param protocolFamily the {@link InternetProtocolFamily} to use. This should be the same as the one used to
* send the query.
* @param maxPayloadSize the suggested max payload size in bytes
* @param protocolFamily the {@link InternetProtocolFamily} to use. This should be the same as the one used to
* send the query.
* @deprecated use {@link DefaultDnsOptEcsRecord#DefaultDnsOptEcsRecord(int, SocketProtocolFamily)}
*/
@Deprecated
public DefaultDnsOptEcsRecord(int maxPayloadSize, InternetProtocolFamily protocolFamily) {
this(maxPayloadSize, 0, 0, 0, protocolFamily.localhost().getAddress());
}

/**
* Creates a new instance.
*
* @param maxPayloadSize the suggested max payload size in bytes
* @param socketProtocolFamily the {@link SocketProtocolFamily} to use. This should be the same as the one used to
* send the query.
*/
public DefaultDnsOptEcsRecord(int maxPayloadSize, SocketProtocolFamily socketProtocolFamily) {
this(maxPayloadSize, 0, 0, 0, localAddress(socketProtocolFamily));
}

private static byte[] localAddress(SocketProtocolFamily family) {
switch (family) {
case INET:
return NetUtil.LOCALHOST4.getAddress();
case INET6:
return NetUtil.LOCALHOST6.getAddress();
default:
return null;
}
}

private static byte[] verifyAddress(byte[] bytes) {
if (bytes.length == 4 || bytes.length == 16) {
if (bytes != null && bytes.length == 4 || bytes.length == 16) {
return bytes;
}
throw new IllegalArgumentException("bytes.length must either 4 or 16");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package io.netty.handler.codec.dns;

import io.netty.buffer.ByteBuf;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.util.internal.UnstableApi;

Expand Down Expand Up @@ -94,8 +93,7 @@ private void encodeOptEcsRecord(DnsOptEcsRecord record, ByteBuf out) throws Exce
}

// See https://www.iana.org/assignments/address-family-numbers/address-family-numbers.xhtml
final short addressNumber = (short) (bytes.length == 4 ?
InternetProtocolFamily.IPv4.addressNumber() : InternetProtocolFamily.IPv6.addressNumber());
final short addressNumber = (short) (bytes.length == 4 ? 1 : 2);
int payloadLength = calculateEcsAddressLength(sourcePrefixLength, lowOrderBitsToPreserve);

int fullPayloadLength = 2 + // OPTION-CODE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.StringUtil;
import org.junit.jupiter.api.Test;

import java.net.Inet4Address;
import java.net.InetAddress;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -134,7 +134,7 @@ private static void testIp(InetAddress address, int prefix) throws Exception {
int rdataLength = out.readUnsignedShort();
assertEquals(rdataLength, out.readableBytes());

assertEquals((short) InternetProtocolFamily.of(address).addressNumber(), out.readShort());
assertEquals((short) (address instanceof Inet4Address ? 1 : 2), out.readShort());

assertEquals(prefix, out.readUnsignedByte());
assertEquals(0, out.readUnsignedByte()); // This must be 0 for requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ List<InetAddress> filterResults(List<InetAddress> unfiltered) {

@Override
boolean isCompleteEarly(InetAddress resolved) {
return completeEarlyIfPossible && parent.preferredAddressType().addressType() == resolved.getClass();
return completeEarlyIfPossible &&
DnsNameResolver.addressType(parent.preferredAddressType()) == resolved.getClass();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.netty.channel.socket.DatagramPacket;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.SocketProtocolFamily;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.dns.DatagramDnsQueryEncoder;
import io.netty.handler.codec.dns.DatagramDnsResponse;
Expand Down Expand Up @@ -105,20 +106,20 @@ public class DnsNameResolver extends InetNameResolver {
private static final DnsRecord[] EMPTY_ADDITIONALS = new DnsRecord[0];
private static final DnsRecordType[] IPV4_ONLY_RESOLVED_RECORD_TYPES =
{DnsRecordType.A};
private static final InternetProtocolFamily[] IPV4_ONLY_RESOLVED_PROTOCOL_FAMILIES =
{InternetProtocolFamily.IPv4};
private static final SocketProtocolFamily[] IPV4_ONLY_RESOLVED_PROTOCOL_FAMILIES =
{SocketProtocolFamily.INET};
private static final DnsRecordType[] IPV4_PREFERRED_RESOLVED_RECORD_TYPES =
{DnsRecordType.A, DnsRecordType.AAAA};
private static final InternetProtocolFamily[] IPV4_PREFERRED_RESOLVED_PROTOCOL_FAMILIES =
{InternetProtocolFamily.IPv4, InternetProtocolFamily.IPv6};
private static final SocketProtocolFamily[] IPV4_PREFERRED_RESOLVED_PROTOCOL_FAMILIES =
{SocketProtocolFamily.INET, SocketProtocolFamily.INET6};
private static final DnsRecordType[] IPV6_ONLY_RESOLVED_RECORD_TYPES =
{DnsRecordType.AAAA};
private static final InternetProtocolFamily[] IPV6_ONLY_RESOLVED_PROTOCOL_FAMILIES =
{InternetProtocolFamily.IPv6};
private static final SocketProtocolFamily[] IPV6_ONLY_RESOLVED_PROTOCOL_FAMILIES =
{SocketProtocolFamily.INET6};
private static final DnsRecordType[] IPV6_PREFERRED_RESOLVED_RECORD_TYPES =
{DnsRecordType.AAAA, DnsRecordType.A};
private static final InternetProtocolFamily[] IPV6_PREFERRED_RESOLVED_PROTOCOL_FAMILIES =
{InternetProtocolFamily.IPv6, InternetProtocolFamily.IPv4};
private static final SocketProtocolFamily[] IPV6_PREFERRED_RESOLVED_PROTOCOL_FAMILIES =
{SocketProtocolFamily.INET6, SocketProtocolFamily.INET};

private static final ChannelHandler NOOP_HANDLER = new ChannelHandlerAdapter() {
@Override
Expand Down Expand Up @@ -255,7 +256,7 @@ protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket p
private final long queryTimeoutMillis;
private final int maxQueriesPerResolve;
private final ResolvedAddressTypes resolvedAddressTypes;
private final InternetProtocolFamily[] resolvedInternetProtocolFamilies;
private final SocketProtocolFamily[] resolvedInternetProtocolFamilies;
private final boolean recursionDesired;
private final int maxPayloadSize;
private final boolean optResourceEnabled;
Expand All @@ -265,7 +266,7 @@ protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket p
private final int ndots;
private final boolean supportsAAAARecords;
private final boolean supportsARecords;
private final InternetProtocolFamily preferredAddressType;
private final SocketProtocolFamily preferredAddressType;
private final DnsRecordType[] resolveRecordTypes;
private final boolean decodeIdn;
private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory;
Expand Down Expand Up @@ -479,7 +480,7 @@ dnsServerAddressStreamProvider, new ThreadLocalNameServerAddressStream(dnsServer
}
preferredAddressType = preferredAddressType(this.resolvedAddressTypes);
this.authoritativeDnsServerCache = checkNotNull(authoritativeDnsServerCache, "authoritativeDnsServerCache");
nameServerComparator = new NameServerComparator(preferredAddressType.addressType());
nameServerComparator = new NameServerComparator(addressType(preferredAddressType));
this.maxNumConsolidation = maxNumConsolidation;
if (maxNumConsolidation > 0) {
inflightLookups = new HashMap<String, Future<List<InetAddress>>>();
Expand Down Expand Up @@ -541,14 +542,14 @@ public void operationComplete(ChannelFuture future) {
});
}

static InternetProtocolFamily preferredAddressType(ResolvedAddressTypes resolvedAddressTypes) {
static SocketProtocolFamily preferredAddressType(ResolvedAddressTypes resolvedAddressTypes) {
switch (resolvedAddressTypes) {
case IPV4_ONLY:
case IPV4_PREFERRED:
return InternetProtocolFamily.IPv4;
return SocketProtocolFamily.INET;
case IPV6_ONLY:
case IPV6_PREFERRED:
return InternetProtocolFamily.IPv6;
return SocketProtocolFamily.INET6;
default:
throw new IllegalArgumentException("Unknown ResolvedAddressTypes " + resolvedAddressTypes);
}
Expand Down Expand Up @@ -630,7 +631,7 @@ public ResolvedAddressTypes resolvedAddressTypes() {
return resolvedAddressTypes;
}

InternetProtocolFamily[] resolvedInternetProtocolFamiliesUnsafe() {
SocketProtocolFamily[] resolvedInternetProtocolFamiliesUnsafe() {
return resolvedInternetProtocolFamilies;
}

Expand All @@ -650,7 +651,7 @@ final boolean supportsARecords() {
return supportsARecords;
}

final InternetProtocolFamily preferredAddressType() {
final SocketProtocolFamily preferredAddressType() {
return preferredAddressType;
}

Expand Down Expand Up @@ -959,7 +960,14 @@ private static void validateAdditional(DnsRecord record, boolean validateType) {
}

private InetAddress loopbackAddress() {
return preferredAddressType().localhost();
switch (preferredAddressType()) {
case INET:
return NetUtil.LOCALHOST4;
case INET6:
return NetUtil.LOCALHOST6;
default:
throw new UnsupportedOperationException("Only INET and INET6 are supported");
}
}

/**
Expand Down Expand Up @@ -1008,10 +1016,11 @@ private boolean doResolveCached(String hostname,
if (cause == null) {
final int numEntries = cachedEntries.size();
// Find the first entry with the preferred address type.
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
for (SocketProtocolFamily f : resolvedInternetProtocolFamilies) {
for (int i = 0; i < numEntries; i++) {
final DnsCacheEntry e = cachedEntries.get(i);
if (f.addressType().isInstance(e.address())) {
final Class<? extends InetAddress> addressType = addressType(f);
if (addressType != null && addressType.isInstance(e.address())) {
trySuccess(promise, e.address());
return true;
}
Expand All @@ -1024,6 +1033,17 @@ private boolean doResolveCached(String hostname,
}
}

static Class<? extends InetAddress> addressType(SocketProtocolFamily f) {
switch (f) {
case INET:
return Inet4Address.class;
case INET6:
return Inet6Address.class;
default:
return null;
}
}

static <T> boolean trySuccess(Promise<T> promise, T result) {
final boolean notifiedRecords = promise.trySuccess(result);
if (!notifiedRecords) {
Expand Down Expand Up @@ -1105,7 +1125,7 @@ static boolean doResolveAllCached(String hostname,
DnsRecord[] additionals,
Promise<List<InetAddress>> promise,
DnsCache resolveCache,
InternetProtocolFamily[] resolvedInternetProtocolFamilies) {
SocketProtocolFamily[] resolvedInternetProtocolFamilies) {
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
if (cachedEntries == null || cachedEntries.isEmpty()) {
return false;
Expand All @@ -1115,10 +1135,11 @@ static boolean doResolveAllCached(String hostname,
if (cause == null) {
List<InetAddress> result = null;
final int numEntries = cachedEntries.size();
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
for (SocketProtocolFamily f : resolvedInternetProtocolFamilies) {
for (int i = 0; i < numEntries; i++) {
final DnsCacheEntry e = cachedEntries.get(i);
if (f.addressType().isInstance(e.address())) {
Class<? extends InetAddress> addressType = addressType(f);
if (addressType != null && addressType.isInstance(e.address())) {
if (result == null) {
result = new ArrayList<InetAddress>(numEntries);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.SocketProtocolFamily;
import io.netty.resolver.HostsFileEntriesResolver;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.util.concurrent.Future;
Expand Down Expand Up @@ -322,25 +323,56 @@ public DnsNameResolverBuilder queryTimeoutMillis(long queryTimeoutMillis) {
* Compute a {@link ResolvedAddressTypes} from some {@link InternetProtocolFamily}s.
* An empty input will return the default value, based on "java.net" System properties.
* Valid inputs are (), (IPv4), (IPv6), (Ipv4, IPv6) and (IPv6, IPv4).
*
* @param internetProtocolFamilies a valid sequence of {@link InternetProtocolFamily}s
* @return a {@link ResolvedAddressTypes}
* @deprecated use {@link #computeResolvedAddressTypes(SocketProtocolFamily...)}
*/
@Deprecated
public static ResolvedAddressTypes computeResolvedAddressTypes(InternetProtocolFamily... internetProtocolFamilies) {
if (internetProtocolFamilies == null || internetProtocolFamilies.length == 0) {
return DnsNameResolver.DEFAULT_RESOLVE_ADDRESS_TYPES;
}
if (internetProtocolFamilies.length > 2) {
throw new IllegalArgumentException("No more than 2 InternetProtocolFamilies");
}
return computeResolvedAddressTypes(toSocketProtocolFamilies(internetProtocolFamilies));
}

private static SocketProtocolFamily[] toSocketProtocolFamilies(InternetProtocolFamily... internetProtocolFamilies) {
if (internetProtocolFamilies == null || internetProtocolFamilies.length == 0) {
return null;
}
SocketProtocolFamily[] socketProtocolFamilies = new SocketProtocolFamily[internetProtocolFamilies.length];
for (int i = 0; i < internetProtocolFamilies.length; i++) {
socketProtocolFamilies[i] = internetProtocolFamilies[i].toSocketProtocolFamily();
}
return socketProtocolFamilies;
}

/**
* Compute a {@link ResolvedAddressTypes} from some {@link SocketProtocolFamily}s.
* An empty input will return the default value, based on "java.net" System properties.
* Valid inputs are (), (IPv4), (IPv6), (Ipv4, IPv6) and (IPv6, IPv4).
* @param socketProtocolFamilies a valid sequence of {@link SocketProtocolFamily}s
* @return a {@link ResolvedAddressTypes}
*/
public static ResolvedAddressTypes computeResolvedAddressTypes(SocketProtocolFamily... socketProtocolFamilies) {
if (socketProtocolFamilies == null || socketProtocolFamilies.length == 0) {
return DnsNameResolver.DEFAULT_RESOLVE_ADDRESS_TYPES;
}
if (socketProtocolFamilies.length > 2) {
throw new IllegalArgumentException("No more than 2 socketProtocolFamilies");
}

switch(internetProtocolFamilies[0]) {
case IPv4:
return (internetProtocolFamilies.length >= 2
&& internetProtocolFamilies[1] == InternetProtocolFamily.IPv6) ?
switch(socketProtocolFamilies[0]) {
case INET:
return (socketProtocolFamilies.length >= 2
&& socketProtocolFamilies[1] == SocketProtocolFamily.INET6) ?
ResolvedAddressTypes.IPV4_PREFERRED: ResolvedAddressTypes.IPV4_ONLY;
case IPv6:
return (internetProtocolFamilies.length >= 2
&& internetProtocolFamilies[1] == InternetProtocolFamily.IPv4) ?
case INET6:
return (socketProtocolFamilies.length >= 2
&& socketProtocolFamilies[1] == SocketProtocolFamily.INET) ?
ResolvedAddressTypes.IPV6_PREFERRED: ResolvedAddressTypes.IPV6_ONLY;
default:
throw new IllegalArgumentException(
Expand Down Expand Up @@ -522,7 +554,8 @@ private AuthoritativeDnsServerCache newAuthoritativeDnsServerCache() {
intValue(minTtl, 0), intValue(maxTtl, Integer.MAX_VALUE),
// Let us use the sane ordering as DnsNameResolver will be used when returning
// nameservers from the cache.
new NameServerComparator(DnsNameResolver.preferredAddressType(resolvedAddressTypes).addressType()));
new NameServerComparator(DnsNameResolver.addressType(
DnsNameResolver.preferredAddressType(resolvedAddressTypes))));
}

private DnsServerAddressStream newQueryServerAddressStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package io.netty.resolver.dns;

import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.SocketProtocolFamily;

import java.net.Inet4Address;
import java.net.Inet6Address;
Expand All @@ -27,11 +27,11 @@ final class PreferredAddressTypeComparator implements Comparator<InetAddress> {
private static final PreferredAddressTypeComparator IPv4 = new PreferredAddressTypeComparator(Inet4Address.class);
private static final PreferredAddressTypeComparator IPv6 = new PreferredAddressTypeComparator(Inet6Address.class);

static PreferredAddressTypeComparator comparator(InternetProtocolFamily family) {
static PreferredAddressTypeComparator comparator(SocketProtocolFamily family) {
switch (family) {
case IPv4:
case INET:
return IPv4;
case IPv6:
case INET6:
return IPv6;
default:
throw new IllegalArgumentException();
Expand Down

0 comments on commit d3fdf71

Please sign in to comment.