Skip to content

Commit

Permalink
Add TCP FastOpen to the KQueue test permutations
Browse files Browse the repository at this point in the history
Also fix a bug where using TFO with KQueue would prematurely consider the socket connected.
Instead, the socket should assume to be connect-in-progress when using TFO since all our sockets are non-blocking.
  • Loading branch information
chrisvest committed Aug 11, 2021
1 parent 14306fc commit 4577097
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 54 deletions.
Expand Up @@ -28,11 +28,9 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.StringUtil;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.Timeout;
import org.opentest4j.TestAbortedException;

import java.io.ByteArrayOutputStream;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -189,8 +187,9 @@ private static void connectAndVerifyDataTransfer(Bootstrap cb, Channel sc)
}

protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
throw new TestAbortedException(
"Support for testing TCP_FASTOPEN not enabled for " + StringUtil.simpleClassName(this));
// TFO is an almost-pure optimisation and should not change any observable behaviour in our tests.
sb.option(ChannelOption.TCP_FASTOPEN, 5);
cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}

private static void assertLocalAddress(InetSocketAddress address) {
Expand Down
Expand Up @@ -29,10 +29,4 @@ public class EpollSocketConnectTest extends SocketConnectTest {
protected List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> newFactories() {
return EpollSocketTestPermutation.INSTANCE.socketWithoutFastOpen();
}

@Override
protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
sb.option(ChannelOption.TCP_FASTOPEN, 5);
cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}
}
Expand Up @@ -68,7 +68,6 @@ public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstra
return list;
}

@SuppressWarnings("unchecked")
@Override
public List<BootstrapFactory<ServerBootstrap>> serverSocket() {
List<BootstrapFactory<ServerBootstrap>> toReturn = new ArrayList<BootstrapFactory<ServerBootstrap>>();
Expand Down Expand Up @@ -207,10 +206,7 @@ public String toString() {
}

public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> domainSocket() {

List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> list =
combo(serverDomainSocket(), clientDomainSocket());
return list;
return combo(serverDomainSocket(), clientDomainSocket());
}

public List<BootstrapFactory<ServerBootstrap>> serverDomainSocket() {
Expand Down
Expand Up @@ -390,7 +390,7 @@ final void readReadyBefore() {
final void readReadyFinally(ChannelConfig config) {
maybeMoreDataToRead = allocHandle.maybeMoreDataToRead();

if (allocHandle.isReadEOF() || (readPending && maybeMoreDataToRead)) {
if (allocHandle.isReadEOF() || readPending && maybeMoreDataToRead) {
// trigger a read again as there may be something left to read and because of ET we
// will not get notified again until we read everything from the socket
//
Expand Down Expand Up @@ -699,7 +699,7 @@ protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddr
socket.bind(localAddress);
}

boolean connected = doConnect0(remoteAddress);
boolean connected = doConnect0(remoteAddress, localAddress);
if (connected) {
remote = remoteSocketAddr == null?
remoteAddress : computeRemoteAddr(remoteSocketAddr, socket.remoteAddress());
Expand All @@ -711,10 +711,10 @@ protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddr
return connected;
}

private boolean doConnect0(SocketAddress remote) throws Exception {
protected boolean doConnect0(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
boolean success = false;
try {
boolean connected = socket.connect(remote);
boolean connected = socket.connect(remoteAddress);
if (!connected) {
writeFilter(true);
}
Expand Down
Expand Up @@ -68,7 +68,7 @@ public ServerSocketChannel parent() {
}

@Override
protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
protected boolean doConnect0(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
if (config.isTcpFastOpenConnect()) {
ChannelOutboundBuffer outbound = unsafe().outboundBuffer();
outbound.addFlush();
Expand All @@ -81,17 +81,16 @@ protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddr
iov.add(initialData, initialData.readerIndex(), initialData.readableBytes());
int bytesSent = socket.connectx(
(InetSocketAddress) localAddress, (InetSocketAddress) remoteAddress, iov, true);
if (bytesSent > 0) {
outbound.removeBytes(bytesSent);
return true;
}
writeFilter(true);
outbound.removeBytes(bytesSent);
return false; // 'false' because we assume connecting to be in-progress.
} finally {
iov.release();
}
}
}
}
return super.doConnect(remoteAddress, localAddress);
return super.doConnect0(remoteAddress, localAddress);
}

@Override
Expand Down
Expand Up @@ -24,6 +24,6 @@
public class KQueueSocketChannelNotYetConnectedTest extends SocketChannelNotYetConnectedTest {
@Override
protected List<TestsuitePermutation.BootstrapFactory<Bootstrap>> newFactories() {
return KQueueSocketTestPermutation.INSTANCE.clientSocket();
return KQueueSocketTestPermutation.INSTANCE.clientSocketWithFastOpen();
}
}
Expand Up @@ -19,6 +19,7 @@
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.nio.NioDatagramChannel;
Expand All @@ -30,8 +31,6 @@
import io.netty.testsuite.transport.TestsuitePermutation.BootstrapFactory;
import io.netty.testsuite.transport.socket.SocketTestPermutation;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -47,8 +46,6 @@ class KQueueSocketTestPermutation extends SocketTestPermutation {
static final EventLoopGroup KQUEUE_WORKER_GROUP =
new KQueueEventLoopGroup(WORKERS, new DefaultThreadFactory("testsuite-KQueue-worker", true));

private static final InternalLogger logger = InternalLoggerFactory.getInstance(KQueueSocketTestPermutation.class);

@Override
public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> socket() {

Expand All @@ -60,7 +57,6 @@ public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstra
return list;
}

@SuppressWarnings("unchecked")
@Override
public List<BootstrapFactory<ServerBootstrap>> serverSocket() {
List<BootstrapFactory<ServerBootstrap>> toReturn = new ArrayList<BootstrapFactory<ServerBootstrap>>();
Expand All @@ -83,30 +79,47 @@ public ServerBootstrap newInstance() {
return toReturn;
}

@SuppressWarnings("unchecked")
@Override
public List<BootstrapFactory<Bootstrap>> clientSocket() {
return Arrays.asList(
new BootstrapFactory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(KQUEUE_WORKER_GROUP).channel(KQueueSocketChannel.class);
}
},
new BootstrapFactory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class);
}
}
);
List<BootstrapFactory<Bootstrap>> toReturn = new ArrayList<BootstrapFactory<Bootstrap>>();

toReturn.add(new BootstrapFactory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(KQUEUE_WORKER_GROUP).channel(KQueueSocketChannel.class);
}
});

toReturn.add(new BootstrapFactory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class);
}
});

return toReturn;
}

@Override
public List<BootstrapFactory<Bootstrap>> clientSocketWithFastOpen() {
List<BootstrapFactory<Bootstrap>> factories = clientSocket();

int insertIndex = factories.size() - 1; // Keep NIO fixture last.
factories.add(insertIndex, new BootstrapFactory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(KQUEUE_WORKER_GROUP).channel(KQueueSocketChannel.class)
.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}
});

return factories;
}

@Override
public List<TestsuitePermutation.BootstrapComboFactory<Bootstrap, Bootstrap>> datagram(
final InternetProtocolFamily family) {
// Make the list of Bootstrap factories.
@SuppressWarnings("unchecked")
List<BootstrapFactory<Bootstrap>> bfs = Arrays.asList(
new BootstrapFactory<Bootstrap>() {
@Override
Expand Down Expand Up @@ -135,10 +148,7 @@ public Bootstrap newInstance() {
}

public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> domainSocket() {

List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> list =
combo(serverDomainSocket(), clientDomainSocket());
return list;
return combo(serverDomainSocket(), clientDomainSocket());
}

public List<BootstrapFactory<ServerBootstrap>> serverDomainSocket() {
Expand Down
Expand Up @@ -25,6 +25,6 @@ public class KqueueWriteBeforeRegisteredTest extends WriteBeforeRegisteredTest {

@Override
protected List<TestsuitePermutation.BootstrapFactory<Bootstrap>> newFactories() {
return KQueueSocketTestPermutation.INSTANCE.clientSocket();
return KQueueSocketTestPermutation.INSTANCE.clientSocketWithFastOpen();
}
}
Expand Up @@ -52,7 +52,7 @@ public class Socket extends FileDescriptor {

public Socket(int fd) {
super(fd);
this.ipv6 = isIPv6(fd);
ipv6 = isIPv6(fd);
}

/**
Expand All @@ -72,7 +72,7 @@ public final void shutdown(boolean read, boolean write) throws IOException {
// shutdown anything. This is because if the underlying FD is reused and we still have an object which
// represents the previous incarnation of the FD we need to be sure we don't inadvertently shutdown the
// "new" FD without explicitly having a change.
final int oldState = this.state;
final int oldState = state;
if (isClosed(oldState)) {
throw new ClosedChannelException();
}
Expand Down

0 comments on commit 4577097

Please sign in to comment.