Skip to content

Commit

Permalink
Issue #4383 - atomic state to MultiPart for multi-thread synchronization
Browse files Browse the repository at this point in the history
Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
  • Loading branch information
lachlan-roberts committed Jan 20, 2020
1 parent b520ca6 commit c5d074e
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 54 deletions.
Expand Up @@ -31,17 +31,16 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Part;

import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.ByteArrayOutputStream2;
import org.eclipse.jetty.util.LazyList;
import org.eclipse.jetty.util.MultiException;
import org.eclipse.jetty.util.MultiMap;
import org.eclipse.jetty.util.QuotedStringTokenizer;
Expand All @@ -58,19 +57,27 @@
*/
public class MultiPartFormInputStream
{
private enum State
{
UNPARSED,
PARSING,
PARSED,
CLOSING,
CLOSED
}

private static final Logger LOG = Log.getLogger(MultiPartFormInputStream.class);
private static final MultiMap<Part> EMPTY_MAP = new MultiMap<>(Collections.emptyMap());
private final MultiMap<Part> _parts;
private InputStream _in;
private MultipartConfigElement _config;
private String _contentType;
private Throwable _err;
private File _tmpDir;
private File _contextTmpDir;
private boolean _deleteOnExit;
private boolean _writeFilesWithFilenames;
private boolean _parsed;
private int _bufferSize = 16 * 1024;
private final MultiMap<Part> _parts = new MultiMap<>();
private final InputStream _in;
private final MultipartConfigElement _config;
private final File _contextTmpDir;
private final String _contentType;
private volatile Throwable _err;
private volatile File _tmpDir;
private volatile boolean _deleteOnExit;
private volatile boolean _writeFilesWithFilenames;
private volatile int _bufferSize = 16 * 1024;
private volatile State state = State.UNPARSED;

public class MultiPart implements Part
{
Expand Down Expand Up @@ -333,39 +340,33 @@ public String getContentDispositionFilename()
public MultiPartFormInputStream(InputStream in, String contentType, MultipartConfigElement config, File contextTmpDir)
{
_contentType = contentType;
_config = config;
_contextTmpDir = contextTmpDir;
if (_contextTmpDir == null)
_contextTmpDir = new File(System.getProperty("java.io.tmpdir"));

if (_config == null)
_config = new MultipartConfigElement(_contextTmpDir.getAbsolutePath());

MultiMap parts = new MultiMap();
_contextTmpDir = (contextTmpDir != null) ? contextTmpDir : new File(System.getProperty("java.io.tmpdir"));
_config = (config != null) ? config : new MultipartConfigElement(_contextTmpDir.getAbsolutePath());

if (in instanceof ServletInputStream)
{
if (((ServletInputStream)in).isFinished())
{
parts = EMPTY_MAP;
_parsed = true;
_in = null;
state = State.PARSED;
return;
}
}
if (!_parsed)
_in = new BufferedInputStream(in);
_parts = parts;

_in = new BufferedInputStream(in);
}

/**
* @return whether the list of parsed parts is empty
* @deprecated use getParts().isEmpty()
*/
@Deprecated
public boolean isEmpty()
{
if (_parts == null)
if (_parts.isEmpty())
return true;

Collection<List<Part>> values = _parts.values();
for (List<Part> partList : values)
for (List<Part> partList : _parts.values())
{
if (!partList.isEmpty())
return false;
Expand All @@ -379,6 +380,26 @@ public boolean isEmpty()
*/
public void deleteParts()
{
synchronized (this)
{
switch (state)
{
case CLOSED:
case UNPARSED:
state = State.CLOSED;
return;

case PARSING:
state = State.CLOSING;
return;

case PARSED:
case CLOSING:
state = State.CLOSED;
break;
}
}

MultiException err = null;
for (List<Part> parts : _parts.values())
{
Expand Down Expand Up @@ -410,21 +431,9 @@ public void deleteParts()
*/
public Collection<Part> getParts() throws IOException
{
if (!_parsed)
parse();
parse();
throwIfError();

if (_parts.isEmpty())
return Collections.emptyList();

Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
}
return parts;
return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList());
}

/**
Expand All @@ -436,8 +445,7 @@ public Collection<Part> getParts() throws IOException
*/
public Part getPart(String name) throws IOException
{
if (!_parsed)
parse();
parse();
throwIfError();
return _parts.getValue(name, 0);
}
Expand Down Expand Up @@ -468,13 +476,24 @@ protected void throwIfError() throws IOException
*/
protected void parse()
{
// have we already parsed the input?
if (_parsed)
return;
_parsed = true;
synchronized (this)
{
switch (state)
{
case UNPARSED:
state = State.PARSING;
break;

case PARSED:
return;

default:
_err = new IllegalStateException(state.name());
return;
}
}

MultiPartParser parser = null;
Handler handler = new Handler();
try
{
// if its not a multipart request, don't parse it
Expand Down Expand Up @@ -507,16 +526,23 @@ else if ("".equals(_config.getLocation()))
contentTypeBoundary = QuotedStringTokenizer.unquote(value(_contentType.substring(bstart, bend)).trim());
}

parser = new MultiPartParser(handler, contentTypeBoundary);
parser = new MultiPartParser(new Handler(), contentTypeBoundary);
byte[] data = new byte[_bufferSize];
int len;
long total = 0;

while (true)
{
synchronized (this)
{
if (state != State.PARSING)
{
_err = new IllegalStateException(state.name());
return;
}
}

len = _in.read(data);

if (len > 0)
{
// keep running total of size of bytes read from input and throw an exception if exceeds MultipartConfigElement._maxRequestSize
Expand Down Expand Up @@ -570,6 +596,29 @@ else if (len == -1)
if (parser != null)
parser.parse(BufferUtil.EMPTY_BUFFER, true);
}
finally
{
boolean cleanup = false;
synchronized (this)
{
switch (state)
{
case PARSING:
state = State.PARSED;
break;

case CLOSING:
cleanup = true;
break;

default:
_err = new IllegalStateException(state.name());
}
}

if (cleanup)
deleteParts();
}
}

class Handler implements MultiPartParser.Handler
Expand Down
Expand Up @@ -25,6 +25,8 @@
import java.io.InputStream;
import java.util.Base64;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ReadListener;
Expand Down Expand Up @@ -518,6 +520,78 @@ public void testDeleteNPE()

mpis.deleteParts(); // this should not be an NPE
}

@Test
public void testAsyncCleanUp() throws Exception
{
final CountDownLatch reading = new CountDownLatch(1);
final InputStream wrappedStream = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());

// This stream won't allow the parser to exit because it will never return anything less than 0.
InputStream slowStream = new InputStream()
{
@Override
public int read(byte[] b, int off, int len) throws IOException
{
return Math.max(0, super.read(b, off, len));
}

@Override
public int read() throws IOException
{
reading.countDown();
return wrappedStream.read();
}
};

MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(slowStream, _contentType, config, _tmpDir);

// In another thread delete the parts when we detect that we have started parsing.
CompletableFuture<Throwable> cleanupError = new CompletableFuture<>();
new Thread(() ->
{
try
{
assertTrue(reading.await(5, TimeUnit.SECONDS));
mpis.deleteParts();
cleanupError.complete(null);
}
catch (Throwable t)
{
cleanupError.complete(t);
}
}).start();

// The call to getParts should throw an error.
Throwable error = assertThrows(IllegalStateException.class, mpis::getParts);
assertThat(error.getMessage(), is("CLOSING"));

// There was no error with the cleanup.
assertNull(cleanupError.get());

// No tmp files are remaining.
String[] fileList = _tmpDir.list();
assertNotNull(fileList);
assertThat(fileList.length, is(0));
}

@Test
public void testParseAfterCleanUp() throws Exception
{
final InputStream input = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());
MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(input, _contentType, config, _tmpDir);

mpis.deleteParts();

// The call to getParts should throw because we have already cleaned up the parts.
Throwable error = assertThrows(IllegalStateException.class, mpis::getParts);
assertThat(error.getMessage(), is("CLOSED"));

// Even though we called getParts() we never even created the tmp directory as we had already called deleteParts().
assertFalse(_tmpDir.exists());
}

@Test
public void testLFOnlyRequest()
Expand Down

0 comments on commit c5d074e

Please sign in to comment.