diff --git a/src/main/example/PerMessageDeflateExample.java b/src/main/example/PerMessageDeflateExample.java new file mode 100644 index 00000000..94bc3fd2 --- /dev/null +++ b/src/main/example/PerMessageDeflateExample.java @@ -0,0 +1,72 @@ +import org.java_websocket.WebSocket; +import org.java_websocket.client.WebSocketClient; +import org.java_websocket.drafts.Draft; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension; +import org.java_websocket.handshake.ClientHandshake; +import org.java_websocket.handshake.ServerHandshake; +import org.java_websocket.server.WebSocketServer; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Collections; + +/** + * This class only serves the purpose of showing how to enable PerMessageDeflateExtension for both server and client sockets.
+ * Extensions are required to be registered in + * @see Draft objects and both + * @see WebSocketClient and + * @see WebSocketServer accept a + * @see Draft object in their constructors. + * This example shows how to achieve it for both server and client sockets. + * Once the connection has been established, PerMessageDeflateExtension will be enabled + * and any messages (binary or text) will be compressed/decompressed automatically.
+ * Since no additional code is required when sending or receiving messages, this example skips those parts. + */ +public class PerMessageDeflateExample { + + private static final Draft perMessageDeflateDraft = new Draft_6455(new PerMessageDeflateExtension()); + private static final int PORT = 8887; + + private static class DeflateClient extends WebSocketClient { + + public DeflateClient() throws URISyntaxException { + super(new URI("ws://localhost:" + PORT), perMessageDeflateDraft); + } + + @Override + public void onOpen(ServerHandshake handshakedata) { } + + @Override + public void onMessage(String message) { } + + @Override + public void onClose(int code, String reason, boolean remote) { } + + @Override + public void onError(Exception ex) { } + } + + private static class DeflateServer extends WebSocketServer { + + public DeflateServer() { + super(new InetSocketAddress(PORT), Collections.singletonList(perMessageDeflateDraft)); + } + + @Override + public void onOpen(WebSocket conn, ClientHandshake handshake) { } + + @Override + public void onClose(WebSocket conn, int code, String reason, boolean remote) { } + + @Override + public void onMessage(WebSocket conn, String message) { } + + @Override + public void onError(WebSocket conn, Exception ex) { } + + @Override + public void onStart() { } + } +} diff --git a/src/main/java/org/java_websocket/drafts/Draft_6455.java b/src/main/java/org/java_websocket/drafts/Draft_6455.java index 126f154d..4c19daf8 100644 --- a/src/main/java/org/java_websocket/drafts/Draft_6455.java +++ b/src/main/java/org/java_websocket/drafts/Draft_6455.java @@ -427,6 +427,12 @@ private ByteBuffer createByteBufferFromFramedata( Framedata framedata ) { byte optcode = fromOpcode( framedata.getOpcode() ); byte one = ( byte ) ( framedata.isFin() ? -128 : 0 ); one |= optcode; + if(framedata.isRSV1()) + one |= getRSVByte(1); + if(framedata.isRSV2()) + one |= getRSVByte(2); + if(framedata.isRSV3()) + one |= getRSVByte(3); buf.put( one ); byte[] payloadlengthbytes = toByteArray( mes.remaining(), sizebytes ); assert ( payloadlengthbytes.length == sizebytes ); @@ -585,6 +591,27 @@ private void translateSingleFrameCheckPacketSize(int maxpacketsize, int realpack } } + /** + * Get a byte that can set RSV bits when OR(|)'d. + * 0 1 2 3 4 5 6 7 + * +-+-+-+-+-------+ + * |F|R|R|R| opcode| + * |I|S|S|S| (4) | + * |N|V|V|V| | + * | |1|2|3| | + * @param rsv Can only be {0, 1, 2, 3} + * @return byte that represents which RSV bit is set. + */ + private byte getRSVByte(int rsv){ + if(rsv == 1) // 0100 0000 + return 0x40; + if(rsv == 2) // 0010 0000 + return 0x20; + if(rsv == 3) // 0001 0000 + return 0x10; + return 0; + } + /** * Get the mask byte if existing * @param mask is mask active or not diff --git a/src/main/java/org/java_websocket/extensions/ExtensionRequestData.java b/src/main/java/org/java_websocket/extensions/ExtensionRequestData.java new file mode 100644 index 00000000..639dd802 --- /dev/null +++ b/src/main/java/org/java_websocket/extensions/ExtensionRequestData.java @@ -0,0 +1,52 @@ +package org.java_websocket.extensions; + +import java.util.LinkedHashMap; +import java.util.Map; + +public class ExtensionRequestData { + + public static String EMPTY_VALUE = ""; + + private Map extensionParameters; + private String extensionName; + + private ExtensionRequestData() { + extensionParameters = new LinkedHashMap(); + } + + public static ExtensionRequestData parseExtensionRequest(String extensionRequest) { + ExtensionRequestData extensionData = new ExtensionRequestData(); + String[] parts = extensionRequest.split(";"); + extensionData.extensionName = parts[0].trim(); + + for(int i = 1; i < parts.length; i++) { + String[] keyValue = parts[i].split("="); + String value = EMPTY_VALUE; + + // Some parameters don't take a value. For those that do, parse the value. + if(keyValue.length > 1) { + String tempValue = keyValue[1].trim(); + + // If the value is wrapped in quotes, just get the data between them. + if((tempValue.startsWith("\"") && tempValue.endsWith("\"")) + || (tempValue.startsWith("'") && tempValue.endsWith("'")) + && tempValue.length() > 2) + tempValue = tempValue.substring(1, tempValue.length() - 1); + + value = tempValue; + } + + extensionData.extensionParameters.put(keyValue[0].trim(), value); + } + + return extensionData; + } + + public String getExtensionName() { + return extensionName; + } + + public Map getExtensionParameters() { + return extensionParameters; + } +} diff --git a/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java b/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java new file mode 100644 index 00000000..296a20b7 --- /dev/null +++ b/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java @@ -0,0 +1,236 @@ +package org.java_websocket.extensions.permessage_deflate; + +import org.java_websocket.enums.Opcode; +import org.java_websocket.exceptions.InvalidDataException; +import org.java_websocket.exceptions.InvalidFrameException; +import org.java_websocket.extensions.CompressionExtension; +import org.java_websocket.extensions.ExtensionRequestData; +import org.java_websocket.extensions.IExtension; +import org.java_websocket.framing.*; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + +public class PerMessageDeflateExtension extends CompressionExtension { + + // Name of the extension as registered by IETF https://tools.ietf.org/html/rfc7692#section-9. + private static final String EXTENSION_REGISTERED_NAME = "permessage-deflate"; + // Below values are defined for convenience. They are not used in the compression/decompression phase. + // They may be needed during the extension-negotiation offer in the future. + private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover"; + private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover"; + private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits"; + private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits"; + private static final int serverMaxWindowBits = 1 << 15; + private static final int clientMaxWindowBits = 1 << 15; + private static final byte[] TAIL_BYTES = {0x00, 0x00, (byte)0xFF, (byte)0xFF}; + private static final int BUFFER_SIZE = 1 << 10; + + private boolean serverNoContextTakeover = true; + private boolean clientNoContextTakeover = false; + + // For WebSocketServers, this variable holds the extension parameters that the peer client has requested. + // For WebSocketClients, this variable holds the extension parameters that client himself has requested. + private Map requestedParameters = new LinkedHashMap(); + private Inflater inflater = new Inflater(true); + private Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); + + /* + An endpoint uses the following algorithm to decompress a message. + 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the + payload of the message. + 2. Decompress the resulting data using DEFLATE. + See, https://tools.ietf.org/html/rfc7692#section-7.2.2 + */ + @Override + public void decodeFrame(Framedata inputFrame) throws InvalidDataException { + // Only DataFrames can be decompressed. + if(!(inputFrame instanceof DataFrame)) + return; + + // RSV1 bit must be set only for the first frame. + if(inputFrame.getOpcode() == Opcode.CONTINUOUS && inputFrame.isRSV1()) + throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, "RSV1 bit can only be set for the first frame."); + + // Decompressed output buffer. + ByteArrayOutputStream output = new ByteArrayOutputStream(); + try { + decompress(inputFrame.getPayloadData().array(), output); + + /* + If a message is "first fragmented and then compressed", as this project does, then the inflater + can not inflate fragments except the first one. + This behavior occurs most likely because those fragments end with "final deflate blocks". + We can check the getRemaining() method to see whether the data we supplied has been decompressed or not. + And if not, we just reset the inflater and decompress again. + Note that this behavior doesn't occur if the message is "first compressed and then fragmented". + */ + if(inflater.getRemaining() > 0){ + inflater = new Inflater(true); + decompress(inputFrame.getPayloadData().array(), output); + } + + if(inputFrame.isFin()) { + decompress(TAIL_BYTES, output); + // If context takeover is disabled, inflater can be reset. + if(clientNoContextTakeover) + inflater = new Inflater(true); + } + } catch (DataFormatException e) { + throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, e.getMessage()); + } + + // RSV1 bit must be cleared after decoding, so that other extensions don't throw an exception. + if(inputFrame.isRSV1()) + ((DataFrame) inputFrame).setRSV1(false); + + // Set frames payload to the new decompressed data. + ((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(output.toByteArray(), 0, output.size())); + } + + private void decompress(byte[] data, ByteArrayOutputStream outputBuffer) throws DataFormatException{ + inflater.setInput(data); + byte[] buffer = new byte[BUFFER_SIZE]; + + int bytesInflated; + while((bytesInflated = inflater.inflate(buffer)) > 0){ + outputBuffer.write(buffer, 0, bytesInflated); + } + } + + @Override + public void encodeFrame(Framedata inputFrame) { + // Only DataFrames can be decompressed. + if(!(inputFrame instanceof DataFrame)) + return; + + // Only the first frame's RSV1 must be set. + if(!(inputFrame instanceof ContinuousFrame)) + ((DataFrame) inputFrame).setRSV1(true); + + deflater.setInput(inputFrame.getPayloadData().array()); + // Compressed output buffer. + ByteArrayOutputStream output = new ByteArrayOutputStream(); + // Temporary buffer to hold compressed output. + byte[] buffer = new byte[1024]; + int bytesCompressed; + while((bytesCompressed = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH)) > 0) { + output.write(buffer, 0, bytesCompressed); + } + + byte outputBytes[] = output.toByteArray(); + int outputLength = outputBytes.length; + + /* + https://tools.ietf.org/html/rfc7692#section-7.2.1 states that if the final fragment's compressed + payload ends with 0x00 0x00 0xff 0xff, they should be removed. + To simulate removal, we just pass 4 bytes less to the new payload + if the frame is final and outputBytes ends with 0x00 0x00 0xff 0xff. + */ + if(inputFrame.isFin()) { + if(endsWithTail(outputBytes)) + outputLength -= TAIL_BYTES.length; + + if(serverNoContextTakeover) { + deflater.end(); + deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); + } + } + + // Set frames payload to the new compressed data. + ((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(outputBytes, 0, outputLength)); + } + + private boolean endsWithTail(byte[] data){ + if(data.length < 4) + return false; + + int length = data.length; + for(int i = 0; i < TAIL_BYTES.length; i++){ + if(TAIL_BYTES[i] != data[length - TAIL_BYTES.length + i]) + return false; + } + + return true; + } + + @Override + public boolean acceptProvidedExtensionAsServer(String inputExtension) { + String[] requestedExtensions = inputExtension.split(","); + for(String extension : requestedExtensions) { + ExtensionRequestData extensionData = ExtensionRequestData.parseExtensionRequest(extension); + if(!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName())) + continue; + + // Holds parameters that peer client has sent. + Map headers = extensionData.getExtensionParameters(); + requestedParameters.putAll(headers); + if(requestedParameters.containsKey(CLIENT_NO_CONTEXT_TAKEOVER)) + clientNoContextTakeover = true; + + return true; + } + + return false; + } + + @Override + public boolean acceptProvidedExtensionAsClient(String inputExtension) { + String[] requestedExtensions = inputExtension.split(","); + for(String extension : requestedExtensions) { + ExtensionRequestData extensionData = ExtensionRequestData.parseExtensionRequest(extension); + if(!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName())) + continue; + + // Holds parameters that are sent by the server, as a response to our initial extension request. + Map headers = extensionData.getExtensionParameters(); + // After this point, parameters that the server sent back can be configured, but we don't use them for now. + return true; + } + + return false; + } + + @Override + public String getProvidedExtensionAsClient() { + requestedParameters.put(CLIENT_NO_CONTEXT_TAKEOVER, ExtensionRequestData.EMPTY_VALUE); + requestedParameters.put(SERVER_NO_CONTEXT_TAKEOVER, ExtensionRequestData.EMPTY_VALUE); + + return EXTENSION_REGISTERED_NAME + "; " + SERVER_NO_CONTEXT_TAKEOVER + "; " + CLIENT_NO_CONTEXT_TAKEOVER; + } + + @Override + public String getProvidedExtensionAsServer() { + return EXTENSION_REGISTERED_NAME + + "; " + SERVER_NO_CONTEXT_TAKEOVER + + (clientNoContextTakeover ? "; " + CLIENT_NO_CONTEXT_TAKEOVER : ""); + } + + @Override + public IExtension copyInstance() { + return new PerMessageDeflateExtension(); + } + + /** + * This extension requires the RSV1 bit to be set only for the first frame. + * If the frame is type is CONTINUOUS, RSV1 bit must be unset. + */ + @Override + public void isFrameValid(Framedata inputFrame) throws InvalidDataException { + if((inputFrame instanceof TextFrame || inputFrame instanceof BinaryFrame) && !inputFrame.isRSV1()) + throw new InvalidFrameException("RSV1 bit must be set for DataFrames."); + if((inputFrame instanceof ContinuousFrame) && (inputFrame.isRSV1() || inputFrame.isRSV2() || inputFrame.isRSV3())) + throw new InvalidFrameException( "bad rsv RSV1: " + inputFrame.isRSV1() + " RSV2: " + inputFrame.isRSV2() + " RSV3: " + inputFrame.isRSV3() ); + super.isFrameValid(inputFrame); + } + + @Override + public String toString() { + return "PerMessageDeflateExtension"; + } +} diff --git a/src/test/java/org/java_websocket/example/AutobahnServerTest.java b/src/test/java/org/java_websocket/example/AutobahnServerTest.java index 61eb4291..e07e3355 100644 --- a/src/test/java/org/java_websocket/example/AutobahnServerTest.java +++ b/src/test/java/org/java_websocket/example/AutobahnServerTest.java @@ -28,6 +28,7 @@ import org.java_websocket.WebSocket; import org.java_websocket.drafts.Draft; import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension; import org.java_websocket.handshake.ClientHandshake; import org.java_websocket.server.WebSocketServer; @@ -101,7 +102,7 @@ public static void main( String[] args ) throws UnknownHostException { System.out.println( "No limit specified. Defaulting to MaxInteger" ); limit = Integer.MAX_VALUE; } - AutobahnServerTest test = new AutobahnServerTest( port, limit, new Draft_6455() ); + AutobahnServerTest test = new AutobahnServerTest( port, limit, new Draft_6455( new PerMessageDeflateExtension()) ); test.setConnectionLostTimeout( 0 ); test.start(); } diff --git a/src/test/java/org/java_websocket/extensions/PerMessageDeflateExtensionTest.java b/src/test/java/org/java_websocket/extensions/PerMessageDeflateExtensionTest.java new file mode 100644 index 00000000..d8307279 --- /dev/null +++ b/src/test/java/org/java_websocket/extensions/PerMessageDeflateExtensionTest.java @@ -0,0 +1,120 @@ +package org.java_websocket.extensions; + +import org.java_websocket.exceptions.InvalidDataException; +import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension; +import org.java_websocket.framing.ContinuousFrame; +import org.java_websocket.framing.TextFrame; +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.junit.Assert.*; + +public class PerMessageDeflateExtensionTest { + + @Test + public void testDecodeFrame() throws InvalidDataException { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + String str = "This is a highly compressable text" + + "This is a highly compressable text" + + "This is a highly compressable text" + + "This is a highly compressable text" + + "This is a highly compressable text"; + byte[] message = str.getBytes(); + TextFrame frame = new TextFrame(); + frame.setPayload(ByteBuffer.wrap(message)); + deflateExtension.encodeFrame(frame); + deflateExtension.decodeFrame(frame); + assertArrayEquals(message, frame.getPayloadData().array()); + } + + @Test + public void testEncodeFrame() { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + String str = "This is a highly compressable text" + + "This is a highly compressable text" + + "This is a highly compressable text" + + "This is a highly compressable text" + + "This is a highly compressable text"; + byte[] message = str.getBytes(); + TextFrame frame = new TextFrame(); + frame.setPayload(ByteBuffer.wrap(message)); + deflateExtension.encodeFrame(frame); + assertTrue(message.length > frame.getPayloadData().array().length); + } + + @Test + public void testAcceptProvidedExtensionAsServer() { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + assertTrue(deflateExtension.acceptProvidedExtensionAsServer("permessage-deflate")); + assertTrue(deflateExtension.acceptProvidedExtensionAsServer("some-other-extension, permessage-deflate")); + assertFalse(deflateExtension.acceptProvidedExtensionAsServer("wrong-permessage-deflate")); + } + + @Test + public void testAcceptProvidedExtensionAsClient() { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + assertTrue(deflateExtension.acceptProvidedExtensionAsClient("permessage-deflate")); + assertTrue(deflateExtension.acceptProvidedExtensionAsClient("some-other-extension, permessage-deflate")); + assertFalse(deflateExtension.acceptProvidedExtensionAsClient("wrong-permessage-deflate")); + } + + @Test + public void testIsFrameValid() { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + TextFrame frame = new TextFrame(); + try { + deflateExtension.isFrameValid(frame); + fail("Frame not valid. RSV1 must be set."); + } catch (Exception e) { + // + } + frame.setRSV1(true); + try { + deflateExtension.isFrameValid(frame); + } catch (Exception e) { + fail("Frame is valid."); + } + frame.setRSV2(true); + try { + deflateExtension.isFrameValid(frame); + fail("Only RSV1 bit must be set."); + } catch (Exception e) { + // + } + ContinuousFrame contFrame = new ContinuousFrame(); + contFrame.setRSV1(true); + try { + deflateExtension.isFrameValid(contFrame); + fail("RSV1 must only be set for first fragments.Continuous frames can't have RSV1 bit set."); + } catch (Exception e) { + // + } + contFrame.setRSV1(false); + try { + deflateExtension.isFrameValid(contFrame); + } catch (Exception e) { + fail("Continuous frame is valid."); + } + } + + @Test + public void testGetProvidedExtensionAsClient() { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + assertEquals( "permessage-deflate; server_no_context_takeover; client_no_context_takeover", + deflateExtension.getProvidedExtensionAsClient() ); + } + + @Test + public void testGetProvidedExtensionAsServer() { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + assertEquals( "permessage-deflate; server_no_context_takeover; client_no_context_takeover", + deflateExtension.getProvidedExtensionAsServer() ); + } + + @Test + public void testToString() throws Exception { + PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(); + assertEquals( "PerMessageDeflateExtension", deflateExtension.toString() ); + } +}