diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java index 3b56753e7b8e..3e40d0b673af 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java @@ -22,18 +22,48 @@ import java.util.ArrayList; import java.util.List; +import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.websocket.api.MessageTooLargeException; public class ByteAccumulator { - private final List chunks = new ArrayList<>(); + private final List chunks = new ArrayList<>(); private final int maxSize; private int length = 0; + private final ByteBufferPool bufferPool; public ByteAccumulator(int maxOverallBufferSize) + { + this(maxOverallBufferSize, null); + } + + public ByteAccumulator(int maxOverallBufferSize, ByteBufferPool bufferPool) { this.maxSize = maxOverallBufferSize; + this.bufferPool = bufferPool; + } + + public void copyChunk(ByteBuffer buffer) + { + int length = buffer.remaining(); + if (this.length + length > maxSize) + { + String err = String.format("Resulting message size [%,d] is too large for configured max of [%,d]", this.length + length, maxSize); + throw new MessageTooLargeException(err); + } + + if (buffer.hasRemaining()) + { + chunks.add(buffer); + this.length += length; + } + else + { + // release 0 length buffer directly + if (bufferPool != null) + bufferPool.release((ByteBuffer)buffer.clear()); + } } public void copyChunk(byte[] buf, int offset, int length) @@ -43,11 +73,7 @@ public void copyChunk(byte[] buf, int offset, int length) String err = String.format("Resulting message size [%,d] is too large for configured max of [%,d]", this.length + length, maxSize); throw new MessageTooLargeException(err); } - - byte[] copy = new byte[length - offset]; - System.arraycopy(buf, offset, copy, 0, length); - - chunks.add(copy); + chunks.add(ByteBuffer.wrap(buf, offset, length)); this.length += length; } @@ -56,6 +82,20 @@ public int getLength() return length; } + int getMaxSize() + { + return maxSize; + } + + ByteBuffer newByteBuffer(int size) + { + if (bufferPool == null) + { + return ByteBuffer.allocate(size); + } + return (ByteBuffer)bufferPool.acquire(size, false).clear(); + } + public void transferTo(ByteBuffer buffer) { if (buffer.remaining() < length) @@ -65,10 +105,26 @@ public void transferTo(ByteBuffer buffer) } int position = buffer.position(); - for (byte[] chunk : chunks) + for (ByteBuffer chunk : chunks) { - buffer.put(chunk, 0, chunk.length); + buffer.put(chunk); } BufferUtil.flipToFlush(buffer, position); } + + void recycle() + { + length = 0; + + if (bufferPool == null) + { + return; + } + for (ByteBuffer chunk : chunks) + { + bufferPool.release((ByteBuffer)chunk.clear()); + } + + chunks.clear(); + } } diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java index 6952ccb67ea0..f7fbb6a1c1cc 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java @@ -46,7 +46,7 @@ public abstract class CompressExtension extends AbstractExtension protected static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF}; protected static final ByteBuffer TAIL_BYTES_BUF = ByteBuffer.wrap(TAIL_BYTES); private static final Logger LOG = Log.getLogger(CompressExtension.class); - + /** * Never drop tail bytes 0000FFFF, from any frame type */ @@ -92,6 +92,7 @@ public abstract class CompressExtension extends AbstractExtension private InflaterPool inflaterPool; private Deflater deflaterImpl; private Inflater inflaterImpl; + protected ByteAccumulator accumulator; protected AtomicInteger decompressCount = new AtomicInteger(0); private int tailDrop = TAIL_DROP_NEVER; private int rsvUse = RSV_USE_ALWAYS; @@ -176,7 +177,29 @@ protected void forwardIncoming(Frame frame, ByteAccumulator accumulator) protected ByteAccumulator newByteAccumulator() { int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageSize()); - return new ByteAccumulator(maxSize); + if (accumulator == null || accumulator.getMaxSize() != maxSize) + { + accumulator = new ByteAccumulator(maxSize, getBufferPool()); + } + return accumulator; + } + + int copyChunk(Inflater inflater, ByteAccumulator accumulator, ByteBuffer buf) throws DataFormatException + { + int position = 0; + int capacity = buf.capacity(); + while (position < capacity) + { + int read = inflater.inflate(buf.array(), position, capacity - position); + if (read <= 0) + { + accumulator.copyChunk((ByteBuffer)buf.position(position).flip()); + return read; + } + position += read; + } + accumulator.copyChunk((ByteBuffer)buf.position(position).flip()); + return position; } protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws DataFormatException @@ -185,10 +208,10 @@ protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws Da { return; } - byte[] output = new byte[DECOMPRESS_BUF_SIZE]; + Inflater inflater = getInflater(); - + while (buf.hasRemaining() && inflater.needsInput()) { if (!supplyInput(inflater, buf)) @@ -198,22 +221,14 @@ protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws Da return; } - int read; - while ((read = inflater.inflate(output)) >= 0) + while (true) { - if (read == 0) + ByteBuffer output = accumulator.newByteBuffer(DECOMPRESS_BUF_SIZE); + int read = copyChunk(inflater, accumulator, output); + if (read <= 0) { - if (LOG.isDebugEnabled()) - LOG.debug("Decompress: read 0 {}", toDetail(inflater)); break; } - else - { - // do something with output - if (LOG.isDebugEnabled()) - LOG.debug("Decompressed {} bytes: {}", read, toDetail(inflater)); - accumulator.copyChunk(output, 0, read); - } } } diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java index 0476c0fcc441..465204cee297 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java @@ -47,7 +47,7 @@ int getTailDropMode() { return TAIL_DROP_ALWAYS; } - + @Override public void incomingFrame(Frame frame) { @@ -63,10 +63,11 @@ public void incomingFrame(Frame frame) try { - ByteAccumulator accumulator = newByteAccumulator(); + accumulator = newByteAccumulator(); decompress(accumulator, frame.getPayload()); decompress(accumulator, TAIL_BYTES_BUF.slice()); forwardIncoming(frame, accumulator); + accumulator.recycle(); } catch (DataFormatException e) { diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java index 37482f8bd678..4a34d8783760 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java @@ -78,10 +78,9 @@ public void incomingFrame(Frame frame) throw new ProtocolException("Invalid RSV1 set on permessage-deflate CONTINUATION frame"); } - ByteAccumulator accumulator = newByteAccumulator(); - try { + accumulator = newByteAccumulator(); ByteBuffer payload = frame.getPayload(); decompress(accumulator, payload); if (frame.isFin()) @@ -90,6 +89,7 @@ public void incomingFrame(Frame frame) } forwardIncoming(frame, accumulator); + accumulator.recycle(); } catch (DataFormatException e) { diff --git a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java index 5aa41765e070..1c3dd1ffaead 100644 --- a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java +++ b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; +import org.eclipse.jetty.io.ArrayByteBufferPool; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.websocket.api.MessageTooLargeException; import org.junit.jupiter.api.Test; @@ -91,4 +92,77 @@ public void testCopyChunkNotEnoughSpace() MessageTooLargeException e = assertThrows(MessageTooLargeException.class, () -> accumulator.copyChunk(world, 0, world.length)); assertThat(e.getMessage(), containsString("too large for configured max")); } + + @Test + public void testRecycle() + { + ByteAccumulator accumulator = new ByteAccumulator(10_000, new ArrayByteBufferPool()); + ByteBuffer out0 = ByteBuffer.allocate(200); + ByteBuffer out1 = ByteBuffer.allocate(200); + { + // 1 + ByteBuffer buf = accumulator.newByteBuffer(10); + byte[] hello = "Hello".getBytes(UTF_8); + buf.put(hello).flip(); + accumulator.copyChunk(buf); + + // 2 + buf = accumulator.newByteBuffer(10); + byte[] space = " ".getBytes(UTF_8); + buf.put(space).flip(); + accumulator.copyChunk(buf); + + // 3 + buf = accumulator.newByteBuffer(10); + byte[] world = "World".getBytes(UTF_8); + buf.put(world).flip(); + accumulator.copyChunk(buf); + + assertThat("Length", accumulator.getLength(), is(hello.length + space.length + world.length)); + + accumulator.transferTo(out0); + + // reuse that byte[] + accumulator.recycle(); + } + + { + // 1 + ByteBuffer buf = accumulator.newByteBuffer(10); + byte[] olleh = "olleH".getBytes(UTF_8); + buf.put(olleh).flip(); + accumulator.copyChunk(buf); + + // 2 + buf = accumulator.newByteBuffer(10); + byte[] space = " ".getBytes(UTF_8); + buf.put(space).flip(); + accumulator.copyChunk(buf); + + // 3 + buf = accumulator.newByteBuffer(10); + byte[] dlrow = "dlroW".getBytes(UTF_8); + buf.put(dlrow).flip(); + accumulator.copyChunk(buf); + + // 4 + buf = accumulator.newByteBuffer(10); + byte[] done = " enoD".getBytes(UTF_8); + buf.put(done).flip(); + accumulator.copyChunk(buf); + + assertThat("Length", accumulator.getLength(), is(olleh.length + space.length + dlrow.length + done.length)); + + accumulator.transferTo(out1); + + // reuse that byte[] + accumulator.recycle(); + } + + String result0 = BufferUtil.toUTF8String(out0); + assertThat("result", result0, is("Hello World")); + + String result1 = BufferUtil.toUTF8String(out1); + assertThat("result", result1, is("olleH dlroW enoD")); + } }