Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ssl buffers handling #8165

Merged
merged 6 commits into from Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -22,12 +22,14 @@
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
Expand All @@ -36,31 +38,39 @@
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.eclipse.jetty.client.api.ContentResponse;
import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpHeaderValue;
import org.eclipse.jetty.http.HttpScheme;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.io.ArrayByteBufferPool;
import org.eclipse.jetty.io.ArrayRetainableByteBufferPool;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.ClientConnectionFactory;
import org.eclipse.jetty.io.ClientConnector;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.ConnectionStatistics;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.io.RetainableByteBuffer;
import org.eclipse.jetty.io.RetainableByteBufferPool;
import org.eclipse.jetty.io.ssl.SslClientConnectionFactory;
import org.eclipse.jetty.io.ssl.SslConnection;
import org.eclipse.jetty.io.ssl.SslHandshakeListener;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.SecureRequestCustomizer;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.SslConnectionFactory;
import org.eclipse.jetty.toolchain.test.Net;
import org.eclipse.jetty.util.Pool;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.ExecutorThreadPool;
Expand All @@ -71,9 +81,14 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledForJreRange;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
Expand Down Expand Up @@ -682,12 +697,7 @@ protected int networkFill(ByteBuffer input) throws IOException
// Trigger the creation of a new connection, but don't use it.
ConnectionPoolHelper.tryCreate(connectionPool);
// Verify that the connection has been created.
while (true)
{
Thread.sleep(50);
if (connectionPool.getConnectionCount() == 1)
break;
}
await().atMost(5, TimeUnit.SECONDS).until(connectionPool::getConnectionCount, is(1));

// Wait for the server to idle timeout the connection.
Thread.sleep(idleTimeout + idleTimeout / 2);
Expand All @@ -698,6 +708,299 @@ protected int networkFill(ByteBuffer input) throws IOException
assertEquals(0, clientBytes.get());
}

@Test
public void testEncryptedInputBufferRepooling() throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
var retainableByteBufferPool = new ArrayRetainableByteBufferPool()
{
@Override
public Pool<RetainableByteBuffer> poolFor(int capacity, boolean direct)
{
return super.poolFor(capacity, direct);
}
};
server.addBean(retainableByteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected int networkFill(ByteBuffer input) throws IOException
{
int n = super.networkFill(input);
if (n > 0)
throw new IOException("boom");
return n;
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler());
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

Pool<RetainableByteBuffer> bucket = retainableByteBufferPool.poolFor(16 * 1024 + 1, ssl.isDirectBuffersForEncryption());
assertEquals(1, bucket.size());
assertEquals(1, bucket.getIdleCount());
}

@Test
public void testEncryptedOutputBufferRepooling() throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
List<ByteBuffer> leakedBuffers = new ArrayList<>();
ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool()
{
@Override
public ByteBuffer acquire(int size, boolean direct)
{
ByteBuffer acquired = super.acquire(size, direct);
leakedBuffers.add(acquired);
return acquired;
}

@Override
public void release(ByteBuffer buffer)
{
leakedBuffers.remove(buffer);
super.release(buffer);
}
};
server.addBean(byteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected boolean networkFlush(ByteBuffer output) throws IOException
{
throw new IOException("bang");
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler());
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty()));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testEncryptedOutputBufferRepoolingAfterNetworkFlushReturnsFalse(boolean close) throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
List<ByteBuffer> leakedBuffers = new ArrayList<>();
ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool()
{
@Override
public ByteBuffer acquire(int size, boolean direct)
{
ByteBuffer acquired = super.acquire(size, direct);
leakedBuffers.add(acquired);
return acquired;
}

@Override
public void release(ByteBuffer buffer)
{
leakedBuffers.remove(buffer);
super.release(buffer);
}
};
server.addBean(byteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
AtomicBoolean failFlush = new AtomicBoolean(false);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected boolean networkFlush(ByteBuffer output) throws IOException
{
if (failFlush.get())
return false;
return super.networkFlush(output);
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response)
{
failFlush.set(true);
if (close)
jettyRequest.getHttpChannel().getEndPoint().close();
else
jettyRequest.getHttpChannel().getEndPoint().shutdownOutput();
}
});
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty()));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testEncryptedOutputBufferRepoolingAfterNetworkFlushThrows(boolean close) throws Exception
{
SslContextFactory.Server serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
List<ByteBuffer> leakedBuffers = new ArrayList<>();
ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool()
{
@Override
public ByteBuffer acquire(int size, boolean direct)
{
ByteBuffer acquired = super.acquire(size, direct);
leakedBuffers.add(acquired);
return acquired;
}

@Override
public void release(ByteBuffer buffer)
{
leakedBuffers.remove(buffer);
super.release(buffer);
}
};
server.addBean(byteBufferPool);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
AtomicBoolean failFlush = new AtomicBoolean(false);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
ByteBufferPool byteBufferPool = connector.getByteBufferPool();
RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class);
return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected boolean networkFlush(ByteBuffer output) throws IOException
{
if (failFlush.get())
throw new IOException();
return super.networkFlush(output);
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) throws IOException
{
failFlush.set(true);
if (close)
jettyRequest.getHttpChannel().getEndPoint().close();
else
jettyRequest.getHttpChannel().getEndPoint().shutdownOutput();
}
});
server.start();

SslContextFactory.Client clientTLSFactory = createClientSslContextFactory();
ClientConnector clientConnector = new ClientConnector();
clientConnector.setSelectors(1);
clientConnector.setSslContextFactory(clientTLSFactory);
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientConnector.setExecutor(clientThreads);
client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector));
client.setExecutor(clientThreads);
client.start();

assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send());

await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty()));
}

@Test
public void testNeverUsedConnectionThenClientIdleTimeout() throws Exception
{
Expand Down Expand Up @@ -780,12 +1083,7 @@ protected int networkFill(ByteBuffer input) throws IOException
// Trigger the creation of a new connection, but don't use it.
ConnectionPoolHelper.tryCreate(connectionPool);
// Verify that the connection has been created.
while (true)
{
Thread.sleep(50);
if (connectionPool.getConnectionCount() == 1)
break;
}
await().atMost(5, TimeUnit.SECONDS).until(connectionPool::getConnectionCount, is(1));

// Wait for the client to idle timeout the connection.
Thread.sleep(idleTimeout + idleTimeout / 2);
Expand Down