Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #4383 - avoid NPE from MultiPart Cleanup #4388

Closed
wants to merge 9 commits into from
Expand Up @@ -35,6 +35,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Part;
Expand All @@ -58,19 +59,27 @@
*/
public class MultiPartFormInputStream
{
private enum State
{
UNPARSED,
PARSING,
PARSED,
CLOSING,
CLOSED
}

private static final Logger LOG = Log.getLogger(MultiPartFormInputStream.class);
private static final MultiMap<Part> EMPTY_MAP = new MultiMap<>(Collections.emptyMap());
private InputStream _in;
private MultipartConfigElement _config;
private String _contentType;
private MultiMap<Part> _parts;
private Throwable _err;
private File _tmpDir;
private File _contextTmpDir;
private boolean _deleteOnExit;
private boolean _writeFilesWithFilenames;
private boolean _parsed;
private int _bufferSize = 16 * 1024;
private final MultiMap<Part> _parts = new MultiMap<>();
janbartel marked this conversation as resolved.
Show resolved Hide resolved
private final InputStream _in;
private final MultipartConfigElement _config;
private final File _contextTmpDir;
private final String _contentType;
private volatile Throwable _err;
private volatile File _tmpDir;
private volatile boolean _deleteOnExit;
private volatile boolean _writeFilesWithFilenames;
private volatile int _bufferSize = 16 * 1024;
private State state = State.UNPARSED;

public class MultiPart implements Part
{
Expand Down Expand Up @@ -333,36 +342,33 @@ public String getContentDispositionFilename()
public MultiPartFormInputStream(InputStream in, String contentType, MultipartConfigElement config, File contextTmpDir)
{
_contentType = contentType;
_config = config;
_contextTmpDir = contextTmpDir;
if (_contextTmpDir == null)
_contextTmpDir = new File(System.getProperty("java.io.tmpdir"));

if (_config == null)
_config = new MultipartConfigElement(_contextTmpDir.getAbsolutePath());
_contextTmpDir = (contextTmpDir != null) ? contextTmpDir : new File(System.getProperty("java.io.tmpdir"));
_config = (config != null) ? config : new MultipartConfigElement(_contextTmpDir.getAbsolutePath());

if (in instanceof ServletInputStream)
{
if (((ServletInputStream)in).isFinished())
{
_parts = EMPTY_MAP;
_parsed = true;
_in = null;
state = State.PARSED;
return;
}
}

_in = new BufferedInputStream(in);
}

/**
* @return whether the list of parsed parts is empty
* @deprecated use getParts().isEmpty()
*/
@Deprecated
public boolean isEmpty()
{
if (_parts == null)
if (_parts.isEmpty())
return true;

Collection<List<Part>> values = _parts.values();
for (List<Part> partList : values)
for (List<Part> partList : _parts.values())
{
if (!partList.isEmpty())
return false;
Expand All @@ -379,7 +385,7 @@ public boolean isEmpty()
@Deprecated
public Collection<Part> getParsedParts()
{
if (_parts == null)
if (_parts.isEmpty())
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
return Collections.emptyList();

Collection<List<Part>> values = _parts.values();
Expand All @@ -397,6 +403,26 @@ public Collection<Part> getParsedParts()
*/
public void deleteParts()
{
synchronized (this)
{
switch (state)
{
case CLOSED:
case UNPARSED:
state = State.CLOSED;
return;

case PARSING:
state = State.CLOSING;
return;

case PARSED:
case CLOSING:
state = State.CLOSED;
break;
}
}

MultiException err = null;
for (List<Part> parts : _parts.values())
{
Expand Down Expand Up @@ -428,18 +454,9 @@ public void deleteParts()
*/
public Collection<Part> getParts() throws IOException
{
if (!_parsed)
parse();
parse();
throwIfError();

Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
}
return parts;
return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList());
}

/**
Expand All @@ -451,8 +468,7 @@ public Collection<Part> getParts() throws IOException
*/
public Part getPart(String name) throws IOException
{
if (!_parsed)
parse();
parse();
throwIfError();
return _parts.getValue(name, 0);
}
Expand Down Expand Up @@ -483,18 +499,26 @@ protected void throwIfError() throws IOException
*/
protected void parse()
{
// have we already parsed the input?
if (_parsed)
return;
_parsed = true;
synchronized (this)
{
switch (state)
{
case UNPARSED:
state = State.PARSING;
break;

case PARSED:
return;

default:
_err = new IllegalStateException(state.name());
return;
}
}

MultiPartParser parser = null;
Handler handler = new Handler();
try
{
// initialize
_parts = new MultiMap<>();

// if its not a multipart request, don't parse it
if (_contentType == null || !_contentType.startsWith("multipart/form-data"))
return;
Expand Down Expand Up @@ -525,16 +549,23 @@ else if ("".equals(_config.getLocation()))
contentTypeBoundary = QuotedStringTokenizer.unquote(value(_contentType.substring(bstart, bend)).trim());
}

parser = new MultiPartParser(handler, contentTypeBoundary);
parser = new MultiPartParser(new Handler(), contentTypeBoundary);
byte[] data = new byte[_bufferSize];
int len;
long total = 0;

while (true)
{
synchronized (this)
{
if (state != State.PARSING)
{
_err = new IllegalStateException(state.name());
return;
}
}

len = _in.read(data);

if (len > 0)
{
// keep running total of size of bytes read from input and throw an exception if exceeds MultipartConfigElement._maxRequestSize
Expand Down Expand Up @@ -588,6 +619,29 @@ else if (len == -1)
if (parser != null)
parser.parse(BufferUtil.EMPTY_BUFFER, true);
}
finally
{
boolean cleanup = false;
synchronized (this)
{
switch (state)
{
case PARSING:
state = State.PARSED;
break;

case CLOSING:
cleanup = true;
break;

default:
_err = new IllegalStateException(state.name());
}
}

if (cleanup)
deleteParts();
}
}

class Handler implements MultiPartParser.Handler
Expand Down
Expand Up @@ -25,6 +25,8 @@
import java.io.InputStream;
import java.util.Base64;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ReadListener;
Expand Down Expand Up @@ -509,6 +511,77 @@ public void testPartTmpFileDeletion() throws Exception
assertThat(stuff.exists(), is(false)); //tmp file was removed after cleanup
}

@Test
public void testAsyncCleanUp() throws Exception
{
final CountDownLatch reading = new CountDownLatch(1);
final InputStream wrappedStream = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());

// This stream won't allow the parser to exit because it will never return anything less than 0.
InputStream slowStream = new InputStream() {
@Override
public int read(byte[] b, int off, int len) throws IOException
{
return Math.max(0, super.read(b, off, len));
}

@Override
public int read() throws IOException
{
reading.countDown();
return wrappedStream.read();
}
};

MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(slowStream, _contentType, config, _tmpDir);

// In another thread delete the parts when we detect that we have started parsing.
CompletableFuture<Throwable> cleanupError = new CompletableFuture<>();
new Thread(() ->
{
try
{
assertTrue(reading.await(5, TimeUnit.SECONDS));
mpis.deleteParts();
cleanupError.complete(null);
}
catch (Throwable t)
{
cleanupError.complete(t);
}
}).start();

// The call to getParts should throw an error.
Throwable error = assertThrows(IllegalStateException.class, mpis::getParts);
assertThat(error.getMessage(), is("CLOSING"));

// There was no error with the cleanup.
assertNull(cleanupError.get());

// No tmp files are remaining.
String[] fileList = _tmpDir.list();
assertNotNull(fileList);
assertThat(fileList.length, is(0));
}

@Test
public void testParseAfterCleanUp() throws Exception
{
final InputStream input = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());
MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(input, _contentType, config, _tmpDir);

mpis.deleteParts();

// The call to getParts should throw because we have already cleaned up the parts.
Throwable error = assertThrows(IllegalStateException.class, mpis::getParts);
assertThat(error.getMessage(), is("CLOSED"));

// Even though we called getParts() we never even created the tmp directory as we had already called deleteParts().
assertFalse(_tmpDir.exists());
}

@Test
public void testLFOnlyRequest()
throws Exception
Expand Down
Expand Up @@ -46,6 +46,7 @@ public interface MultiParts extends Closeable

Part getPart(String name) throws IOException;

@Deprecated
boolean isEmpty();

ContextHandler.Context getContext();
Expand Down