From b383f53c1a56af195c5cf82565bbcea84733fd9e Mon Sep 17 00:00:00 2001 From: Ludovic Orban Date: Tue, 14 Jun 2022 17:03:01 +0200 Subject: [PATCH] #8161 improve SSLConnection buffers handling Signed-off-by: Ludovic Orban --- .../jetty/client/HttpClientTLSTest.java | 322 +++++++++++++++++- .../io/ArrayRetainableByteBufferPool.java | 5 + .../eclipse/jetty/io/ssl/SslConnection.java | 97 ++++-- 3 files changed, 386 insertions(+), 38 deletions(-) diff --git a/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java b/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java index d8ad678c430c..e4d5358cf78d 100644 --- a/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java +++ b/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java @@ -22,12 +22,14 @@ import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLEngine; @@ -36,6 +38,8 @@ import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSocket; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.client.api.ContentResponse; import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP; @@ -43,12 +47,16 @@ import org.eclipse.jetty.http.HttpHeaderValue; import org.eclipse.jetty.http.HttpScheme; import org.eclipse.jetty.http.HttpStatus; +import org.eclipse.jetty.io.ArrayByteBufferPool; +import org.eclipse.jetty.io.ArrayRetainableByteBufferPool; import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.io.ClientConnectionFactory; import org.eclipse.jetty.io.ClientConnector; import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.ConnectionStatistics; import org.eclipse.jetty.io.EndPoint; +import org.eclipse.jetty.io.RetainableByteBuffer; +import org.eclipse.jetty.io.RetainableByteBufferPool; import org.eclipse.jetty.io.ssl.SslClientConnectionFactory; import org.eclipse.jetty.io.ssl.SslConnection; import org.eclipse.jetty.io.ssl.SslHandshakeListener; @@ -56,11 +64,13 @@ import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.HttpConfiguration; import org.eclipse.jetty.server.HttpConnectionFactory; +import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.SecureRequestCustomizer; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.SslConnectionFactory; import org.eclipse.jetty.toolchain.test.Net; +import org.eclipse.jetty.util.Pool; import org.eclipse.jetty.util.StringUtil; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.thread.ExecutorThreadPool; @@ -71,9 +81,14 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledForJreRange; import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import static org.awaitility.Awaitility.await; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -682,12 +697,7 @@ protected int networkFill(ByteBuffer input) throws IOException // Trigger the creation of a new connection, but don't use it. ConnectionPoolHelper.tryCreate(connectionPool); // Verify that the connection has been created. - while (true) - { - Thread.sleep(50); - if (connectionPool.getConnectionCount() == 1) - break; - } + await().atMost(5, TimeUnit.SECONDS).until(connectionPool::getConnectionCount, is(1)); // Wait for the server to idle timeout the connection. Thread.sleep(idleTimeout + idleTimeout / 2); @@ -698,6 +708,299 @@ protected int networkFill(ByteBuffer input) throws IOException assertEquals(0, clientBytes.get()); } + @Test + public void testEncryptedInputBufferRepooling() throws Exception + { + SslContextFactory.Server serverTLSFactory = createServerSslContextFactory(); + QueuedThreadPool serverThreads = new QueuedThreadPool(); + serverThreads.setName("server"); + server = new Server(serverThreads); + var retainableByteBufferPool = new ArrayRetainableByteBufferPool() + { + @Override + public Pool poolFor(int capacity, boolean direct) + { + return super.poolFor(capacity, direct); + } + }; + server.addBean(retainableByteBufferPool); + HttpConfiguration httpConfig = new HttpConfiguration(); + httpConfig.addCustomizer(new SecureRequestCustomizer()); + HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); + SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol()) + { + @Override + protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine) + { + ByteBufferPool byteBufferPool = connector.getByteBufferPool(); + RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class); + return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption()) + { + @Override + protected int networkFill(ByteBuffer input) throws IOException + { + int n = super.networkFill(input); + if (n > 0) + throw new IOException("boom"); + return n; + } + }; + } + }; + connector = new ServerConnector(server, 1, 1, ssl, http); + server.addConnector(connector); + server.setHandler(new EmptyServerHandler()); + server.start(); + + SslContextFactory.Client clientTLSFactory = createClientSslContextFactory(); + ClientConnector clientConnector = new ClientConnector(); + clientConnector.setSelectors(1); + clientConnector.setSslContextFactory(clientTLSFactory); + QueuedThreadPool clientThreads = new QueuedThreadPool(); + clientThreads.setName("client"); + clientConnector.setExecutor(clientThreads); + client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector)); + client.setExecutor(clientThreads); + client.start(); + + assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send()); + + Pool bucket = retainableByteBufferPool.poolFor(16 * 1024 + 1, ssl.isDirectBuffersForEncryption()); + assertEquals(1, bucket.size()); + assertEquals(1, bucket.getIdleCount()); + } + + @Test + public void testEncryptedOutputBufferRepooling() throws Exception + { + SslContextFactory.Server serverTLSFactory = createServerSslContextFactory(); + QueuedThreadPool serverThreads = new QueuedThreadPool(); + serverThreads.setName("server"); + server = new Server(serverThreads); + List leakedBuffers = new ArrayList<>(); + ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool() + { + @Override + public ByteBuffer acquire(int size, boolean direct) + { + ByteBuffer acquired = super.acquire(size, direct); + leakedBuffers.add(acquired); + return acquired; + } + + @Override + public void release(ByteBuffer buffer) + { + leakedBuffers.remove(buffer); + super.release(buffer); + } + }; + server.addBean(byteBufferPool); + HttpConfiguration httpConfig = new HttpConfiguration(); + httpConfig.addCustomizer(new SecureRequestCustomizer()); + HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); + SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol()) + { + @Override + protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine) + { + ByteBufferPool byteBufferPool = connector.getByteBufferPool(); + RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class); + return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption()) + { + @Override + protected boolean networkFlush(ByteBuffer output) throws IOException + { + throw new IOException("bang"); + } + }; + } + }; + connector = new ServerConnector(server, 1, 1, ssl, http); + server.addConnector(connector); + server.setHandler(new EmptyServerHandler()); + server.start(); + + SslContextFactory.Client clientTLSFactory = createClientSslContextFactory(); + ClientConnector clientConnector = new ClientConnector(); + clientConnector.setSelectors(1); + clientConnector.setSslContextFactory(clientTLSFactory); + QueuedThreadPool clientThreads = new QueuedThreadPool(); + clientThreads.setName("client"); + clientConnector.setExecutor(clientThreads); + client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector)); + client.setExecutor(clientThreads); + client.start(); + + assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send()); + + await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty())); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testEncryptedOutputBufferRepoolingAfterNetworkFlushReturnsFalse(boolean close) throws Exception + { + SslContextFactory.Server serverTLSFactory = createServerSslContextFactory(); + QueuedThreadPool serverThreads = new QueuedThreadPool(); + serverThreads.setName("server"); + server = new Server(serverThreads); + List leakedBuffers = new ArrayList<>(); + ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool() + { + @Override + public ByteBuffer acquire(int size, boolean direct) + { + ByteBuffer acquired = super.acquire(size, direct); + leakedBuffers.add(acquired); + return acquired; + } + + @Override + public void release(ByteBuffer buffer) + { + leakedBuffers.remove(buffer); + super.release(buffer); + } + }; + server.addBean(byteBufferPool); + HttpConfiguration httpConfig = new HttpConfiguration(); + httpConfig.addCustomizer(new SecureRequestCustomizer()); + HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); + AtomicBoolean failFlush = new AtomicBoolean(false); + SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol()) + { + @Override + protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine) + { + ByteBufferPool byteBufferPool = connector.getByteBufferPool(); + RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class); + return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption()) + { + @Override + protected boolean networkFlush(ByteBuffer output) throws IOException + { + if (failFlush.get()) + return false; + return super.networkFlush(output); + } + }; + } + }; + connector = new ServerConnector(server, 1, 1, ssl, http); + server.addConnector(connector); + server.setHandler(new EmptyServerHandler() + { + @Override + protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) + { + failFlush.set(true); + if (close) + jettyRequest.getHttpChannel().getEndPoint().close(); + else + jettyRequest.getHttpChannel().getEndPoint().shutdownOutput(); + } + }); + server.start(); + + SslContextFactory.Client clientTLSFactory = createClientSslContextFactory(); + ClientConnector clientConnector = new ClientConnector(); + clientConnector.setSelectors(1); + clientConnector.setSslContextFactory(clientTLSFactory); + QueuedThreadPool clientThreads = new QueuedThreadPool(); + clientThreads.setName("client"); + clientConnector.setExecutor(clientThreads); + client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector)); + client.setExecutor(clientThreads); + client.start(); + + assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send()); + + await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty())); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testEncryptedOutputBufferRepoolingAfterNetworkFlushThrows(boolean close) throws Exception + { + SslContextFactory.Server serverTLSFactory = createServerSslContextFactory(); + QueuedThreadPool serverThreads = new QueuedThreadPool(); + serverThreads.setName("server"); + server = new Server(serverThreads); + List leakedBuffers = new ArrayList<>(); + ArrayByteBufferPool byteBufferPool = new ArrayByteBufferPool() + { + @Override + public ByteBuffer acquire(int size, boolean direct) + { + ByteBuffer acquired = super.acquire(size, direct); + leakedBuffers.add(acquired); + return acquired; + } + + @Override + public void release(ByteBuffer buffer) + { + leakedBuffers.remove(buffer); + super.release(buffer); + } + }; + server.addBean(byteBufferPool); + HttpConfiguration httpConfig = new HttpConfiguration(); + httpConfig.addCustomizer(new SecureRequestCustomizer()); + HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); + AtomicBoolean failFlush = new AtomicBoolean(false); + SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol()) + { + @Override + protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine) + { + ByteBufferPool byteBufferPool = connector.getByteBufferPool(); + RetainableByteBufferPool retainableByteBufferPool = connector.getBean(RetainableByteBufferPool.class); + return new SslConnection(retainableByteBufferPool, byteBufferPool, connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption()) + { + @Override + protected boolean networkFlush(ByteBuffer output) throws IOException + { + if (failFlush.get()) + throw new IOException(); + return super.networkFlush(output); + } + }; + } + }; + connector = new ServerConnector(server, 1, 1, ssl, http); + server.addConnector(connector); + server.setHandler(new EmptyServerHandler() + { + @Override + protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) throws IOException + { + failFlush.set(true); + if (close) + jettyRequest.getHttpChannel().getEndPoint().close(); + else + jettyRequest.getHttpChannel().getEndPoint().shutdownOutput(); + } + }); + server.start(); + + SslContextFactory.Client clientTLSFactory = createClientSslContextFactory(); + ClientConnector clientConnector = new ClientConnector(); + clientConnector.setSelectors(1); + clientConnector.setSslContextFactory(clientTLSFactory); + QueuedThreadPool clientThreads = new QueuedThreadPool(); + clientThreads.setName("client"); + clientConnector.setExecutor(clientThreads); + client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector)); + client.setExecutor(clientThreads); + client.start(); + + assertThrows(Exception.class, () -> client.newRequest("localhost", connector.getLocalPort()).scheme(HttpScheme.HTTPS.asString()).send()); + + await().atMost(5, TimeUnit.SECONDS).until(() -> leakedBuffers, is(empty())); + } + @Test public void testNeverUsedConnectionThenClientIdleTimeout() throws Exception { @@ -780,12 +1083,7 @@ protected int networkFill(ByteBuffer input) throws IOException // Trigger the creation of a new connection, but don't use it. ConnectionPoolHelper.tryCreate(connectionPool); // Verify that the connection has been created. - while (true) - { - Thread.sleep(50); - if (connectionPool.getConnectionCount() == 1) - break; - } + await().atMost(5, TimeUnit.SECONDS).until(connectionPool::getConnectionCount, is(1)); // Wait for the client to idle timeout the connection. Thread.sleep(idleTimeout + idleTimeout / 2); diff --git a/jetty-io/src/main/java/org/eclipse/jetty/io/ArrayRetainableByteBufferPool.java b/jetty-io/src/main/java/org/eclipse/jetty/io/ArrayRetainableByteBufferPool.java index 3fe42c147ee1..cc2807918fb8 100644 --- a/jetty-io/src/main/java/org/eclipse/jetty/io/ArrayRetainableByteBufferPool.java +++ b/jetty-io/src/main/java/org/eclipse/jetty/io/ArrayRetainableByteBufferPool.java @@ -156,6 +156,11 @@ private RetainableByteBuffer newRetainableByteBuffer(int capacity, boolean direc return retainableByteBuffer; } + protected Pool poolFor(int capacity, boolean direct) + { + return bucketFor(capacity, direct); + } + private Bucket bucketFor(int capacity, boolean direct) { if (capacity < _minCapacity) diff --git a/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java b/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java index aeb34eab9c25..0aeb72f7e320 100644 --- a/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java +++ b/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java @@ -421,6 +421,8 @@ public String toConnectionString() private void releaseEncryptedInputBuffer() { + if (!_lock.isHeldByCurrentThread()) + throw new IllegalStateException(); if (_encryptedInput != null && !_encryptedInput.hasRemaining()) { _encryptedInput.release(); @@ -428,8 +430,10 @@ private void releaseEncryptedInputBuffer() } } - protected void releaseDecryptedInputBuffer() + private void releaseDecryptedInputBuffer() { + if (!_lock.isHeldByCurrentThread()) + throw new IllegalStateException(); if (_decryptedInput != null && !_decryptedInput.hasRemaining()) { _bufferPool.release(_decryptedInput); @@ -437,6 +441,21 @@ protected void releaseDecryptedInputBuffer() } } + private void releaseInputBuffers() + { + releaseEncryptedInputBuffer(); + releaseDecryptedInputBuffer(); + } + + private void clearInputBuffers() + { + if (!_lock.isHeldByCurrentThread()) + throw new IllegalStateException(); + if (_encryptedInput != null) + _encryptedInput.clear(); + BufferUtil.clear(_decryptedInput); + } + private void releaseEncryptedOutputBuffer() { if (!_lock.isHeldByCurrentThread()) @@ -790,6 +809,9 @@ public int fill(ByteBuffer buffer) throws IOException } catch (Throwable x) { + // Clear the input buffers so that releaseInputBuffers() + // in the finally block will not leak them. + clearInputBuffers(); Throwable f = handleException(x, "fill"); Throwable failure = handshakeFailed(f); if (_flushState == FlushState.WAIT_FOR_FILL) @@ -801,8 +823,7 @@ public int fill(ByteBuffer buffer) throws IOException } finally { - releaseEncryptedInputBuffer(); - releaseDecryptedInputBuffer(); + releaseInputBuffers(); if (_flushState == FlushState.WAIT_FOR_FILL) { @@ -988,26 +1009,26 @@ public boolean flush(ByteBuffer... appOuts) throws IOException } } - // finish of any previous flushes - if (_encryptedOutput != null) + Boolean result = null; + try { - int remaining = _encryptedOutput.remaining(); - if (remaining > 0) + // finish of any previous flushes + if (_encryptedOutput != null) { - boolean flushed = networkFlush(_encryptedOutput); - int written = remaining - _encryptedOutput.remaining(); - if (written > 0) - _bytesOut.addAndGet(written); - if (!flushed) - return false; + int remaining = _encryptedOutput.remaining(); + if (remaining > 0) + { + boolean flushed = networkFlush(_encryptedOutput); + int written = remaining - _encryptedOutput.remaining(); + if (written > 0) + _bytesOut.addAndGet(written); + if (!flushed) + return false; + } } - } - boolean isEmpty = BufferUtil.isEmpty(appOuts); + boolean isEmpty = BufferUtil.isEmpty(appOuts); - Boolean result = null; - try - { if (_flushState != FlushState.IDLE) return result = false; @@ -1159,6 +1180,10 @@ public boolean flush(ByteBuffer... appOuts) throws IOException } catch (Throwable x) { + // Clear the encrypted output buffer so that + // releaseEncryptedOutputBuffer() in the finally block + // will not leak it. + BufferUtil.clear(_encryptedOutput); Throwable failure = handleException(x, "flush"); throw handshakeFailed(failure); } @@ -1274,11 +1299,15 @@ else if (fillInterest) @Override public void doShutdownOutput() + { + doShutdownOutput(false); + } + + private void doShutdownOutput(boolean close) { EndPoint endPoint = getEndPoint(); try { - boolean close; boolean flush = false; try (AutoLock l = _lock.lock()) { @@ -1296,7 +1325,8 @@ public void doShutdownOutput() flush = !oshut; } - close = ishut; + if (!close) + close = ishut; } if (flush) @@ -1323,21 +1353,32 @@ public void doShutdownOutput() _flushState = FlushState.IDLE; releaseEncryptedOutputBuffer(); } - }, t -> endPoint.close()), write); + }, t -> disconnect()), write); } } } if (close) - endPoint.close(); + disconnect(); else ensureFillInterested(); } catch (Throwable x) { - LOG.trace("IGNORED", x); - endPoint.close(); + if (LOG.isTraceEnabled()) + LOG.trace("IGNORED", x); + disconnect(); + } + } + + private void disconnect() + { + try (AutoLock l = _lock.lock()) + { + BufferUtil.clear(_encryptedOutput); + releaseEncryptedOutputBuffer(); } + getEndPoint().close(); } private void closeOutbound() @@ -1382,9 +1423,13 @@ private boolean isOutboundDone() @Override public void doClose() { + try (AutoLock l = _lock.lock()) + { + clearInputBuffers(); + releaseInputBuffers(); + } // First send the TLS Close Alert, then the FIN. - doShutdownOutput(); - getEndPoint().close(); + doShutdownOutput(true); super.doClose(); }