diff --git a/jetty-http/src/main/java/org/eclipse/jetty/http/MultiPartFormInputStream.java b/jetty-http/src/main/java/org/eclipse/jetty/http/MultiPartFormInputStream.java index 0dd1ef6c81d7..03a4d8a65a72 100644 --- a/jetty-http/src/main/java/org/eclipse/jetty/http/MultiPartFormInputStream.java +++ b/jetty-http/src/main/java/org/eclipse/jetty/http/MultiPartFormInputStream.java @@ -35,6 +35,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import javax.servlet.MultipartConfigElement; import javax.servlet.ServletInputStream; @@ -59,7 +60,16 @@ */ public class MultiPartFormInputStream { + private enum State + { + UNPARSED, + PARSING, + ERROR, + COMPLETED + } + private static final Logger LOG = Log.getLogger(MultiPartFormInputStream.class); + private final AtomicReference state = new AtomicReference<>(State.UNPARSED); private final MultiMap _parts = new MultiMap<>(); private final InputStream _in; private final MultipartConfigElement _config; @@ -356,22 +366,19 @@ public MultiPartFormInputStream(InputStream in, String contentType, MultipartCon @Deprecated public boolean isEmpty() { - synchronized (this) - { - if (!_parsed) - throw new IllegalStateException(); - - if (_parts.isEmpty()) - return true; - - for (List partList : _parts.values()) - { - if (!partList.isEmpty()) - return false; - } + if (!_parsed) + throw new IllegalStateException(); + if (_parts.isEmpty()) return true; + + for (List partList : _parts.values()) + { + if (!partList.isEmpty()) + return false; } + + return true; } /** @@ -382,20 +389,17 @@ public boolean isEmpty() @Deprecated public Collection getParsedParts() { - synchronized (this) - { - if (_parts.isEmpty()) - return Collections.emptyList(); + if (_parts.isEmpty()) + return Collections.emptyList(); - Collection> values = _parts.values(); - List parts = new ArrayList<>(); - for (List o : values) - { - List asList = LazyList.getList(o, false); - parts.addAll(asList); - } - return parts; + Collection> values = _parts.values(); + List parts = new ArrayList<>(); + for (List o : values) + { + List asList = LazyList.getList(o, false); + parts.addAll(asList); } + return parts; } /** @@ -403,31 +407,52 @@ public Collection getParsedParts() */ public void deleteParts() { - // TODO: Can we cancel parsing somehow instead of blocking. - synchronized (this) + while (true) + { + switch (state.get()) + { + case PARSING: + state.compareAndSet(State.PARSING, State.ERROR); + Thread.yield(); + continue; + + case UNPARSED: + if (!state.compareAndSet(State.UNPARSED, State.COMPLETED)) + continue; + break; + + case ERROR: + Thread.yield(); + continue; + + case COMPLETED: + break; + } + + break; + } + + MultiException err = null; + for (List parts : _parts.values()) { - MultiException err = null; - for (List parts : _parts.values()) + for (Part p : parts) { - for (Part p : parts) + try { - try - { - ((MultiPart)p).cleanUp(); - } - catch (Exception e) - { - if (err == null) - err = new MultiException(); - err.add(e); - } + ((MultiPart)p).cleanUp(); + } + catch (Exception e) + { + if (err == null) + err = new MultiException(); + err.add(e); } } - _parts.clear(); - - if (err != null) - err.ifExceptionThrowRuntime(); } + _parts.clear(); + + if (err != null) + err.ifExceptionThrowRuntime(); } /** @@ -438,13 +463,10 @@ public void deleteParts() */ public Collection getParts() throws IOException { - synchronized (this) - { - if (!_parsed) - parse(); - throwIfError(); - return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList()); - } + if (!_parsed) + parse(); + throwIfError(); + return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList()); } /** @@ -456,13 +478,10 @@ public Collection getParts() throws IOException */ public Part getPart(String name) throws IOException { - synchronized (this) - { - if (!_parsed) - parse(); - throwIfError(); - return _parts.getValue(name, 0); - } + if (!_parsed) + parse(); + throwIfError(); + return _parts.getValue(name, 0); } /** @@ -534,8 +553,15 @@ else if ("".equals(_config.getLocation())) int len; long total = 0; + if (!state.compareAndSet(State.UNPARSED, State.PARSING)) + throw new IllegalStateException("Could not start parsing " + state.get()); + while (true) { + State currentState = state.get(); + if (currentState != State.PARSING) + throw new IllegalStateException("Unexpected state " + currentState); + len = _in.read(data); if (len > 0) @@ -591,6 +617,27 @@ else if (len == -1) if (parser != null) parser.parse(BufferUtil.EMPTY_BUFFER, true); } + finally + { + while (true) + { + switch (state.get()) + { + case PARSING: + if (!state.compareAndSet(State.PARSING, State.COMPLETED)) + continue; + break; + + case ERROR: + state.compareAndSet(State.ERROR, State.COMPLETED); + break; + + default: + break; + } + break; + } + } } class Handler implements MultiPartParser.Handler diff --git a/jetty-http/src/test/java/org/eclipse/jetty/http/MultiPartFormInputStreamTest.java b/jetty-http/src/test/java/org/eclipse/jetty/http/MultiPartFormInputStreamTest.java index a8c1d9861190..0081e9f08973 100644 --- a/jetty-http/src/test/java/org/eclipse/jetty/http/MultiPartFormInputStreamTest.java +++ b/jetty-http/src/test/java/org/eclipse/jetty/http/MultiPartFormInputStreamTest.java @@ -25,6 +25,8 @@ import java.io.InputStream; import java.util.Base64; import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import javax.servlet.MultipartConfigElement; import javax.servlet.ReadListener; @@ -509,6 +511,79 @@ public void testPartTmpFileDeletion() throws Exception assertThat(stuff.exists(), is(false)); //tmp file was removed after cleanup } + @Test + public void testAsyncCleanUp() throws Exception + { + final CountDownLatch reading = new CountDownLatch(1); + final InputStream wrappedStream = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes()); + + // This stream won't allow the parser to exit because it will never return anything less than 0. + InputStream slowStream = new InputStream() { + @Override + public int read(byte[] b, int off, int len) throws IOException + { + return Math.max(0, super.read(b, off, len)); + } + + @Override + public int read() throws IOException + { + reading.countDown(); + return wrappedStream.read(); + } + }; + + MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50); + MultiPartFormInputStream mpis = new MultiPartFormInputStream(slowStream, _contentType, config, _tmpDir); + + // In another thread delete the parts when we detect that we have started parsing. + CompletableFuture cleanupError = new CompletableFuture<>(); + new Thread(() -> + { + try + { + assertTrue(reading.await(5, TimeUnit.SECONDS)); + mpis.deleteParts(); + cleanupError.complete(null); + } + catch (Throwable t) + { + cleanupError.complete(t); + } + }).start(); + + // The call to getParts should throw an error. + Throwable error = assertThrows(IllegalStateException.class, mpis::getParts); + assertThat(error.getMessage(), is("Unexpected state ERROR")); + + // There was no error with the cleanup. + assertNull(cleanupError.get()); + + // No tmp files are remaining. + String[] fileList = _tmpDir.list(); + assertNotNull(fileList); + assertThat(fileList.length, is(0)); + } + + @Test + public void testParseAfterCleanUp() throws Exception + { + final InputStream input = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes()); + MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50); + MultiPartFormInputStream mpis = new MultiPartFormInputStream(input, _contentType, config, _tmpDir); + + mpis.deleteParts(); + + // The call to getParts should throw because we have already cleaned up the parts. + Throwable error = assertThrows(IllegalStateException.class, mpis::getParts); + assertThat(error.getMessage(), is("Could not start parsing COMPLETED")); + + // No tmp files are remaining. + String[] fileList = _tmpDir.list(); + assertNotNull(fileList); + assertThat(fileList.length, is(0)); + } + @Test public void testLFOnlyRequest() throws Exception diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java b/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java index 1727d40ee6ee..f02ad81dbdab 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/Request.java @@ -2310,7 +2310,7 @@ public Collection getParts() throws IOException, ServletException return getParts(null); } - private synchronized Collection getParts(MultiMap params) throws IOException + private Collection getParts(MultiMap params) throws IOException { if (_multiParts == null) _multiParts = (MultiParts)getAttribute(MULTIPARTS);