diff --git a/src/main/example/PerMessageDeflateExample.java b/src/main/example/PerMessageDeflateExample.java
new file mode 100644
index 00000000..94bc3fd2
--- /dev/null
+++ b/src/main/example/PerMessageDeflateExample.java
@@ -0,0 +1,72 @@
+import org.java_websocket.WebSocket;
+import org.java_websocket.client.WebSocketClient;
+import org.java_websocket.drafts.Draft;
+import org.java_websocket.drafts.Draft_6455;
+import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension;
+import org.java_websocket.handshake.ClientHandshake;
+import org.java_websocket.handshake.ServerHandshake;
+import org.java_websocket.server.WebSocketServer;
+
+import java.net.InetSocketAddress;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Collections;
+
+/**
+ * This class only serves the purpose of showing how to enable PerMessageDeflateExtension for both server and client sockets.
+ * Extensions are required to be registered in
+ * @see Draft objects and both
+ * @see WebSocketClient and
+ * @see WebSocketServer accept a
+ * @see Draft object in their constructors.
+ * This example shows how to achieve it for both server and client sockets.
+ * Once the connection has been established, PerMessageDeflateExtension will be enabled
+ * and any messages (binary or text) will be compressed/decompressed automatically.
+ * Since no additional code is required when sending or receiving messages, this example skips those parts.
+ */
+public class PerMessageDeflateExample {
+
+ private static final Draft perMessageDeflateDraft = new Draft_6455(new PerMessageDeflateExtension());
+ private static final int PORT = 8887;
+
+ private static class DeflateClient extends WebSocketClient {
+
+ public DeflateClient() throws URISyntaxException {
+ super(new URI("ws://localhost:" + PORT), perMessageDeflateDraft);
+ }
+
+ @Override
+ public void onOpen(ServerHandshake handshakedata) { }
+
+ @Override
+ public void onMessage(String message) { }
+
+ @Override
+ public void onClose(int code, String reason, boolean remote) { }
+
+ @Override
+ public void onError(Exception ex) { }
+ }
+
+ private static class DeflateServer extends WebSocketServer {
+
+ public DeflateServer() {
+ super(new InetSocketAddress(PORT), Collections.singletonList(perMessageDeflateDraft));
+ }
+
+ @Override
+ public void onOpen(WebSocket conn, ClientHandshake handshake) { }
+
+ @Override
+ public void onClose(WebSocket conn, int code, String reason, boolean remote) { }
+
+ @Override
+ public void onMessage(WebSocket conn, String message) { }
+
+ @Override
+ public void onError(WebSocket conn, Exception ex) { }
+
+ @Override
+ public void onStart() { }
+ }
+}
diff --git a/src/main/java/org/java_websocket/drafts/Draft_6455.java b/src/main/java/org/java_websocket/drafts/Draft_6455.java
index 126f154d..4c19daf8 100644
--- a/src/main/java/org/java_websocket/drafts/Draft_6455.java
+++ b/src/main/java/org/java_websocket/drafts/Draft_6455.java
@@ -427,6 +427,12 @@ private ByteBuffer createByteBufferFromFramedata( Framedata framedata ) {
byte optcode = fromOpcode( framedata.getOpcode() );
byte one = ( byte ) ( framedata.isFin() ? -128 : 0 );
one |= optcode;
+ if(framedata.isRSV1())
+ one |= getRSVByte(1);
+ if(framedata.isRSV2())
+ one |= getRSVByte(2);
+ if(framedata.isRSV3())
+ one |= getRSVByte(3);
buf.put( one );
byte[] payloadlengthbytes = toByteArray( mes.remaining(), sizebytes );
assert ( payloadlengthbytes.length == sizebytes );
@@ -585,6 +591,27 @@ private void translateSingleFrameCheckPacketSize(int maxpacketsize, int realpack
}
}
+ /**
+ * Get a byte that can set RSV bits when OR(|)'d.
+ * 0 1 2 3 4 5 6 7
+ * +-+-+-+-+-------+
+ * |F|R|R|R| opcode|
+ * |I|S|S|S| (4) |
+ * |N|V|V|V| |
+ * | |1|2|3| |
+ * @param rsv Can only be {0, 1, 2, 3}
+ * @return byte that represents which RSV bit is set.
+ */
+ private byte getRSVByte(int rsv){
+ if(rsv == 1) // 0100 0000
+ return 0x40;
+ if(rsv == 2) // 0010 0000
+ return 0x20;
+ if(rsv == 3) // 0001 0000
+ return 0x10;
+ return 0;
+ }
+
/**
* Get the mask byte if existing
* @param mask is mask active or not
diff --git a/src/main/java/org/java_websocket/extensions/ExtensionRequestData.java b/src/main/java/org/java_websocket/extensions/ExtensionRequestData.java
new file mode 100644
index 00000000..639dd802
--- /dev/null
+++ b/src/main/java/org/java_websocket/extensions/ExtensionRequestData.java
@@ -0,0 +1,52 @@
+package org.java_websocket.extensions;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+public class ExtensionRequestData {
+
+ public static String EMPTY_VALUE = "";
+
+ private Map extensionParameters;
+ private String extensionName;
+
+ private ExtensionRequestData() {
+ extensionParameters = new LinkedHashMap();
+ }
+
+ public static ExtensionRequestData parseExtensionRequest(String extensionRequest) {
+ ExtensionRequestData extensionData = new ExtensionRequestData();
+ String[] parts = extensionRequest.split(";");
+ extensionData.extensionName = parts[0].trim();
+
+ for(int i = 1; i < parts.length; i++) {
+ String[] keyValue = parts[i].split("=");
+ String value = EMPTY_VALUE;
+
+ // Some parameters don't take a value. For those that do, parse the value.
+ if(keyValue.length > 1) {
+ String tempValue = keyValue[1].trim();
+
+ // If the value is wrapped in quotes, just get the data between them.
+ if((tempValue.startsWith("\"") && tempValue.endsWith("\""))
+ || (tempValue.startsWith("'") && tempValue.endsWith("'"))
+ && tempValue.length() > 2)
+ tempValue = tempValue.substring(1, tempValue.length() - 1);
+
+ value = tempValue;
+ }
+
+ extensionData.extensionParameters.put(keyValue[0].trim(), value);
+ }
+
+ return extensionData;
+ }
+
+ public String getExtensionName() {
+ return extensionName;
+ }
+
+ public Map getExtensionParameters() {
+ return extensionParameters;
+ }
+}
diff --git a/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java b/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java
new file mode 100644
index 00000000..296a20b7
--- /dev/null
+++ b/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java
@@ -0,0 +1,236 @@
+package org.java_websocket.extensions.permessage_deflate;
+
+import org.java_websocket.enums.Opcode;
+import org.java_websocket.exceptions.InvalidDataException;
+import org.java_websocket.exceptions.InvalidFrameException;
+import org.java_websocket.extensions.CompressionExtension;
+import org.java_websocket.extensions.ExtensionRequestData;
+import org.java_websocket.extensions.IExtension;
+import org.java_websocket.framing.*;
+
+import java.io.ByteArrayOutputStream;
+import java.nio.ByteBuffer;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.zip.DataFormatException;
+import java.util.zip.Deflater;
+import java.util.zip.Inflater;
+
+public class PerMessageDeflateExtension extends CompressionExtension {
+
+ // Name of the extension as registered by IETF https://tools.ietf.org/html/rfc7692#section-9.
+ private static final String EXTENSION_REGISTERED_NAME = "permessage-deflate";
+ // Below values are defined for convenience. They are not used in the compression/decompression phase.
+ // They may be needed during the extension-negotiation offer in the future.
+ private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
+ private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
+ private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
+ private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
+ private static final int serverMaxWindowBits = 1 << 15;
+ private static final int clientMaxWindowBits = 1 << 15;
+ private static final byte[] TAIL_BYTES = {0x00, 0x00, (byte)0xFF, (byte)0xFF};
+ private static final int BUFFER_SIZE = 1 << 10;
+
+ private boolean serverNoContextTakeover = true;
+ private boolean clientNoContextTakeover = false;
+
+ // For WebSocketServers, this variable holds the extension parameters that the peer client has requested.
+ // For WebSocketClients, this variable holds the extension parameters that client himself has requested.
+ private Map requestedParameters = new LinkedHashMap();
+ private Inflater inflater = new Inflater(true);
+ private Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
+
+ /*
+ An endpoint uses the following algorithm to decompress a message.
+ 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
+ payload of the message.
+ 2. Decompress the resulting data using DEFLATE.
+ See, https://tools.ietf.org/html/rfc7692#section-7.2.2
+ */
+ @Override
+ public void decodeFrame(Framedata inputFrame) throws InvalidDataException {
+ // Only DataFrames can be decompressed.
+ if(!(inputFrame instanceof DataFrame))
+ return;
+
+ // RSV1 bit must be set only for the first frame.
+ if(inputFrame.getOpcode() == Opcode.CONTINUOUS && inputFrame.isRSV1())
+ throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, "RSV1 bit can only be set for the first frame.");
+
+ // Decompressed output buffer.
+ ByteArrayOutputStream output = new ByteArrayOutputStream();
+ try {
+ decompress(inputFrame.getPayloadData().array(), output);
+
+ /*
+ If a message is "first fragmented and then compressed", as this project does, then the inflater
+ can not inflate fragments except the first one.
+ This behavior occurs most likely because those fragments end with "final deflate blocks".
+ We can check the getRemaining() method to see whether the data we supplied has been decompressed or not.
+ And if not, we just reset the inflater and decompress again.
+ Note that this behavior doesn't occur if the message is "first compressed and then fragmented".
+ */
+ if(inflater.getRemaining() > 0){
+ inflater = new Inflater(true);
+ decompress(inputFrame.getPayloadData().array(), output);
+ }
+
+ if(inputFrame.isFin()) {
+ decompress(TAIL_BYTES, output);
+ // If context takeover is disabled, inflater can be reset.
+ if(clientNoContextTakeover)
+ inflater = new Inflater(true);
+ }
+ } catch (DataFormatException e) {
+ throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, e.getMessage());
+ }
+
+ // RSV1 bit must be cleared after decoding, so that other extensions don't throw an exception.
+ if(inputFrame.isRSV1())
+ ((DataFrame) inputFrame).setRSV1(false);
+
+ // Set frames payload to the new decompressed data.
+ ((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(output.toByteArray(), 0, output.size()));
+ }
+
+ private void decompress(byte[] data, ByteArrayOutputStream outputBuffer) throws DataFormatException{
+ inflater.setInput(data);
+ byte[] buffer = new byte[BUFFER_SIZE];
+
+ int bytesInflated;
+ while((bytesInflated = inflater.inflate(buffer)) > 0){
+ outputBuffer.write(buffer, 0, bytesInflated);
+ }
+ }
+
+ @Override
+ public void encodeFrame(Framedata inputFrame) {
+ // Only DataFrames can be decompressed.
+ if(!(inputFrame instanceof DataFrame))
+ return;
+
+ // Only the first frame's RSV1 must be set.
+ if(!(inputFrame instanceof ContinuousFrame))
+ ((DataFrame) inputFrame).setRSV1(true);
+
+ deflater.setInput(inputFrame.getPayloadData().array());
+ // Compressed output buffer.
+ ByteArrayOutputStream output = new ByteArrayOutputStream();
+ // Temporary buffer to hold compressed output.
+ byte[] buffer = new byte[1024];
+ int bytesCompressed;
+ while((bytesCompressed = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH)) > 0) {
+ output.write(buffer, 0, bytesCompressed);
+ }
+
+ byte outputBytes[] = output.toByteArray();
+ int outputLength = outputBytes.length;
+
+ /*
+ https://tools.ietf.org/html/rfc7692#section-7.2.1 states that if the final fragment's compressed
+ payload ends with 0x00 0x00 0xff 0xff, they should be removed.
+ To simulate removal, we just pass 4 bytes less to the new payload
+ if the frame is final and outputBytes ends with 0x00 0x00 0xff 0xff.
+ */
+ if(inputFrame.isFin()) {
+ if(endsWithTail(outputBytes))
+ outputLength -= TAIL_BYTES.length;
+
+ if(serverNoContextTakeover) {
+ deflater.end();
+ deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
+ }
+ }
+
+ // Set frames payload to the new compressed data.
+ ((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(outputBytes, 0, outputLength));
+ }
+
+ private boolean endsWithTail(byte[] data){
+ if(data.length < 4)
+ return false;
+
+ int length = data.length;
+ for(int i = 0; i < TAIL_BYTES.length; i++){
+ if(TAIL_BYTES[i] != data[length - TAIL_BYTES.length + i])
+ return false;
+ }
+
+ return true;
+ }
+
+ @Override
+ public boolean acceptProvidedExtensionAsServer(String inputExtension) {
+ String[] requestedExtensions = inputExtension.split(",");
+ for(String extension : requestedExtensions) {
+ ExtensionRequestData extensionData = ExtensionRequestData.parseExtensionRequest(extension);
+ if(!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName()))
+ continue;
+
+ // Holds parameters that peer client has sent.
+ Map headers = extensionData.getExtensionParameters();
+ requestedParameters.putAll(headers);
+ if(requestedParameters.containsKey(CLIENT_NO_CONTEXT_TAKEOVER))
+ clientNoContextTakeover = true;
+
+ return true;
+ }
+
+ return false;
+ }
+
+ @Override
+ public boolean acceptProvidedExtensionAsClient(String inputExtension) {
+ String[] requestedExtensions = inputExtension.split(",");
+ for(String extension : requestedExtensions) {
+ ExtensionRequestData extensionData = ExtensionRequestData.parseExtensionRequest(extension);
+ if(!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName()))
+ continue;
+
+ // Holds parameters that are sent by the server, as a response to our initial extension request.
+ Map headers = extensionData.getExtensionParameters();
+ // After this point, parameters that the server sent back can be configured, but we don't use them for now.
+ return true;
+ }
+
+ return false;
+ }
+
+ @Override
+ public String getProvidedExtensionAsClient() {
+ requestedParameters.put(CLIENT_NO_CONTEXT_TAKEOVER, ExtensionRequestData.EMPTY_VALUE);
+ requestedParameters.put(SERVER_NO_CONTEXT_TAKEOVER, ExtensionRequestData.EMPTY_VALUE);
+
+ return EXTENSION_REGISTERED_NAME + "; " + SERVER_NO_CONTEXT_TAKEOVER + "; " + CLIENT_NO_CONTEXT_TAKEOVER;
+ }
+
+ @Override
+ public String getProvidedExtensionAsServer() {
+ return EXTENSION_REGISTERED_NAME
+ + "; " + SERVER_NO_CONTEXT_TAKEOVER
+ + (clientNoContextTakeover ? "; " + CLIENT_NO_CONTEXT_TAKEOVER : "");
+ }
+
+ @Override
+ public IExtension copyInstance() {
+ return new PerMessageDeflateExtension();
+ }
+
+ /**
+ * This extension requires the RSV1 bit to be set only for the first frame.
+ * If the frame is type is CONTINUOUS, RSV1 bit must be unset.
+ */
+ @Override
+ public void isFrameValid(Framedata inputFrame) throws InvalidDataException {
+ if((inputFrame instanceof TextFrame || inputFrame instanceof BinaryFrame) && !inputFrame.isRSV1())
+ throw new InvalidFrameException("RSV1 bit must be set for DataFrames.");
+ if((inputFrame instanceof ContinuousFrame) && (inputFrame.isRSV1() || inputFrame.isRSV2() || inputFrame.isRSV3()))
+ throw new InvalidFrameException( "bad rsv RSV1: " + inputFrame.isRSV1() + " RSV2: " + inputFrame.isRSV2() + " RSV3: " + inputFrame.isRSV3() );
+ super.isFrameValid(inputFrame);
+ }
+
+ @Override
+ public String toString() {
+ return "PerMessageDeflateExtension";
+ }
+}
diff --git a/src/test/java/org/java_websocket/example/AutobahnServerTest.java b/src/test/java/org/java_websocket/example/AutobahnServerTest.java
index 61eb4291..e07e3355 100644
--- a/src/test/java/org/java_websocket/example/AutobahnServerTest.java
+++ b/src/test/java/org/java_websocket/example/AutobahnServerTest.java
@@ -28,6 +28,7 @@
import org.java_websocket.WebSocket;
import org.java_websocket.drafts.Draft;
import org.java_websocket.drafts.Draft_6455;
+import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
@@ -101,7 +102,7 @@ public static void main( String[] args ) throws UnknownHostException {
System.out.println( "No limit specified. Defaulting to MaxInteger" );
limit = Integer.MAX_VALUE;
}
- AutobahnServerTest test = new AutobahnServerTest( port, limit, new Draft_6455() );
+ AutobahnServerTest test = new AutobahnServerTest( port, limit, new Draft_6455( new PerMessageDeflateExtension()) );
test.setConnectionLostTimeout( 0 );
test.start();
}
diff --git a/src/test/java/org/java_websocket/extensions/PerMessageDeflateExtensionTest.java b/src/test/java/org/java_websocket/extensions/PerMessageDeflateExtensionTest.java
new file mode 100644
index 00000000..d8307279
--- /dev/null
+++ b/src/test/java/org/java_websocket/extensions/PerMessageDeflateExtensionTest.java
@@ -0,0 +1,120 @@
+package org.java_websocket.extensions;
+
+import org.java_websocket.exceptions.InvalidDataException;
+import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension;
+import org.java_websocket.framing.ContinuousFrame;
+import org.java_websocket.framing.TextFrame;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+
+import static org.junit.Assert.*;
+
+public class PerMessageDeflateExtensionTest {
+
+ @Test
+ public void testDecodeFrame() throws InvalidDataException {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ String str = "This is a highly compressable text"
+ + "This is a highly compressable text"
+ + "This is a highly compressable text"
+ + "This is a highly compressable text"
+ + "This is a highly compressable text";
+ byte[] message = str.getBytes();
+ TextFrame frame = new TextFrame();
+ frame.setPayload(ByteBuffer.wrap(message));
+ deflateExtension.encodeFrame(frame);
+ deflateExtension.decodeFrame(frame);
+ assertArrayEquals(message, frame.getPayloadData().array());
+ }
+
+ @Test
+ public void testEncodeFrame() {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ String str = "This is a highly compressable text"
+ + "This is a highly compressable text"
+ + "This is a highly compressable text"
+ + "This is a highly compressable text"
+ + "This is a highly compressable text";
+ byte[] message = str.getBytes();
+ TextFrame frame = new TextFrame();
+ frame.setPayload(ByteBuffer.wrap(message));
+ deflateExtension.encodeFrame(frame);
+ assertTrue(message.length > frame.getPayloadData().array().length);
+ }
+
+ @Test
+ public void testAcceptProvidedExtensionAsServer() {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ assertTrue(deflateExtension.acceptProvidedExtensionAsServer("permessage-deflate"));
+ assertTrue(deflateExtension.acceptProvidedExtensionAsServer("some-other-extension, permessage-deflate"));
+ assertFalse(deflateExtension.acceptProvidedExtensionAsServer("wrong-permessage-deflate"));
+ }
+
+ @Test
+ public void testAcceptProvidedExtensionAsClient() {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ assertTrue(deflateExtension.acceptProvidedExtensionAsClient("permessage-deflate"));
+ assertTrue(deflateExtension.acceptProvidedExtensionAsClient("some-other-extension, permessage-deflate"));
+ assertFalse(deflateExtension.acceptProvidedExtensionAsClient("wrong-permessage-deflate"));
+ }
+
+ @Test
+ public void testIsFrameValid() {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ TextFrame frame = new TextFrame();
+ try {
+ deflateExtension.isFrameValid(frame);
+ fail("Frame not valid. RSV1 must be set.");
+ } catch (Exception e) {
+ //
+ }
+ frame.setRSV1(true);
+ try {
+ deflateExtension.isFrameValid(frame);
+ } catch (Exception e) {
+ fail("Frame is valid.");
+ }
+ frame.setRSV2(true);
+ try {
+ deflateExtension.isFrameValid(frame);
+ fail("Only RSV1 bit must be set.");
+ } catch (Exception e) {
+ //
+ }
+ ContinuousFrame contFrame = new ContinuousFrame();
+ contFrame.setRSV1(true);
+ try {
+ deflateExtension.isFrameValid(contFrame);
+ fail("RSV1 must only be set for first fragments.Continuous frames can't have RSV1 bit set.");
+ } catch (Exception e) {
+ //
+ }
+ contFrame.setRSV1(false);
+ try {
+ deflateExtension.isFrameValid(contFrame);
+ } catch (Exception e) {
+ fail("Continuous frame is valid.");
+ }
+ }
+
+ @Test
+ public void testGetProvidedExtensionAsClient() {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ assertEquals( "permessage-deflate; server_no_context_takeover; client_no_context_takeover",
+ deflateExtension.getProvidedExtensionAsClient() );
+ }
+
+ @Test
+ public void testGetProvidedExtensionAsServer() {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ assertEquals( "permessage-deflate; server_no_context_takeover; client_no_context_takeover",
+ deflateExtension.getProvidedExtensionAsServer() );
+ }
+
+ @Test
+ public void testToString() throws Exception {
+ PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
+ assertEquals( "PerMessageDeflateExtension", deflateExtension.toString() );
+ }
+}