Skip to content

Commit

Permalink
Issue #4383 - use atomic state for multipart cleanup
Browse files Browse the repository at this point in the history
- Removed synchronization for parsing by two threads.

- Introduced an atomic state to decide when it is safe to remove
the parts. The call to deleteParts will now cancel the parsing and
only delete the parts when the parser exits.

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
  • Loading branch information
lachlan-roberts committed Dec 16, 2019
1 parent 9fdd454 commit 6d15071
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 59 deletions.
Expand Up @@ -35,6 +35,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ServletInputStream;
Expand All @@ -59,7 +60,16 @@
*/
public class MultiPartFormInputStream
{
private enum State
{
UNPARSED,
PARSING,
ERROR,
COMPLETED
}

private static final Logger LOG = Log.getLogger(MultiPartFormInputStream.class);
private final AtomicReference<State> state = new AtomicReference<>(State.UNPARSED);
private final MultiMap<Part> _parts = new MultiMap<>();
private final InputStream _in;
private final MultipartConfigElement _config;
Expand Down Expand Up @@ -356,22 +366,19 @@ public MultiPartFormInputStream(InputStream in, String contentType, MultipartCon
@Deprecated
public boolean isEmpty()
{
synchronized (this)
{
if (!_parsed)
throw new IllegalStateException();

if (_parts.isEmpty())
return true;

for (List<Part> partList : _parts.values())
{
if (!partList.isEmpty())
return false;
}
if (!_parsed)
throw new IllegalStateException();

if (_parts.isEmpty())
return true;

for (List<Part> partList : _parts.values())
{
if (!partList.isEmpty())
return false;
}

return true;
}

/**
Expand All @@ -382,52 +389,70 @@ public boolean isEmpty()
@Deprecated
public Collection<Part> getParsedParts()
{
synchronized (this)
{
if (_parts.isEmpty())
return Collections.emptyList();
if (_parts.isEmpty())
return Collections.emptyList();

Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
}
return parts;
Collection<List<Part>> values = _parts.values();
List<Part> parts = new ArrayList<>();
for (List<Part> o : values)
{
List<Part> asList = LazyList.getList(o, false);
parts.addAll(asList);
}
return parts;
}

/**
* Delete any tmp storage for parts, and clear out the parts list.
*/
public void deleteParts()
{
// TODO: Can we cancel parsing somehow instead of blocking.
synchronized (this)
while (true)
{
switch (state.get())
{
case PARSING:
state.compareAndSet(State.PARSING, State.ERROR);
Thread.yield();
continue;

case UNPARSED:
if (!state.compareAndSet(State.UNPARSED, State.COMPLETED))
continue;
break;

case ERROR:
Thread.yield();
continue;

case COMPLETED:
break;
}

break;
}

MultiException err = null;
for (List<Part> parts : _parts.values())
{
MultiException err = null;
for (List<Part> parts : _parts.values())
for (Part p : parts)
{
for (Part p : parts)
try
{
try
{
((MultiPart)p).cleanUp();
}
catch (Exception e)
{
if (err == null)
err = new MultiException();
err.add(e);
}
((MultiPart)p).cleanUp();
}
catch (Exception e)
{
if (err == null)
err = new MultiException();
err.add(e);
}
}
_parts.clear();

if (err != null)
err.ifExceptionThrowRuntime();
}
_parts.clear();

if (err != null)
err.ifExceptionThrowRuntime();
}

/**
Expand All @@ -438,13 +463,10 @@ public void deleteParts()
*/
public Collection<Part> getParts() throws IOException
{
synchronized (this)
{
if (!_parsed)
parse();
throwIfError();
return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList());
}
if (!_parsed)
parse();
throwIfError();
return _parts.values().stream().flatMap(List::stream).collect(Collectors.toList());
}

/**
Expand All @@ -456,13 +478,10 @@ public Collection<Part> getParts() throws IOException
*/
public Part getPart(String name) throws IOException
{
synchronized (this)
{
if (!_parsed)
parse();
throwIfError();
return _parts.getValue(name, 0);
}
if (!_parsed)
parse();
throwIfError();
return _parts.getValue(name, 0);
}

/**
Expand Down Expand Up @@ -534,8 +553,15 @@ else if ("".equals(_config.getLocation()))
int len;
long total = 0;

if (!state.compareAndSet(State.UNPARSED, State.PARSING))
throw new IllegalStateException("Could not start parsing " + state.get());

while (true)
{
State currentState = state.get();
if (currentState != State.PARSING)
throw new IllegalStateException("Unexpected state " + currentState);

len = _in.read(data);

if (len > 0)
Expand Down Expand Up @@ -591,6 +617,27 @@ else if (len == -1)
if (parser != null)
parser.parse(BufferUtil.EMPTY_BUFFER, true);
}
finally
{
while (true)
{
switch (state.get())
{
case PARSING:
if (!state.compareAndSet(State.PARSING, State.COMPLETED))
continue;
break;

case ERROR:
state.compareAndSet(State.ERROR, State.COMPLETED);
break;

default:
break;
}
break;
}
}
}

class Handler implements MultiPartParser.Handler
Expand Down
Expand Up @@ -25,6 +25,8 @@
import java.io.InputStream;
import java.util.Base64;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.servlet.MultipartConfigElement;
import javax.servlet.ReadListener;
Expand Down Expand Up @@ -509,6 +511,79 @@ public void testPartTmpFileDeletion() throws Exception
assertThat(stuff.exists(), is(false)); //tmp file was removed after cleanup
}

@Test
public void testAsyncCleanUp() throws Exception
{
final CountDownLatch reading = new CountDownLatch(1);
final InputStream wrappedStream = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());

// This stream won't allow the parser to exit because it will never return anything less than 0.
InputStream slowStream = new InputStream() {
@Override
public int read(byte[] b, int off, int len) throws IOException
{
return Math.max(0, super.read(b, off, len));
}

@Override
public int read() throws IOException
{
reading.countDown();
return wrappedStream.read();
}
};

MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(slowStream, _contentType, config, _tmpDir);

// In another thread delete the parts when we detect that we have started parsing.
CompletableFuture<Throwable> cleanupError = new CompletableFuture<>();
new Thread(() ->
{
try
{
assertTrue(reading.await(5, TimeUnit.SECONDS));
mpis.deleteParts();
cleanupError.complete(null);
}
catch (Throwable t)
{
cleanupError.complete(t);
}
}).start();

// The call to getParts should throw an error.
Throwable error = assertThrows(IllegalStateException.class, mpis::getParts);
assertThat(error.getMessage(), is("Unexpected state ERROR"));

// There was no error with the cleanup.
assertNull(cleanupError.get());

// No tmp files are remaining.
String[] fileList = _tmpDir.list();
assertNotNull(fileList);
assertThat(fileList.length, is(0));
}

@Test
public void testParseAfterCleanUp() throws Exception
{
final InputStream input = new ByteArrayInputStream(createMultipartRequestString("myFile").getBytes());
MultipartConfigElement config = new MultipartConfigElement(_dirname, 1024, 1024, 50);
MultiPartFormInputStream mpis = new MultiPartFormInputStream(input, _contentType, config, _tmpDir);

mpis.deleteParts();

// The call to getParts should throw because we have already cleaned up the parts.
Throwable error = assertThrows(IllegalStateException.class, mpis::getParts);
assertThat(error.getMessage(), is("Could not start parsing COMPLETED"));

// No tmp files are remaining.
String[] fileList = _tmpDir.list();
assertNotNull(fileList);
assertThat(fileList.length, is(0));
}

@Test
public void testLFOnlyRequest()
throws Exception
Expand Down
Expand Up @@ -2310,7 +2310,7 @@ public Collection<Part> getParts() throws IOException, ServletException
return getParts(null);
}

private synchronized Collection<Part> getParts(MultiMap<String> params) throws IOException
private Collection<Part> getParts(MultiMap<String> params) throws IOException
{
if (_multiParts == null)
_multiParts = (MultiParts)getAttribute(MULTIPARTS);
Expand Down

0 comments on commit 6d15071

Please sign in to comment.