Skip to content

Commit

Permalink
Fix issue jetty#5499
Browse files Browse the repository at this point in the history
this PR let the ByteAccumulator recyclable. after invoke ByteAccumulator.transferTo method
we can invoke ByteAccumulator.recycle method to reuse byte[] via ByteAccumulator.newByteBuffer method

Signed-off-by: Baoyi Chen <chen.bao.yi@qq.com>
  • Loading branch information
leonchen83 authored and Baoyi Chen committed Oct 30, 2020
1 parent 47885f7 commit 7cbe501
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 28 deletions.
Expand Up @@ -22,18 +22,50 @@
import java.util.ArrayList;
import java.util.List;

import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.api.MessageTooLargeException;

public class ByteAccumulator
{
private final List<byte[]> chunks = new ArrayList<>();
private final List<ByteBuffer> chunks = new ArrayList<>();
private final int maxSize;
private int length = 0;
private final ByteBufferPool bufferPool;

public ByteAccumulator(int maxOverallBufferSize)
{
this(maxOverallBufferSize, null);
}

public ByteAccumulator(int maxOverallBufferSize, ByteBufferPool bufferPool)
{
this.maxSize = maxOverallBufferSize;
this.bufferPool = bufferPool;
}

public void copyChunk(ByteBuffer buffer)
{
int length = buffer.remaining();
if (this.length + length > maxSize)
{
if (bufferPool != null)
bufferPool.release(buffer);
String err = String.format("Resulting message size [%,d] is too large for configured max of [%,d]", this.length + length, maxSize);
throw new MessageTooLargeException(err);
}

if (buffer.hasRemaining())
{
chunks.add(buffer);
this.length += length;
}
else
{
// release 0 length buffer directly
if (bufferPool != null)
bufferPool.release(buffer);
}
}

public void copyChunk(byte[] buf, int offset, int length)
Expand All @@ -43,11 +75,7 @@ public void copyChunk(byte[] buf, int offset, int length)
String err = String.format("Resulting message size [%,d] is too large for configured max of [%,d]", this.length + length, maxSize);
throw new MessageTooLargeException(err);
}

byte[] copy = new byte[length - offset];
System.arraycopy(buf, offset, copy, 0, length);

chunks.add(copy);
chunks.add(ByteBuffer.wrap(buf, offset, length));
this.length += length;
}

Expand All @@ -56,6 +84,20 @@ public int getLength()
return length;
}

int getMaxSize()
{
return maxSize;
}

ByteBuffer newByteBuffer(int size)
{
if (bufferPool == null)
{
return ByteBuffer.allocate(size);
}
return (ByteBuffer)bufferPool.acquire(size, false).clear();
}

public void transferTo(ByteBuffer buffer)
{
if (buffer.remaining() < length)
Expand All @@ -65,10 +107,25 @@ public void transferTo(ByteBuffer buffer)
}

int position = buffer.position();
for (byte[] chunk : chunks)
for (ByteBuffer chunk : chunks)
{
buffer.put(chunk, 0, chunk.length);
buffer.put(chunk);
}
BufferUtil.flipToFlush(buffer, position);
}

void recycle()
{
length = 0;

if (bufferPool != null)
{
for (ByteBuffer chunk : chunks)
{
bufferPool.release(chunk);
}
}

chunks.clear();
}
}
Expand Up @@ -46,7 +46,7 @@ public abstract class CompressExtension extends AbstractExtension
protected static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF};
protected static final ByteBuffer TAIL_BYTES_BUF = ByteBuffer.wrap(TAIL_BYTES);
private static final Logger LOG = Log.getLogger(CompressExtension.class);

/**
* Never drop tail bytes 0000FFFF, from any frame type
*/
Expand Down Expand Up @@ -92,6 +92,7 @@ public abstract class CompressExtension extends AbstractExtension
private InflaterPool inflaterPool;
private Deflater deflaterImpl;
private Inflater inflaterImpl;
protected ByteAccumulator accumulator;
protected AtomicInteger decompressCount = new AtomicInteger(0);
private int tailDrop = TAIL_DROP_NEVER;
private int rsvUse = RSV_USE_ALWAYS;
Expand Down Expand Up @@ -176,7 +177,39 @@ protected void forwardIncoming(Frame frame, ByteAccumulator accumulator)
protected ByteAccumulator newByteAccumulator()
{
int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageSize());
return new ByteAccumulator(maxSize);
if (accumulator == null || accumulator.getMaxSize() != maxSize)
{
accumulator = new ByteAccumulator(maxSize, getBufferPool());
}
return accumulator;
}

int copyChunk(Inflater inflater, ByteAccumulator accumulator) throws DataFormatException
{
ByteBuffer buf = accumulator.newByteBuffer(DECOMPRESS_BUF_SIZE);
while (buf.hasRemaining())
{
try
{
int read = inflater.inflate(buf.array(), buf.position(), buf.remaining());
if (read <= 0)
{
accumulator.copyChunk((ByteBuffer)buf.flip());
return read;
}
buf.position(buf.position() + read);
}
catch (DataFormatException e)
{
// must add chunk to accumulator
// so that recycle in subclass's finally block
accumulator.copyChunk((ByteBuffer)buf.flip());
throw e;
}
}
int position = buf.position();
accumulator.copyChunk((ByteBuffer)buf.flip());
return position;
}

protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws DataFormatException
Expand All @@ -185,10 +218,10 @@ protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws Da
{
return;
}
byte[] output = new byte[DECOMPRESS_BUF_SIZE];


Inflater inflater = getInflater();

while (buf.hasRemaining() && inflater.needsInput())
{
if (!supplyInput(inflater, buf))
Expand All @@ -198,22 +231,12 @@ protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws Da
return;
}

int read;
while ((read = inflater.inflate(output)) >= 0)
while (true)
{
if (read == 0)
if (copyChunk(inflater, accumulator) <= 0)
{
if (LOG.isDebugEnabled())
LOG.debug("Decompress: read 0 {}", toDetail(inflater));
break;
}
else
{
// do something with output
if (LOG.isDebugEnabled())
LOG.debug("Decompressed {} bytes: {}", read, toDetail(inflater));
accumulator.copyChunk(output, 0, read);
}
}
}

Expand Down
Expand Up @@ -47,7 +47,7 @@ int getTailDropMode()
{
return TAIL_DROP_ALWAYS;
}

@Override
public void incomingFrame(Frame frame)
{
Expand All @@ -63,7 +63,7 @@ public void incomingFrame(Frame frame)

try
{
ByteAccumulator accumulator = newByteAccumulator();
accumulator = newByteAccumulator();
decompress(accumulator, frame.getPayload());
decompress(accumulator, TAIL_BYTES_BUF.slice());
forwardIncoming(frame, accumulator);
Expand All @@ -72,5 +72,10 @@ public void incomingFrame(Frame frame)
{
throw new BadPayloadException(e);
}
finally
{
if (accumulator != null)
accumulator.recycle();
}
}
}
Expand Up @@ -78,10 +78,9 @@ public void incomingFrame(Frame frame)
throw new ProtocolException("Invalid RSV1 set on permessage-deflate CONTINUATION frame");
}

ByteAccumulator accumulator = newByteAccumulator();

try
{
accumulator = newByteAccumulator();
ByteBuffer payload = frame.getPayload();
decompress(accumulator, payload);
if (frame.isFin())
Expand All @@ -90,11 +89,17 @@ public void incomingFrame(Frame frame)
}

forwardIncoming(frame, accumulator);

}
catch (DataFormatException e)
{
throw new BadPayloadException(e);
}
finally
{
if (accumulator != null)
accumulator.recycle();
}

if (frame.isFin())
incomingCompressed = false;
Expand Down
Expand Up @@ -20,6 +20,7 @@

import java.nio.ByteBuffer;

import org.eclipse.jetty.io.ArrayByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.api.MessageTooLargeException;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -91,4 +92,77 @@ public void testCopyChunkNotEnoughSpace()
MessageTooLargeException e = assertThrows(MessageTooLargeException.class, () -> accumulator.copyChunk(world, 0, world.length));
assertThat(e.getMessage(), containsString("too large for configured max"));
}

@Test
public void testRecycle()
{
ByteAccumulator accumulator = new ByteAccumulator(10_000, new ArrayByteBufferPool());
ByteBuffer out0 = ByteBuffer.allocate(200);
ByteBuffer out1 = ByteBuffer.allocate(200);
{
// 1
ByteBuffer buf = accumulator.newByteBuffer(10);
byte[] hello = "Hello".getBytes(UTF_8);
buf.put(hello).flip();
accumulator.copyChunk(buf);

// 2
buf = accumulator.newByteBuffer(10);
byte[] space = " ".getBytes(UTF_8);
buf.put(space).flip();
accumulator.copyChunk(buf);

// 3
buf = accumulator.newByteBuffer(10);
byte[] world = "World".getBytes(UTF_8);
buf.put(world).flip();
accumulator.copyChunk(buf);

assertThat("Length", accumulator.getLength(), is(hello.length + space.length + world.length));

accumulator.transferTo(out0);

// reuse that byte[]
accumulator.recycle();
}

{
// 1
ByteBuffer buf = accumulator.newByteBuffer(10);
byte[] olleh = "olleH".getBytes(UTF_8);
buf.put(olleh).flip();
accumulator.copyChunk(buf);

// 2
buf = accumulator.newByteBuffer(10);
byte[] space = " ".getBytes(UTF_8);
buf.put(space).flip();
accumulator.copyChunk(buf);

// 3
buf = accumulator.newByteBuffer(10);
byte[] dlrow = "dlroW".getBytes(UTF_8);
buf.put(dlrow).flip();
accumulator.copyChunk(buf);

// 4
buf = accumulator.newByteBuffer(10);
byte[] done = " enoD".getBytes(UTF_8);
buf.put(done).flip();
accumulator.copyChunk(buf);

assertThat("Length", accumulator.getLength(), is(olleh.length + space.length + dlrow.length + done.length));

accumulator.transferTo(out1);

// reuse that byte[]
accumulator.recycle();
}

String result0 = BufferUtil.toUTF8String(out0);
assertThat("result", result0, is("Hello World"));

String result1 = BufferUtil.toUTF8String(out1);
assertThat("result", result1, is("olleH dlroW enoD"));
}
}

0 comments on commit 7cbe501

Please sign in to comment.