Skip to content

Commit

Permalink
Support for maxInMemorySize in SSE reader
Browse files Browse the repository at this point in the history
Closes gh-24312
  • Loading branch information
rstoyanchev committed Jan 13, 2020
1 parent a741ae4 commit cbc5746
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 26 deletions.
Expand Up @@ -21,7 +21,7 @@
* This can be raised when data buffers are cached and aggregated, e.g.
* {@link DataBufferUtils#join}. Or it could also be raised when data buffers
* have been released but a parsed representation is being aggregated, e.g. async
* parsing with Jackson.
* parsing with Jackson, SSE parsing and aggregating lines per event.
*
* @author Rossen Stoyanchev
* @since 5.1.11
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
Expand All @@ -48,14 +49,16 @@ public class ServerSentEventHttpMessageReader implements HttpMessageReader<Objec

private static final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();

private static final StringDecoder stringDecoder = StringDecoder.textPlainOnly();

private static final ResolvableType STRING_TYPE = ResolvableType.forClass(String.class);


@Nullable
private final Decoder<?> decoder;

private final StringDecoder lineDecoder = StringDecoder.textPlainOnly();




/**
* Constructor without a {@code Decoder}. In this mode only {@code String}
Expand All @@ -82,6 +85,29 @@ public Decoder<?> getDecoder() {
return this.decoder;
}

/**
* Configure a limit on the maximum number of bytes per SSE event which are
* buffered before the event is parsed.
* <p>Note that the {@link #getDecoder() data decoder}, if provided, must
* also be customized accordingly to raise the limit if necessary in order
* to be able to parse the data portion of the event.
* <p>By default this is set to 256K.
* @param byteCount the max number of bytes to buffer, or -1 for unlimited
* @since 5.1.13
*/
public void setMaxInMemorySize(int byteCount) {
this.lineDecoder.setMaxInMemorySize(byteCount);
}

/**
* Return the {@link #setMaxInMemorySize configured} byte count limit.
* @since 5.1.13
*/
public int getMaxInMemorySize() {
return this.lineDecoder.getMaxInMemorySize();
}


@Override
public List<MediaType> getReadableMediaTypes() {
return Collections.singletonList(MediaType.TEXT_EVENT_STREAM);
Expand All @@ -101,12 +127,15 @@ private boolean isServerSentEvent(ResolvableType elementType) {
public Flux<Object> read(
ResolvableType elementType, ReactiveHttpInputMessage message, Map<String, Object> hints) {

LimitTracker limitTracker = new LimitTracker();

boolean shouldWrap = isServerSentEvent(elementType);
ResolvableType valueType = (shouldWrap ? elementType.getGeneric() : elementType);

return stringDecoder.decode(message.getBody(), STRING_TYPE, null, hints)
return this.lineDecoder.decode(message.getBody(), STRING_TYPE, null, hints)
.doOnNext(limitTracker::afterLineParsed)
.bufferUntil(String::isEmpty)
.concatMap(lines -> Mono.justOrEmpty(buildEvent(lines, valueType, shouldWrap, hints)));
.map(lines -> buildEvent(lines, valueType, shouldWrap, hints));
}

@Nullable
Expand Down Expand Up @@ -172,16 +201,47 @@ private Object decodeData(String data, ResolvableType dataType, Map<String, Obje
public Mono<Object> readMono(
ResolvableType elementType, ReactiveHttpInputMessage message, Map<String, Object> hints) {

// We're ahead of String + "*/*"
// Let's see if we can aggregate the output (lest we time out)...
// In order of readers, we're ahead of String + "*/*"
// If this is called, simply delegate to StringDecoder

if (elementType.resolve() == String.class) {
Flux<DataBuffer> body = message.getBody();
return stringDecoder.decodeToMono(body, elementType, null, null).cast(Object.class);
return this.lineDecoder.decodeToMono(body, elementType, null, null).cast(Object.class);
}

return Mono.error(new UnsupportedOperationException(
"ServerSentEventHttpMessageReader only supports reading stream of events as a Flux"));
}


private class LimitTracker {

private int accumulated = 0;


public void afterLineParsed(String line) {
if (getMaxInMemorySize() < 0) {
return;
}
if (line.isEmpty()) {
this.accumulated = 0;
}
if (line.length() > Integer.MAX_VALUE - this.accumulated) {
raiseLimitException();
}
else {
this.accumulated += line.length();
if (this.accumulated > getMaxInMemorySize()) {
raiseLimitException();
}
}
}

private void raiseLimitException() {
// Do not release here, it's likely down via doOnDiscard..
throw new DataBufferLimitException(
"Exceeded limit on max bytes to buffer : " + getMaxInMemorySize());
}
}

}
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -238,9 +238,6 @@ private void initCodec(@Nullable Object codec) {
if (codec instanceof DecoderHttpMessageReader) {
codec = ((DecoderHttpMessageReader) codec).getDecoder();
}
else if (codec instanceof ServerSentEventHttpMessageReader) {
codec = ((ServerSentEventHttpMessageReader) codec).getDecoder();
}

if (codec == null) {
return;
Expand Down Expand Up @@ -269,6 +266,10 @@ else if (codec instanceof ServerSentEventHttpMessageReader) {
if (codec instanceof FormHttpMessageReader) {
((FormHttpMessageReader) codec).setMaxInMemorySize(size);
}
if (codec instanceof ServerSentEventHttpMessageReader) {
((ServerSentEventHttpMessageReader) codec).setMaxInMemorySize(size);
initCodec(((ServerSentEventHttpMessageReader) codec).getDecoder());
}
if (synchronossMultipartPresent) {
if (codec instanceof SynchronossPartHttpMessageReader) {
((SynchronossPartHttpMessageReader) codec).setMaxInMemorySize(size);
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,6 +27,7 @@

import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.testfixture.io.buffer.AbstractLeakCheckingTests;
import org.springframework.http.MediaType;
import org.springframework.http.codec.json.Jackson2JsonDecoder;
Expand All @@ -42,20 +43,21 @@
*/
public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingTests {

private ServerSentEventHttpMessageReader messageReader =
new ServerSentEventHttpMessageReader(new Jackson2JsonDecoder());
private Jackson2JsonDecoder jsonDecoder = new Jackson2JsonDecoder();

private ServerSentEventHttpMessageReader reader = new ServerSentEventHttpMessageReader(this.jsonDecoder);


@Test
public void cantRead() {
assertThat(messageReader.canRead(ResolvableType.forClass(Object.class), new MediaType("foo", "bar"))).isFalse();
assertThat(messageReader.canRead(ResolvableType.forClass(Object.class), null)).isFalse();
assertThat(reader.canRead(ResolvableType.forClass(Object.class), new MediaType("foo", "bar"))).isFalse();
assertThat(reader.canRead(ResolvableType.forClass(Object.class), null)).isFalse();
}

@Test
public void canRead() {
assertThat(messageReader.canRead(ResolvableType.forClass(Object.class), new MediaType("text", "event-stream"))).isTrue();
assertThat(messageReader.canRead(ResolvableType.forClass(ServerSentEvent.class), new MediaType("foo", "bar"))).isTrue();
assertThat(reader.canRead(ResolvableType.forClass(Object.class), new MediaType("text", "event-stream"))).isTrue();
assertThat(reader.canRead(ResolvableType.forClass(ServerSentEvent.class), new MediaType("foo", "bar"))).isTrue();
}

@Test
Expand All @@ -66,7 +68,7 @@ public void readServerSentEvents() {
"id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:bar\n\n" +
"id:c43\nevent:bar\nretry:456\ndata:baz\n\n")));

Flux<ServerSentEvent> events = this.messageReader
Flux<ServerSentEvent> events = this.reader
.read(ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class),
request, Collections.emptyMap()).cast(ServerSentEvent.class);

Expand Down Expand Up @@ -98,7 +100,7 @@ public void readServerSentEventsWithMultipleChunks() {
stringBuffer("ent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:"),
stringBuffer("bar\n\nid:c43\nevent:bar\nretry:456\ndata:baz\n\n")));

Flux<ServerSentEvent> events = messageReader
Flux<ServerSentEvent> events = reader
.read(ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class),
request, Collections.emptyMap()).cast(ServerSentEvent.class);

Expand Down Expand Up @@ -126,7 +128,7 @@ public void readString() {
MockServerHttpRequest request = MockServerHttpRequest.post("/")
.body(Mono.just(stringBuffer("data:foo\ndata:bar\n\ndata:baz\n\n")));

Flux<String> data = messageReader.read(ResolvableType.forClass(String.class),
Flux<String> data = reader.read(ResolvableType.forClass(String.class),
request, Collections.emptyMap()).cast(String.class);

StepVerifier.create(data)
Expand All @@ -143,7 +145,7 @@ public void readPojo() {
"data:{\"foo\": \"foofoo\", \"bar\": \"barbar\"}\n\n" +
"data:{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}\n\n")));

Flux<Pojo> data = messageReader.read(ResolvableType.forClass(Pojo.class), request,
Flux<Pojo> data = reader.read(ResolvableType.forClass(Pojo.class), request,
Collections.emptyMap()).cast(Pojo.class);

StepVerifier.create(data)
Expand All @@ -165,7 +167,7 @@ public void decodeFullContentAsString() {
MockServerHttpRequest request = MockServerHttpRequest.post("/")
.body(Mono.just(stringBuffer(body)));

String actual = messageReader
String actual = reader
.readMono(ResolvableType.forClass(String.class), request, Collections.emptyMap())
.cast(String.class)
.block(Duration.ZERO);
Expand All @@ -182,7 +184,7 @@ public void readError() {
MockServerHttpRequest request = MockServerHttpRequest.post("/")
.body(body);

Flux<String> data = messageReader.read(ResolvableType.forClass(String.class),
Flux<String> data = reader.read(ResolvableType.forClass(String.class),
request, Collections.emptyMap()).cast(String.class);

StepVerifier.create(data)
Expand All @@ -192,6 +194,54 @@ public void readError() {
.verify();
}

@Test
public void maxInMemoryLimit() {

this.reader.setMaxInMemorySize(17);

MockServerHttpRequest request = MockServerHttpRequest.post("/")
.body(Flux.just(stringBuffer("data:\"TOO MUCH DATA\"\ndata:bar\n\ndata:baz\n\n")));

Flux<String> data = this.reader.read(ResolvableType.forClass(String.class),
request, Collections.emptyMap()).cast(String.class);

StepVerifier.create(data)
.expectError(DataBufferLimitException.class)
.verify();
}

@Test // gh-24312
public void maxInMemoryLimitAllowsReadingPojoLargerThanDefaultSize() {

int limit = this.jsonDecoder.getMaxInMemorySize();

String fooValue = getStringOfSize(limit) + "and then some more";
String content = "data:{\"foo\": \"" + fooValue + "\"}\n\n";
MockServerHttpRequest request = MockServerHttpRequest.post("/").body(Mono.just(stringBuffer(content)));

Jackson2JsonDecoder jacksonDecoder = new Jackson2JsonDecoder();
ServerSentEventHttpMessageReader messageReader = new ServerSentEventHttpMessageReader(jacksonDecoder);

jacksonDecoder.setMaxInMemorySize(limit + 1024);
messageReader.setMaxInMemorySize(limit + 1024);

Flux<Pojo> data = messageReader.read(ResolvableType.forClass(Pojo.class), request,
Collections.emptyMap()).cast(Pojo.class);

StepVerifier.create(data)
.consumeNextWith(pojo -> assertThat(pojo.getFoo()).isEqualTo(fooValue))
.expectComplete()
.verify();
}

private static String getStringOfSize(long size) {
StringBuilder content = new StringBuilder("Aa");
while (content.length() < size) {
content.append(content);
}
return content.toString();
}

private DataBuffer stringBuffer(String value) {
byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length);
Expand Down
Expand Up @@ -140,6 +140,7 @@ public void maxInMemorySize() {
assertThat(((Jaxb2XmlDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);

ServerSentEventHttpMessageReader reader = (ServerSentEventHttpMessageReader) nextReader(readers);
assertThat(reader.getMaxInMemorySize()).isEqualTo(size);
assertThat(((Jackson2JsonDecoder) reader.getDecoder()).getMaxInMemorySize()).isEqualTo(size);

assertThat(((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
Expand Down

0 comments on commit cbc5746

Please sign in to comment.