Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #4383 - synchronize multiparts for cleanup from different thread #4498

Merged
merged 5 commits into from Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -31,17 +31,16 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Part;

import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.ByteArrayOutputStream2;
import org.eclipse.jetty.util.LazyList;
import org.eclipse.jetty.util.MultiException;
import org.eclipse.jetty.util.MultiMap;
import org.eclipse.jetty.util.QuotedStringTokenizer;
Expand All @@ -53,24 +52,49 @@
* MultiPartInputStream
* <p>
* Handle a MultiPart Mime input stream, breaking it up on the boundary into files and strings.
*
* </p>
* <p>
* Deleting the parts can be done from a different thread if the parts are parsed asynchronously.
* Because of this we use the state to fail the parsing and coordinate which thread will delete any remaining parts.
* The deletion of parts is done by the cleanup thread in all cases except the transition from ERROR-&gt;DELETED which
* is done by the parsing thread.
* </p>
* <pre>{@code
* deleteParts()
* +--------------------------------------------------------------+
* | |
* | deleteParts() v
* UNPARSED -------> PARSING --------> PARSED ------------------>DELETED
* | ^
* | |
* +----------------> ERROR ---------------------+
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
* deleteParts() parsing thread
* }</pre>
* @see <a href="https://tools.ietf.org/html/rfc7578">https://tools.ietf.org/html/rfc7578</a>
*/
public class MultiPartFormInputStream
{
private enum State
{
UNPARSED,
PARSING,
PARSED,
DELETED,
ERROR
}

lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
private static final Logger LOG = Log.getLogger(MultiPartFormInputStream.class);
private static final MultiMap<Part> EMPTY_MAP = new MultiMap<>(Collections.emptyMap());
private final MultiMap<Part> _parts;
private InputStream _in;
private MultipartConfigElement _config;
private String _contentType;
private Throwable _err;
private File _tmpDir;
private File _contextTmpDir;
private boolean _deleteOnExit;
private boolean _writeFilesWithFilenames;
private boolean _parsed;
private int _bufferSize = 16 * 1024;
private final MultiMap<Part> _parts = new MultiMap<>();
private final InputStream _in;
private final MultipartConfigElement _config;
private final File _contextTmpDir;
private final String _contentType;
private volatile Throwable _err;
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
private volatile File _tmpDir;
private volatile boolean _deleteOnExit;
private volatile boolean _writeFilesWithFilenames;
private volatile int _bufferSize = 16 * 1024;
private volatile State state = State.UNPARSED;

public class MultiPart implements Part
{
Expand Down Expand Up @@ -333,39 +357,33 @@ public String getContentDispositionFilename()
public MultiPartFormInputStream(InputStream in, String contentType, MultipartConfigElement config, File contextTmpDir)
{
_contentType = contentType;
_config = config;
_contextTmpDir = contextTmpDir;
if (_contextTmpDir == null)
_contextTmpDir = new File(System.getProperty("java.io.tmpdir"));

if (_config == null)
_config = new MultipartConfigElement(_contextTmpDir.getAbsolutePath());

MultiMap parts = new MultiMap();
_contextTmpDir = (contextTmpDir != null) ? contextTmpDir : new File(System.getProperty("java.io.tmpdir"));
_config = (config != null) ? config : new MultipartConfigElement(_contextTmpDir.getAbsolutePath());

if (in instanceof ServletInputStream)
{
if (((ServletInputStream)in).isFinished())
{
parts = EMPTY_MAP;
_parsed = true;
_in = null;
state = State.PARSED;
janbartel marked this conversation as resolved.
Show resolved Hide resolved
return;
}
}
if (!_parsed)
_in = new BufferedInputStream(in);
_parts = parts;

_in = new BufferedInputStream(in);
}

/**
* @return whether the list of parsed parts is empty
* @deprecated use getParts().isEmpty()
*/
@Deprecated
public boolean isEmpty()
{
if (_parts == null)
if (_parts.isEmpty())
return true;

Collection<List<Part>> values = _parts.values();
for (List<Part> partList : values)
for (List<Part> partList : _parts.values())
{
if (!partList.isEmpty())
return false;
Expand All @@ -378,6 +396,33 @@ public boolean isEmpty()
* Delete any tmp storage for parts, and clear out the parts list.
*/
public void deleteParts()
{
synchronized (this)
{
switch (state)
{
case DELETED:
case ERROR:
return;

case PARSING:
state = State.ERROR;
return;

case UNPARSED:
state = State.DELETED;
return;

case PARSED:
state = State.DELETED;
break;
}
}

uncheckedDeleteParts();
}

private void uncheckedDeleteParts()
{
MultiException err = null;
for (List<Part> parts : _parts.values())
Expand Down Expand Up @@ -410,21 +455,9 @@ public void deleteParts()
*/
public Collection<Part> getParts() throws IOException
{
if (!_parsed)
parse();
parse();
throwIfError();

if (_parts.isEmpty())
return Collections.emptyList();

Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
}
return parts;
return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList());
}

/**
Expand All @@ -436,8 +469,7 @@ public Collection<Part> getParts() throws IOException
*/
public Part getPart(String name) throws IOException
{
if (!_parsed)
parse();
parse();
throwIfError();
return _parts.getValue(name, 0);
}
Expand Down Expand Up @@ -468,13 +500,24 @@ protected void throwIfError() throws IOException
*/
protected void parse()
{
// have we already parsed the input?
if (_parsed)
return;
_parsed = true;
synchronized (this)
{
switch (state)
{
case UNPARSED:
state = State.PARSING;
break;

case PARSED:
return;

default:
_err = new IOException(state.name());
return;
}
}

MultiPartParser parser = null;
Handler handler = new Handler();
try
{
// if its not a multipart request, don't parse it
Expand Down Expand Up @@ -507,16 +550,23 @@ else if ("".equals(_config.getLocation()))
contentTypeBoundary = QuotedStringTokenizer.unquote(value(_contentType.substring(bstart, bend)).trim());
}

parser = new MultiPartParser(handler, contentTypeBoundary);
parser = new MultiPartParser(new Handler(), contentTypeBoundary);
byte[] data = new byte[_bufferSize];
int len;
long total = 0;

while (true)
{
synchronized (this)
{
if (state != State.PARSING)
janbartel marked this conversation as resolved.
Show resolved Hide resolved
{
_err = new IOException(state.name());
return;
}
}

len = _in.read(data);

if (len > 0)
{
// keep running total of size of bytes read from input and throw an exception if exceeds MultipartConfigElement._maxRequestSize
Expand Down Expand Up @@ -570,6 +620,30 @@ else if (len == -1)
if (parser != null)
parser.parse(BufferUtil.EMPTY_BUFFER, true);
}
finally
{
boolean cleanup = false;
synchronized (this)
{
switch (state)
{
case PARSING:
state = State.PARSED;
break;

case ERROR:
state = State.DELETED;
cleanup = true;
break;

default:
_err = new IllegalStateException(state.name());
}
}

if (cleanup)
janbartel marked this conversation as resolved.
Show resolved Hide resolved
uncheckedDeleteParts();
}
}

class Handler implements MultiPartParser.Handler
Expand Down
Expand Up @@ -25,6 +25,8 @@
import java.io.InputStream;
import java.util.Base64;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ReadListener;
Expand Down Expand Up @@ -518,6 +520,78 @@ public void testDeleteNPE()

mpis.deleteParts(); // this should not be an NPE
}

@Test
public void testAsyncCleanUp() throws Exception
{
final CountDownLatch reading = new CountDownLatch(1);
final InputStream wrappedStream = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());

// This stream won't allow the parser to exit because it will never return anything less than 0.
InputStream slowStream = new InputStream()
{
@Override
public int read(byte[] b, int off, int len) throws IOException
{
return Math.max(0, super.read(b, off, len));
}

@Override
public int read() throws IOException
{
reading.countDown();
return wrappedStream.read();
}
};

MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(slowStream, _contentType, config, _tmpDir);

// In another thread delete the parts when we detect that we have started parsing.
CompletableFuture<Throwable> cleanupError = new CompletableFuture<>();
new Thread(() ->
{
try
{
assertTrue(reading.await(5, TimeUnit.SECONDS));
mpis.deleteParts();
cleanupError.complete(null);
}
catch (Throwable t)
{
cleanupError.complete(t);
}
}).start();

// The call to getParts should throw an error.
Throwable error = assertThrows(IOException.class, mpis::getParts);
assertThat(error.getMessage(), is("ERROR"));

// There was no error with the cleanup.
assertNull(cleanupError.get());

// No tmp files are remaining.
String[] fileList = _tmpDir.list();
assertNotNull(fileList);
assertThat(fileList.length, is(0));
}

@Test
public void testParseAfterCleanUp() throws Exception
{
final InputStream input = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());
MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(input, _contentType, config, _tmpDir);

mpis.deleteParts();

// The call to getParts should throw because we have already cleaned up the parts.
Throwable error = assertThrows(IOException.class, mpis::getParts);
assertThat(error.getMessage(), is("DELETED"));

// Even though we called getParts() we never even created the tmp directory as we had already called deleteParts().
assertFalse(_tmpDir.exists());
}

@Test
public void testLFOnlyRequest()
Expand Down