Skip to content

Commit

Permalink
Add support for client-side TCP FastOpen to KQueue MacOS (#11560)
Browse files Browse the repository at this point in the history
Motivation:
The MacOS-specific `connectx(2)` system call make it possible to establish client-side connections with TCP FastOpen.

Modification:
Add support for TCP FastOpen to the KQueue transport, and add the `connectx(2)` system call to `BsdSocket`.

Result:
It's now possible to use TCP FastOpen when initiating connections on MacOS.
  • Loading branch information
chrisvest committed Aug 12, 2021
1 parent bcdc07f commit 25699e4
Show file tree
Hide file tree
Showing 14 changed files with 300 additions and 28 deletions.
Expand Up @@ -30,11 +30,9 @@
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.StringUtil;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.Timeout;
import org.opentest4j.TestAbortedException;

import java.io.ByteArrayOutputStream;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -177,8 +175,9 @@ private static void connectAndVerifyDataTransfer(Bootstrap cb, Channel sc)
}

protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
throw new TestAbortedException(
"Support for testing TCP_FASTOPEN not enabled for " + StringUtil.simpleClassName(this));
// TFO is an almost-pure optimisation and should not change any observable behaviour in our tests.
sb.option(ChannelOption.TCP_FASTOPEN, 5);
cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}

private static void assertLocalAddress(InetSocketAddress address) {
Expand Down
Expand Up @@ -29,10 +29,4 @@ public class EpollSocketConnectTest extends SocketConnectTest {
protected List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> newFactories() {
return EpollSocketTestPermutation.INSTANCE.socketWithoutFastOpen();
}

@Override
protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
sb.option(ChannelOption.TCP_FASTOPEN, 5);
cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}
}
58 changes: 57 additions & 1 deletion transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c
Expand Up @@ -13,6 +13,7 @@
* License for the specific language governing permissions and limitations
* under the License.
*/
#include <assert.h>
#include <stdlib.h>
#include <errno.h>
#include <string.h>
Expand Down Expand Up @@ -83,6 +84,60 @@ static jlong netty_kqueue_bsdsocket_sendFile(JNIEnv* env, jclass clazz, jint soc
return res < 0 ? -err : 0;
}

static jint netty_kqueue_bsdsocket_connectx(JNIEnv* env, jclass clazz,
jint socketFd,
jint socketInterface,
jboolean sourceIPv6, jbyteArray sourceAddress, jint sourceScopeId, jint sourcePort,
jboolean destinationIPv6, jbyteArray destinationAddress, jint destinationScopeId, jint destinationPort,
jint flags,
jlong iovAddress, jint iovCount, jint iovDataLength) {
#ifdef __APPLE__ // connectx(2) is only defined on Darwin.
sa_endpoints_t endpoints;
endpoints.sae_srcif = (unsigned int) socketInterface;
endpoints.sae_srcaddr = NULL;
endpoints.sae_srcaddrlen = 0;
endpoints.sae_dstaddr = NULL;
endpoints.sae_dstaddrlen = 0;

struct sockaddr_storage srcaddr;
socklen_t srcaddrlen;
struct sockaddr_storage dstaddr;
socklen_t dstaddrlen;

if (NULL != sourceAddress) {
if (-1 == netty_unix_socket_initSockaddr(env,
sourceIPv6, sourceAddress, sourceScopeId, sourcePort, &srcaddr, &srcaddrlen)) {
netty_unix_errors_throwIOException(env,
"Source address specified, but could not be converted to sockaddr.");
return -EINVAL;
}
endpoints.sae_srcaddr = (const struct sockaddr*) &srcaddr;
endpoints.sae_srcaddrlen = srcaddrlen;
}

assert(destinationAddress != NULL); // Java side will ensure destination is never null.
if (-1 == netty_unix_socket_initSockaddr(env,
destinationIPv6, destinationAddress, destinationScopeId, destinationPort, &dstaddr, &dstaddrlen)) {
netty_unix_errors_throwIOException(env, "Destination address could not be converted to sockaddr.");
return -EINVAL;
}
endpoints.sae_dstaddr = (const struct sockaddr*) &dstaddr;
endpoints.sae_dstaddrlen = dstaddrlen;

int socket = (int) socketFd;
const struct iovec* iov = (const struct iovec*) iovAddress;
unsigned int iovcnt = (unsigned int) iovCount;
size_t len = (size_t) iovDataLength;
int result = connectx(socket, &endpoints, SAE_ASSOCID_ANY, flags, iov, iovcnt, &len, NULL);
if (result == -1) {
return -errno;
}
return (jint) len;
#else
return -ENOSYS;
#endif
}

static void netty_kqueue_bsdsocket_setAcceptFilter(JNIEnv* env, jclass clazz, jint fd, jstring afName, jstring afArg) {
#ifdef SO_ACCEPTFILTER
struct accept_filter_arg af;
Expand Down Expand Up @@ -196,7 +251,8 @@ static const JNINativeMethod fixed_method_table[] = {
{ "setSndLowAt", "(II)V", (void *) netty_kqueue_bsdsocket_setSndLowAt },
{ "getAcceptFilter", "(I)[Ljava/lang/String;", (void *) netty_kqueue_bsdsocket_getAcceptFilter },
{ "getTcpNoPush", "(I)I", (void *) netty_kqueue_bsdsocket_getTcpNoPush },
{ "getSndLowAt", "(I)I", (void *) netty_kqueue_bsdsocket_getSndLowAt }
{ "getSndLowAt", "(I)I", (void *) netty_kqueue_bsdsocket_getSndLowAt },
{ "connectx", "(IIZ[BIIZ[BIIIJII)I", (void *) netty_kqueue_bsdsocket_connectx }
};

static const jint fixed_method_table_size = sizeof(fixed_method_table) / sizeof(fixed_method_table[0]);
Expand Down
24 changes: 23 additions & 1 deletion transport-native-kqueue/src/main/c/netty_kqueue_native.c
Expand Up @@ -60,6 +60,12 @@
#ifndef NOTE_DISCONNECTED
#define NOTE_DISCONNECTED 0x00001000
#endif /* NOTE_DISCONNECTED */
#ifndef CONNECT_RESUME_ON_READ_WRITE
#define CONNECT_RESUME_ON_READ_WRITE 0x1
#endif /* CONNECT_RESUME_ON_READ_WRITE */
#ifndef CONNECT_DATA_IDEMPOTENT
#define CONNECT_DATA_IDEMPOTENT 0x2
#endif /* CONNECT_DATA_IDEMPOTENT */
#else
#ifndef EVFILT_SOCK
#define EVFILT_SOCK 0 // Disabled
Expand All @@ -73,6 +79,12 @@
#ifndef NOTE_DISCONNECTED
#define NOTE_DISCONNECTED 0
#endif /* NOTE_DISCONNECTED */
#ifndef CONNECT_RESUME_ON_READ_WRITE
#define CONNECT_RESUME_ON_READ_WRITE 0
#endif /* CONNECT_RESUME_ON_READ_WRITE */
#ifndef CONNECT_DATA_IDEMPOTENT
#define CONNECT_DATA_IDEMPOTENT 0
#endif /* CONNECT_DATA_IDEMPOTENT */
#endif /* __APPLE__ */

static clockid_t waitClockId = 0; // initialized by netty_unix_util_initialize_wait_clock
Expand Down Expand Up @@ -247,6 +259,14 @@ static jshort netty_kqueue_native_noteDisconnected(JNIEnv* env, jclass clazz) {
return NOTE_DISCONNECTED;
}

static jint netty_kqueue_bsdsocket_connectResumeOnReadWrite(JNIEnv *env) {
return CONNECT_RESUME_ON_READ_WRITE;
}

static jint netty_kqueue_bsdsocket_connectDataIdempotent(JNIEnv *env) {
return CONNECT_DATA_IDEMPOTENT;
}

// JNI Method Registration Table Begin
static const JNINativeMethod statically_referenced_fixed_method_table[] = {
{ "evfiltRead", "()S", (void *) netty_kqueue_native_evfiltRead },
Expand All @@ -262,7 +282,9 @@ static const JNINativeMethod statically_referenced_fixed_method_table[] = {
{ "evError", "()S", (void *) netty_kqueue_native_evError },
{ "noteReadClosed", "()S", (void *) netty_kqueue_native_noteReadClosed },
{ "noteConnReset", "()S", (void *) netty_kqueue_native_noteConnReset },
{ "noteDisconnected", "()S", (void *) netty_kqueue_native_noteDisconnected }
{ "noteDisconnected", "()S", (void *) netty_kqueue_native_noteDisconnected },
{ "connectResumeOnReadWrite", "()I", (void *) netty_kqueue_bsdsocket_connectResumeOnReadWrite },
{ "connectDataIdempotent", "()I", (void *) netty_kqueue_bsdsocket_connectDataIdempotent }
};
static const jint statically_referenced_fixed_method_table_size = sizeof(statically_referenced_fixed_method_table) / sizeof(statically_referenced_fixed_method_table[0]);
static const JNINativeMethod fixed_method_table[] = {
Expand Down
Expand Up @@ -387,7 +387,7 @@ final void readReadyBefore() {
final void readReadyFinally(ChannelConfig config) {
maybeMoreDataToRead = allocHandle.maybeMoreDataToRead();

if (allocHandle.isReadEOF() || (readPending && maybeMoreDataToRead)) {
if (allocHandle.isReadEOF() || readPending && maybeMoreDataToRead) {
// trigger a read again as there may be something left to read and because of ET we
// will not get notified again until we read everything from the socket
//
Expand Down Expand Up @@ -691,7 +691,7 @@ protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddr
socket.bind(localAddress);
}

boolean connected = doConnect0(remoteAddress);
boolean connected = doConnect0(remoteAddress, localAddress);
if (connected) {
remote = remoteSocketAddr == null?
remoteAddress : computeRemoteAddr(remoteSocketAddr, socket.remoteAddress());
Expand All @@ -703,10 +703,10 @@ protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddr
return connected;
}

private boolean doConnect0(SocketAddress remote) throws Exception {
protected boolean doConnect0(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
boolean success = false;
try {
boolean connected = socket.connect(remote);
boolean connected = socket.connect(remoteAddress);
if (!connected) {
writeFilter(true);
}
Expand Down
Expand Up @@ -16,13 +16,21 @@
package io.netty.channel.kqueue;

import io.netty.channel.DefaultFileRegion;
import io.netty.channel.unix.IovArray;
import io.netty.channel.unix.PeerCredentials;
import io.netty.channel.unix.Socket;

import java.io.IOException;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;

import static io.netty.channel.kqueue.AcceptFilter.PLATFORM_UNSUPPORTED;
import static io.netty.channel.kqueue.Native.CONNECT_TCP_FASTOPEN;
import static io.netty.channel.unix.Errors.ERRNO_EINPROGRESS_NEGATIVE;
import static io.netty.channel.unix.Errors.ioResult;
import static io.netty.channel.unix.NativeInetAddress.ipv4MappedIpv6Address;
import static io.netty.util.internal.ObjectUtil.checkNotNull;

/**
* A socket which provides access BSD native methods.
Expand All @@ -34,6 +42,12 @@ final class BsdSocket extends Socket {
private static final int APPLE_SND_LOW_AT_MAX = 1 << 17;
private static final int FREEBSD_SND_LOW_AT_MAX = 1 << 15;
static final int BSD_SND_LOW_AT_MAX = Math.min(APPLE_SND_LOW_AT_MAX, FREEBSD_SND_LOW_AT_MAX);
/**
* The `endpoints` structure passed to `connectx(2)` has an optional "source interface" field,
* which is the index of the network interface to use.
* According to `if_nametoindex(3)`, the value 0 is used when no interface is specified.
*/
private static final int UNSPECIFIED_SOURCE_INTERFACE = 0;

BsdSocket(int fd) {
super(fd);
Expand All @@ -51,7 +65,7 @@ void setSndLowAt(int lowAt) throws IOException {
setSndLowAt(intValue(), lowAt);
}

boolean isTcpNoPush() throws IOException {
boolean isTcpNoPush() throws IOException {
return getTcpNoPush(intValue()) != 0;
}

Expand Down Expand Up @@ -80,6 +94,96 @@ long sendFile(DefaultFileRegion src, long baseOffset, long offset, long length)
return ioResult("sendfile", (int) res);
}

/**
* Establish a connection to the given destination address, and send the given data to it.
*
* <strong>Note:</strong> This method relies on the {@code connectx(2)} system call, which is MacOS specific.
*
* @param source the source address we are connecting from.
* @param destination the destination address we are connecting to.
* @param data the data to copy to the kernel-side socket buffer.
* @param tcpFastOpen if {@code true}, set the flags needed to enable TCP FastOpen connecting.
* @return The number of bytes copied to the kernel-side socket buffer, or the number of bytes sent to the
* destination. This number is <em>negative</em> if connecting is left in an in-progress state,
* or <em>positive</em> if the connection was immediately established.
* @throws IOException if an IO error occurs, if the {@code data} is too big to send in one go,
* or if the system call is not supported on your platform.
*/
int connectx(InetSocketAddress source, InetSocketAddress destination, IovArray data, boolean tcpFastOpen)
throws IOException {
checkNotNull(destination, "Destination InetSocketAddress cannot be null.");
int flags = tcpFastOpen ? CONNECT_TCP_FASTOPEN : 0;

boolean sourceIPv6;
byte[] sourceAddress;
int sourceScopeId;
int sourcePort;
if (source == null) {
sourceIPv6 = false;
sourceAddress = null;
sourceScopeId = 0;
sourcePort = 0;
} else {
InetAddress sourceInetAddress = source.getAddress();
sourceIPv6 = sourceInetAddress instanceof Inet6Address;
if (sourceIPv6) {
sourceAddress = sourceInetAddress.getAddress();
sourceScopeId = ((Inet6Address) sourceInetAddress).getScopeId();
} else {
// convert to ipv4 mapped ipv6 address;
sourceScopeId = 0;
sourceAddress = ipv4MappedIpv6Address(sourceInetAddress.getAddress());
}
sourcePort = source.getPort();
}

InetAddress destinationInetAddress = destination.getAddress();
boolean destinationIPv6 = destinationInetAddress instanceof Inet6Address;
byte[] destinationAddress;
int destinationScopeId;
if (destinationIPv6) {
destinationAddress = destinationInetAddress.getAddress();
destinationScopeId = ((Inet6Address) destinationInetAddress).getScopeId();
} else {
// convert to ipv4 mapped ipv6 address;
destinationScopeId = 0;
destinationAddress = ipv4MappedIpv6Address(destinationInetAddress.getAddress());
}
int destinationPort = destination.getPort();

long iovAddress;
int iovCount;
int iovDataLength;
if (data == null || data.count() == 0) {
iovAddress = 0;
iovCount = 0;
iovDataLength = 0;
} else {
iovAddress = data.memoryAddress(0);
iovCount = data.count();
long size = data.size();
if (size > Integer.MAX_VALUE) {
throw new IOException("IovArray.size() too big: " + size + " bytes.");
}
iovDataLength = (int) size;
}

int result = connectx(intValue(),
UNSPECIFIED_SOURCE_INTERFACE, sourceIPv6, sourceAddress, sourceScopeId, sourcePort,
destinationIPv6, destinationAddress, destinationScopeId, destinationPort,
flags, iovAddress, iovCount, iovDataLength);
if (result == ERRNO_EINPROGRESS_NEGATIVE) {
// This is normal for non-blocking sockets.
// We'll know the connection has been established when the socket is selectable for writing.
// Tell the channel the data was written, so the outbound buffer can update its position.
return -iovDataLength;
}
if (result < 0) {
return ioResult("connectx", result);
}
return result;
}

public static BsdSocket newSocketStream() {
return new BsdSocket(newSocketStream0());
}
Expand All @@ -99,12 +203,32 @@ public static BsdSocket newSocketDomainDgram() {
private static native long sendFile(int socketFd, DefaultFileRegion src, long baseOffset,
long offset, long length) throws IOException;

/**
* @return If successful, zero or positive number of bytes transfered, otherwise negative errno.
*/
private static native int connectx(
int socketFd,
// sa_endpoints_t *endpoints:
int sourceInterface,
boolean sourceIPv6, byte[] sourceAddress, int sourceScopeId, int sourcePort,
boolean destinationIPv6, byte[] destinationAddress, int destinationScopeId, int destinationPort,
// sae_associd_t associd is reserved
int flags,
long iovAddress, int iovCount, int iovDataLength
// sae_connid_t *connid is reserved
);

private static native String[] getAcceptFilter(int fd) throws IOException;

private static native int getTcpNoPush(int fd) throws IOException;

private static native int getSndLowAt(int fd) throws IOException;

private static native PeerCredentials getPeerCredentials(int fd) throws IOException;

private static native void setAcceptFilter(int fd, String filterName, String filterArgs) throws IOException;

private static native void setTcpNoPush(int fd, int tcpNoPush) throws IOException;

private static native void setSndLowAt(int fd, int lowAt) throws IOException;
}

0 comments on commit 25699e4

Please sign in to comment.