diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/MultiPartFormInputStream.java b/jetty-server/src/main/java/org/eclipse/jetty/server/MultiPartFormInputStream.java index 12e40428398d..8a53207e2f24 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/MultiPartFormInputStream.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/MultiPartFormInputStream.java @@ -31,17 +31,16 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; import javax.servlet.MultipartConfigElement; import javax.servlet.ServletInputStream; import javax.servlet.http.Part; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.ByteArrayOutputStream2; -import org.eclipse.jetty.util.LazyList; import org.eclipse.jetty.util.MultiException; import org.eclipse.jetty.util.MultiMap; import org.eclipse.jetty.util.QuotedStringTokenizer; @@ -58,19 +57,27 @@ */ public class MultiPartFormInputStream { + private enum State + { + UNPARSED, + PARSING, + PARSED, + CLOSING, + CLOSED + } + private static final Logger LOG = Log.getLogger(MultiPartFormInputStream.class); - private static final MultiMap EMPTY_MAP = new MultiMap<>(Collections.emptyMap()); - private final MultiMap _parts; - private InputStream _in; - private MultipartConfigElement _config; - private String _contentType; - private Throwable _err; - private File _tmpDir; - private File _contextTmpDir; - private boolean _deleteOnExit; - private boolean _writeFilesWithFilenames; - private boolean _parsed; - private int _bufferSize = 16 * 1024; + private final MultiMap _parts = new MultiMap<>(); + private final InputStream _in; + private final MultipartConfigElement _config; + private final File _contextTmpDir; + private final String _contentType; + private volatile Throwable _err; + private volatile File _tmpDir; + private volatile boolean _deleteOnExit; + private volatile boolean _writeFilesWithFilenames; + private volatile int _bufferSize = 16 * 1024; + private volatile State state = State.UNPARSED; public class MultiPart implements Part { @@ -333,39 +340,33 @@ public String getContentDispositionFilename() public MultiPartFormInputStream(InputStream in, String contentType, MultipartConfigElement config, File contextTmpDir) { _contentType = contentType; - _config = config; - _contextTmpDir = contextTmpDir; - if (_contextTmpDir == null) - _contextTmpDir = new File(System.getProperty("java.io.tmpdir")); - - if (_config == null) - _config = new MultipartConfigElement(_contextTmpDir.getAbsolutePath()); - - MultiMap parts = new MultiMap(); + _contextTmpDir = (contextTmpDir != null) ? contextTmpDir : new File(System.getProperty("java.io.tmpdir")); + _config = (config != null) ? config : new MultipartConfigElement(_contextTmpDir.getAbsolutePath()); if (in instanceof ServletInputStream) { if (((ServletInputStream)in).isFinished()) { - parts = EMPTY_MAP; - _parsed = true; + _in = null; + state = State.PARSED; + return; } } - if (!_parsed) - _in = new BufferedInputStream(in); - _parts = parts; + + _in = new BufferedInputStream(in); } /** * @return whether the list of parsed parts is empty + * @deprecated use getParts().isEmpty() */ + @Deprecated public boolean isEmpty() { - if (_parts == null) + if (_parts.isEmpty()) return true; - Collection> values = _parts.values(); - for (List partList : values) + for (List partList : _parts.values()) { if (!partList.isEmpty()) return false; @@ -379,6 +380,26 @@ public boolean isEmpty() */ public void deleteParts() { + synchronized (this) + { + switch (state) + { + case CLOSED: + case UNPARSED: + state = State.CLOSED; + return; + + case PARSING: + state = State.CLOSING; + return; + + case PARSED: + case CLOSING: + state = State.CLOSED; + break; + } + } + MultiException err = null; for (List parts : _parts.values()) { @@ -410,21 +431,9 @@ public void deleteParts() */ public Collection getParts() throws IOException { - if (!_parsed) - parse(); + parse(); throwIfError(); - - if (_parts.isEmpty()) - return Collections.emptyList(); - - Collection> values = _parts.values(); - List parts = new ArrayList<>(); - for (List o : values) - { - List asList = LazyList.getList(o, false); - parts.addAll(asList); - } - return parts; + return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList()); } /** @@ -436,8 +445,7 @@ public Collection getParts() throws IOException */ public Part getPart(String name) throws IOException { - if (!_parsed) - parse(); + parse(); throwIfError(); return _parts.getValue(name, 0); } @@ -468,13 +476,24 @@ protected void throwIfError() throws IOException */ protected void parse() { - // have we already parsed the input? - if (_parsed) - return; - _parsed = true; + synchronized (this) + { + switch (state) + { + case UNPARSED: + state = State.PARSING; + break; + + case PARSED: + return; + + default: + _err = new IllegalStateException(state.name()); + return; + } + } MultiPartParser parser = null; - Handler handler = new Handler(); try { // if its not a multipart request, don't parse it @@ -507,16 +526,23 @@ else if ("".equals(_config.getLocation())) contentTypeBoundary = QuotedStringTokenizer.unquote(value(_contentType.substring(bstart, bend)).trim()); } - parser = new MultiPartParser(handler, contentTypeBoundary); + parser = new MultiPartParser(new Handler(), contentTypeBoundary); byte[] data = new byte[_bufferSize]; int len; long total = 0; while (true) { + synchronized (this) + { + if (state != State.PARSING) + { + _err = new IllegalStateException(state.name()); + return; + } + } len = _in.read(data); - if (len > 0) { // keep running total of size of bytes read from input and throw an exception if exceeds MultipartConfigElement._maxRequestSize @@ -570,6 +596,29 @@ else if (len == -1) if (parser != null) parser.parse(BufferUtil.EMPTY_BUFFER, true); } + finally + { + boolean cleanup = false; + synchronized (this) + { + switch (state) + { + case PARSING: + state = State.PARSED; + break; + + case CLOSING: + cleanup = true; + break; + + default: + _err = new IllegalStateException(state.name()); + } + } + + if (cleanup) + deleteParts(); + } } class Handler implements MultiPartParser.Handler diff --git a/jetty-server/src/test/java/org/eclipse/jetty/server/MultiPartFormInputStreamTest.java b/jetty-server/src/test/java/org/eclipse/jetty/server/MultiPartFormInputStreamTest.java index b80787e13158..4862e32b62e4 100644 --- a/jetty-server/src/test/java/org/eclipse/jetty/server/MultiPartFormInputStreamTest.java +++ b/jetty-server/src/test/java/org/eclipse/jetty/server/MultiPartFormInputStreamTest.java @@ -25,6 +25,8 @@ import java.io.InputStream; import java.util.Base64; import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import javax.servlet.MultipartConfigElement; import javax.servlet.ReadListener; @@ -518,6 +520,78 @@ public void testDeleteNPE() mpis.deleteParts(); // this should not be an NPE } + + @Test + public void testAsyncCleanUp() throws Exception + { + final CountDownLatch reading = new CountDownLatch(1); + final InputStream wrappedStream = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes()); + + // This stream won't allow the parser to exit because it will never return anything less than 0. + InputStream slowStream = new InputStream() + { + @Override + public int read(byte[] b, int off, int len) throws IOException + { + return Math.max(0, super.read(b, off, len)); + } + + @Override + public int read() throws IOException + { + reading.countDown(); + return wrappedStream.read(); + } + }; + + MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50); + MultiPartFormInputStream mpis = new MultiPartFormInputStream(slowStream, _contentType, config, _tmpDir); + + // In another thread delete the parts when we detect that we have started parsing. + CompletableFuture cleanupError = new CompletableFuture<>(); + new Thread(() -> + { + try + { + assertTrue(reading.await(5, TimeUnit.SECONDS)); + mpis.deleteParts(); + cleanupError.complete(null); + } + catch (Throwable t) + { + cleanupError.complete(t); + } + }).start(); + + // The call to getParts should throw an error. + Throwable error = assertThrows(IllegalStateException.class, mpis::getParts); + assertThat(error.getMessage(), is("CLOSING")); + + // There was no error with the cleanup. + assertNull(cleanupError.get()); + + // No tmp files are remaining. + String[] fileList = _tmpDir.list(); + assertNotNull(fileList); + assertThat(fileList.length, is(0)); + } + + @Test + public void testParseAfterCleanUp() throws Exception + { + final InputStream input = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes()); + MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50); + MultiPartFormInputStream mpis = new MultiPartFormInputStream(input, _contentType, config, _tmpDir); + + mpis.deleteParts(); + + // The call to getParts should throw because we have already cleaned up the parts. + Throwable error = assertThrows(IllegalStateException.class, mpis::getParts); + assertThat(error.getMessage(), is("CLOSED")); + + // Even though we called getParts() we never even created the tmp directory as we had already called deleteParts(). + assertFalse(_tmpDir.exists()); + } @Test public void testLFOnlyRequest()