diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java index 3b56753e7b8e..6e4ce2f7ef5f 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulator.java @@ -27,9 +27,11 @@ public class ByteAccumulator { - private final List chunks = new ArrayList<>(); + private List prevChunks = null; + private List nextChunks = new ArrayList<>(); private final int maxSize; private int length = 0; + private int index; public ByteAccumulator(int maxOverallBufferSize) { @@ -43,11 +45,7 @@ public void copyChunk(byte[] buf, int offset, int length) String err = String.format("Resulting message size [%,d] is too large for configured max of [%,d]", this.length + length, maxSize); throw new MessageTooLargeException(err); } - - byte[] copy = new byte[length - offset]; - System.arraycopy(buf, offset, copy, 0, length); - - chunks.add(copy); + nextChunks.add(ByteBuffer.wrap(buf, offset, length)); this.length += length; } @@ -56,6 +54,26 @@ public int getLength() return length; } + int getMaxSize() + { + return maxSize; + } + + ByteBuffer newByteBuffer(int size) + { + ByteBuffer buf; + if (prevChunks != null && prevChunks.size() > index) + { + buf = prevChunks.get(index); + } + else + { + buf = ByteBuffer.allocate(size); + } + index++; + return buf; + } + public void transferTo(ByteBuffer buffer) { if (buffer.remaining() < length) @@ -65,10 +83,29 @@ public void transferTo(ByteBuffer buffer) } int position = buffer.position(); - for (byte[] chunk : chunks) + for (ByteBuffer chunk : nextChunks) { - buffer.put(chunk, 0, chunk.length); + buffer.put(chunk); } BufferUtil.flipToFlush(buffer, position); } + + void recycle() + { + index = 0; + length = 0; + + int resize = 16; + if (prevChunks == null || nextChunks.size() > prevChunks.size()) + { + prevChunks = nextChunks; + } + + // keep prevChunks retain max resize elements + if (prevChunks.size() > resize) + { + prevChunks.subList(resize, prevChunks.size()).clear(); + } + nextChunks = new ArrayList<>(); + } } diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java index 6952ccb67ea0..5e4d9bf56015 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/CompressExtension.java @@ -46,6 +46,11 @@ public abstract class CompressExtension extends AbstractExtension protected static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF}; protected static final ByteBuffer TAIL_BYTES_BUF = ByteBuffer.wrap(TAIL_BYTES); private static final Logger LOG = Log.getLogger(CompressExtension.class); + + /** + * Accumulator + */ + protected ByteAccumulator accumulator; /** * Never drop tail bytes 0000FFFF, from any frame type @@ -176,7 +181,11 @@ protected void forwardIncoming(Frame frame, ByteAccumulator accumulator) protected ByteAccumulator newByteAccumulator() { int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageSize()); - return new ByteAccumulator(maxSize); + if (accumulator == null || accumulator.getMaxSize() != maxSize) + { + accumulator = new ByteAccumulator(maxSize); + } + return accumulator; } protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws DataFormatException @@ -185,10 +194,10 @@ protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws Da { return; } - byte[] output = new byte[DECOMPRESS_BUF_SIZE]; + Inflater inflater = getInflater(); - + while (buf.hasRemaining() && inflater.needsInput()) { if (!supplyInput(inflater, buf)) @@ -199,21 +208,25 @@ protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws Da } int read; - while ((read = inflater.inflate(output)) >= 0) + + loop: + while (true) { - if (read == 0) + ByteBuffer output = accumulator.newByteBuffer(DECOMPRESS_BUF_SIZE); + int offset = 0; + while (offset < output.capacity()) { - if (LOG.isDebugEnabled()) - LOG.debug("Decompress: read 0 {}", toDetail(inflater)); - break; - } - else - { - // do something with output - if (LOG.isDebugEnabled()) - LOG.debug("Decompressed {} bytes: {}", read, toDetail(inflater)); - accumulator.copyChunk(output, 0, read); + read = inflater.inflate(output.array(), offset, output.capacity() - offset); + if (read <= 0) + { + // last chunk + if (offset > 0) + accumulator.copyChunk(output.array(), 0, offset); + break loop; + } + offset += read; } + accumulator.copyChunk(output.array(), 0, offset); } } diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java index 0476c0fcc441..465204cee297 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/DeflateFrameExtension.java @@ -47,7 +47,7 @@ int getTailDropMode() { return TAIL_DROP_ALWAYS; } - + @Override public void incomingFrame(Frame frame) { @@ -63,10 +63,11 @@ public void incomingFrame(Frame frame) try { - ByteAccumulator accumulator = newByteAccumulator(); + accumulator = newByteAccumulator(); decompress(accumulator, frame.getPayload()); decompress(accumulator, TAIL_BYTES_BUF.slice()); forwardIncoming(frame, accumulator); + accumulator.recycle(); } catch (DataFormatException e) { diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java index 37482f8bd678..4a34d8783760 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/extensions/compress/PerMessageDeflateExtension.java @@ -78,10 +78,9 @@ public void incomingFrame(Frame frame) throw new ProtocolException("Invalid RSV1 set on permessage-deflate CONTINUATION frame"); } - ByteAccumulator accumulator = newByteAccumulator(); - try { + accumulator = newByteAccumulator(); ByteBuffer payload = frame.getPayload(); decompress(accumulator, payload); if (frame.isFin()) @@ -90,6 +89,7 @@ public void incomingFrame(Frame frame) } forwardIncoming(frame, accumulator); + accumulator.recycle(); } catch (DataFormatException e) { diff --git a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java index 5aa41765e070..caa69e2c975b 100644 --- a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java +++ b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/extensions/compress/ByteAccumulatorTest.java @@ -91,4 +91,77 @@ public void testCopyChunkNotEnoughSpace() MessageTooLargeException e = assertThrows(MessageTooLargeException.class, () -> accumulator.copyChunk(world, 0, world.length)); assertThat(e.getMessage(), containsString("too large for configured max")); } + + @Test + public void testRecycle() + { + ByteAccumulator accumulator = new ByteAccumulator(10_000); + ByteBuffer out0 = ByteBuffer.allocate(200); + ByteBuffer out1 = ByteBuffer.allocate(200); + { + // 1 + ByteBuffer buf = accumulator.newByteBuffer(10); + byte[] hello = "Hello".getBytes(UTF_8); + System.arraycopy(hello, 0, buf.array(), 0, hello.length); + accumulator.copyChunk(buf.array(), 0, hello.length); + + // 2 + buf = accumulator.newByteBuffer(10); + byte[] space = " ".getBytes(UTF_8); + System.arraycopy(space, 0, buf.array(), 0, space.length); + accumulator.copyChunk(buf.array(), 0, space.length); + + // 3 + buf = accumulator.newByteBuffer(10); + byte[] world = "World".getBytes(UTF_8); + System.arraycopy(world, 0, buf.array(), 0, world.length); + accumulator.copyChunk(buf.array(), 0, world.length); + + assertThat("Length", accumulator.getLength(), is(hello.length + space.length + world.length)); + + accumulator.transferTo(out0); + + // reuse that byte[] + accumulator.recycle(); + } + + { + // 1 + ByteBuffer buf = accumulator.newByteBuffer(10); + byte[] olleh = "olleH".getBytes(UTF_8); + System.arraycopy(olleh, 0, buf.array(), 0, olleh.length); + accumulator.copyChunk(buf.array(), 0, olleh.length); + + // 2 + buf = accumulator.newByteBuffer(10); + byte[] space = " ".getBytes(UTF_8); + System.arraycopy(space, 0, buf.array(), 0, space.length); + accumulator.copyChunk(buf.array(), 0, space.length); + + // 3 + buf = accumulator.newByteBuffer(10); + byte[] dlrow = "dlroW".getBytes(UTF_8); + System.arraycopy(dlrow, 0, buf.array(), 0, dlrow.length); + accumulator.copyChunk(buf.array(), 0, dlrow.length); + + // 4 + buf = accumulator.newByteBuffer(10); + byte[] done = " enoD".getBytes(UTF_8); + System.arraycopy(done, 0, buf.array(), 0, done.length); + accumulator.copyChunk(buf.array(), 0, done.length); + + assertThat("Length", accumulator.getLength(), is(olleh.length + space.length + dlrow.length + done.length)); + + accumulator.transferTo(out1); + + // reuse that byte[] + accumulator.recycle(); + } + + String result0 = BufferUtil.toUTF8String(out0); + assertThat("result", result0, is("Hello World")); + + String result1 = BufferUtil.toUTF8String(out1); + assertThat("result", result1, is("olleH dlroW enoD")); + } }