Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #4383 - avoid NPE from MultiPart Cleanup #4388

Closed
wants to merge 9 commits into from
Expand Up @@ -35,6 +35,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Part;
Expand All @@ -59,18 +60,17 @@
public class MultiPartFormInputStream
{
private static final Logger LOG = Log.getLogger(MultiPartFormInputStream.class);
private static final MultiMap<Part> EMPTY_MAP = new MultiMap<>(Collections.emptyMap());
private InputStream _in;
private MultipartConfigElement _config;
private String _contentType;
private MultiMap<Part> _parts;
private Throwable _err;
private File _tmpDir;
private File _contextTmpDir;
private boolean _deleteOnExit;
private boolean _writeFilesWithFilenames;
private boolean _parsed;
private int _bufferSize = 16 * 1024;
private final MultiMap<Part> _parts = new MultiMap<>();
janbartel marked this conversation as resolved.
Show resolved Hide resolved
private final InputStream _in;
private final MultipartConfigElement _config;
private final File _contextTmpDir;
private final String _contentType;
private volatile Throwable _err;
private volatile File _tmpDir;
private volatile boolean _deleteOnExit;
private volatile boolean _writeFilesWithFilenames;
private volatile boolean _parsed;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need _parsed if you have the state variable.

private volatile int _bufferSize = 16 * 1024;

public class MultiPart implements Part
{
Expand Down Expand Up @@ -333,42 +333,45 @@ public String getContentDispositionFilename()
public MultiPartFormInputStream(InputStream in, String contentType, MultipartConfigElement config, File contextTmpDir)
{
_contentType = contentType;
_config = config;
_contextTmpDir = contextTmpDir;
if (_contextTmpDir == null)
_contextTmpDir = new File(System.getProperty("java.io.tmpdir"));

if (_config == null)
_config = new MultipartConfigElement(_contextTmpDir.getAbsolutePath());
_contextTmpDir = (contextTmpDir != null) ? contextTmpDir : new File(System.getProperty("java.io.tmpdir"));
_config = (config != null) ? config : new MultipartConfigElement(_contextTmpDir.getAbsolutePath());

if (in instanceof ServletInputStream)
{
if (((ServletInputStream)in).isFinished())
{
_parts = EMPTY_MAP;
_in = null;
_parsed = true;
return;
}
}

_in = new BufferedInputStream(in);
}

/**
* @return whether the list of parsed parts is empty
* @deprecated use getParts().isEmpty()
*/
@Deprecated
public boolean isEmpty()
{
if (_parts == null)
return true;

Collection<List<Part>> values = _parts.values();
for (List<Part> partList : values)
synchronized (this)
{
if (!partList.isEmpty())
return false;
}
if (!_parsed)
throw new IllegalStateException();

if (_parts.isEmpty())
return true;

return true;
for (List<Part> partList : _parts.values())
{
if (!partList.isEmpty())
return false;
}

return true;
}
}

/**
Expand All @@ -379,45 +382,52 @@ public boolean isEmpty()
@Deprecated
public Collection<Part> getParsedParts()
{
if (_parts == null)
return Collections.emptyList();

Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
synchronized (this)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
if (_parts.isEmpty())
return Collections.emptyList();

Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
}
return parts;
}
return parts;
}

/**
* Delete any tmp storage for parts, and clear out the parts list.
*/
public void deleteParts()
{
MultiException err = null;
for (List<Part> parts : _parts.values())
// TODO: Can we cancel parsing somehow instead of blocking.
synchronized (this)
{
for (Part p : parts)
MultiException err = null;
for (List<Part> parts : _parts.values())
{
try
{
((MultiPart)p).cleanUp();
}
catch (Exception e)
for (Part p : parts)
{
if (err == null)
err = new MultiException();
err.add(e);
try
{
((MultiPart)p).cleanUp();
}
catch (Exception e)
{
if (err == null)
err = new MultiException();
err.add(e);
}
}
}
}
_parts.clear();
_parts.clear();

if (err != null)
err.ifExceptionThrowRuntime();
if (err != null)
err.ifExceptionThrowRuntime();
}
}

/**
Expand All @@ -428,18 +438,13 @@ public void deleteParts()
*/
public Collection<Part> getParts() throws IOException
{
if (!_parsed)
parse();
throwIfError();

Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
synchronized (this)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
if (!_parsed)
parse();
throwIfError();
return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList());
}
return parts;
}

/**
Expand All @@ -451,10 +456,13 @@ public Collection<Part> getParts() throws IOException
*/
public Part getPart(String name) throws IOException
{
if (!_parsed)
parse();
throwIfError();
return _parts.getValue(name, 0);
synchronized (this)
{
if (!_parsed)
parse();
throwIfError();
return _parts.getValue(name, 0);
}
}

/**
Expand Down Expand Up @@ -489,12 +497,8 @@ protected void parse()
_parsed = true;

MultiPartParser parser = null;
Handler handler = new Handler();
try
{
// initialize
_parts = new MultiMap<>();

// if its not a multipart request, don't parse it
if (_contentType == null || !_contentType.startsWith("multipart/form-data"))
return;
Expand Down Expand Up @@ -525,14 +529,13 @@ else if ("".equals(_config.getLocation()))
contentTypeBoundary = QuotedStringTokenizer.unquote(value(_contentType.substring(bstart, bend)).trim());
}

parser = new MultiPartParser(handler, contentTypeBoundary);
parser = new MultiPartParser(new Handler(), contentTypeBoundary);
byte[] data = new byte[_bufferSize];
int len;
long total = 0;

while (true)
{

len = _in.read(data);

if (len > 0)
Expand Down
Expand Up @@ -46,6 +46,7 @@ public interface MultiParts extends Closeable

Part getPart(String name) throws IOException;

@Deprecated
boolean isEmpty();

ContextHandler.Context getContext();
Expand Down
Expand Up @@ -2310,7 +2310,7 @@ public Collection<Part> getParts() throws IOException, ServletException
return getParts(null);
}

private Collection<Part> getParts(MultiMap<String> params) throws IOException
private synchronized Collection<Part> getParts(MultiMap<String> params) throws IOException
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
{
if (_multiParts == null)
_multiParts = (MultiParts)getAttribute(MULTIPARTS);
Expand Down
Expand Up @@ -351,16 +351,23 @@ public void testMultiPart() throws Exception
@Override
public void requestDestroyed(ServletRequestEvent sre)
{
MultiParts m = (MultiParts)sre.getServletRequest().getAttribute(Request.MULTIPARTS);
assertNotNull(m);
ContextHandler.Context c = m.getContext();
assertNotNull(c);
assertTrue(c == sre.getServletContext());
assertTrue(!m.isEmpty());
assertTrue(testTmpDir.list().length == 2);
super.requestDestroyed(sre);
String[] files = testTmpDir.list();
assertTrue(files.length == 0);
try
{
MultiParts m = (MultiParts)sre.getServletRequest().getAttribute(Request.MULTIPARTS);
assertNotNull(m);
ContextHandler.Context c = m.getContext();
assertNotNull(c);
assertTrue(c == sre.getServletContext());
assertTrue(!m.getParts().isEmpty());
assertTrue(testTmpDir.list().length == 2);
super.requestDestroyed(sre);
String[] files = testTmpDir.list();
assertTrue(files.length == 0);
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
});
_server.stop();
Expand Down Expand Up @@ -411,16 +418,23 @@ public void testUtilMultiPart() throws Exception
@Override
public void requestDestroyed(ServletRequestEvent sre)
{
MultiParts m = (MultiParts)sre.getServletRequest().getAttribute(Request.MULTIPARTS);
assertNotNull(m);
ContextHandler.Context c = m.getContext();
assertNotNull(c);
assertTrue(c == sre.getServletContext());
assertTrue(!m.isEmpty());
assertTrue(testTmpDir.list().length == 2);
super.requestDestroyed(sre);
String[] files = testTmpDir.list();
assertTrue(files.length == 0);
try
{
MultiParts m = (MultiParts)sre.getServletRequest().getAttribute(Request.MULTIPARTS);
assertNotNull(m);
ContextHandler.Context c = m.getContext();
assertNotNull(c);
assertTrue(c == sre.getServletContext());
assertTrue(!m.getParts().isEmpty());
assertTrue(testTmpDir.list().length == 2);
super.requestDestroyed(sre);
String[] files = testTmpDir.list();
assertTrue(files.length == 0);
}
catch (IOException t)
{
throw new RuntimeException(t);
}
}
});
_server.stop();
Expand Down