From 0eb363de94c0f0e17ba717590993570f8908c445 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 27 Sep 2019 00:55:44 -0700 Subject: [PATCH] Close eventfd shutdown/wakeup race by closely tracking epoll edges (#9586) (#9612) Motivation This is another iteration of #9476. Modifications Instead of maintaining a count of all writes performed and then using reads during shutdown to ensure all are accounted for, just set a flag after each write and don't reset it until the corresponding event has been returned from epoll_wait. This requires that while a write is still pending we don't reset wakenUp, i.e. continue to block writes from the wakeup() method. Result Race condition eliminated. Fixes #9362 Co-authored-by: Norman Maurer --- .../src/main/c/netty_epoll_native.c | 65 ++++++-------- .../io/netty/channel/epoll/EpollHandler.java | 87 ++++++++++++++----- .../java/io/netty/channel/epoll/Native.java | 22 ++++- 3 files changed, 112 insertions(+), 62 deletions(-) diff --git a/transport-native-epoll/src/main/c/netty_epoll_native.c b/transport-native-epoll/src/main/c/netty_epoll_native.c index 79135788248..4d74c8faedc 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_native.c +++ b/transport-native-epoll/src/main/c/netty_epoll_native.c @@ -187,48 +187,38 @@ static jint netty_epoll_native_epollCreate(JNIEnv* env, jclass clazz) { return efd; } -static jint netty_epoll_native_epollWait0(JNIEnv* env, jclass clazz, jint efd, jlong address, jint len, jint timerFd, jint tvSec, jint tvNsec) { +static jint netty_epoll_native_epollWait(JNIEnv* env, jclass clazz, jint efd, jlong address, jint len, jint timeout) { struct epoll_event *ev = (struct epoll_event*) (intptr_t) address; int result, err; + do { + result = epoll_wait(efd, ev, len, timeout); + if (result >= 0) { + return result; + } + } while((err = errno) == EINTR); + return -err; +} + +// This method is deprecated! +static jint netty_epoll_native_epollWait0(JNIEnv* env, jclass clazz, jint efd, jlong address, jint len, jint timerFd, jint tvSec, jint tvNsec) { if (tvSec == 0 && tvNsec == 0) { // Zeros = poll (aka return immediately). - do { - result = epoll_wait(efd, ev, len, 0); - if (result >= 0) { - return result; - } - } while((err = errno) == EINTR); - } else { - // only reschedule the timer if there is a newer event. - // -1 is a special value used by EpollEventLoop. - if (tvSec != ((jint) -1) && tvNsec != ((jint) -1)) { - struct itimerspec ts; - memset(&ts.it_interval, 0, sizeof(struct timespec)); - ts.it_value.tv_sec = tvSec; - ts.it_value.tv_nsec = tvNsec; - if (timerfd_settime(timerFd, 0, &ts, NULL) < 0) { - netty_unix_errors_throwChannelExceptionErrorNo(env, "timerfd_settime() failed: ", errno); - return -1; - } - } - do { - result = epoll_wait(efd, ev, len, -1); - if (result > 0) { - // Detect timeout, and preserve the epoll_wait API. - if (result == 1 && ev[0].data.fd == timerFd) { - // We assume that timerFD is in ET mode. So we must consume this event to ensure we are notified - // of future timer events because ET mode only notifies a single time until the event is consumed. - uint64_t timerFireCount; - // We don't care what the result is. We just want to consume the wakeup event and reset ET. - result = read(timerFd, &timerFireCount, sizeof(uint64_t)); - return 0; - } - return result; - } - } while((err = errno) == EINTR); + return netty_epoll_native_epollWait(env, clazz, efd, address, len, 0); } - return -err; + // only reschedule the timer if there is a newer event. + // -1 is a special value used by EpollEventLoop. + if (tvSec != ((jint) -1) && tvNsec != ((jint) -1)) { + struct itimerspec ts; + memset(&ts.it_interval, 0, sizeof(struct timespec)); + ts.it_value.tv_sec = tvSec; + ts.it_value.tv_nsec = tvNsec; + if (timerfd_settime(timerFd, 0, &ts, NULL) < 0) { + netty_unix_errors_throwChannelExceptionErrorNo(env, "timerfd_settime() failed: ", errno); + return -1; + } + } + return netty_epoll_native_epollWait(env, clazz, efd, address, len, -1); } static inline void cpu_relax() { @@ -497,7 +487,8 @@ static const JNINativeMethod fixed_method_table[] = { { "eventFdRead", "(I)V", (void *) netty_epoll_native_eventFdRead }, { "timerFdRead", "(I)V", (void *) netty_epoll_native_timerFdRead }, { "epollCreate", "()I", (void *) netty_epoll_native_epollCreate }, - { "epollWait0", "(IJIIII)I", (void *) netty_epoll_native_epollWait0 }, + { "epollWait0", "(IJIIII)I", (void *) netty_epoll_native_epollWait0 }, // This method is deprecated! + { "epollWait", "(IJII)I", (void *) netty_epoll_native_epollWait }, { "epollBusyWait0", "(IJI)I", (void *) netty_epoll_native_epollBusyWait0 }, { "epollCtlAdd0", "(III)I", (void *) netty_epoll_native_epollCtlAdd0 }, { "epollCtlMod0", "(III)I", (void *) netty_epoll_native_epollCtlMod0 }, diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java index d978402c414..463cae67841 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.util.BitSet; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; @@ -48,8 +49,6 @@ */ public class EpollHandler implements IoHandler { private static final InternalLogger logger = InternalLoggerFactory.getInstance(EpollHandler.class); - private static final AtomicIntegerFieldUpdater WAKEN_UP_UPDATER = - AtomicIntegerFieldUpdater.newUpdater(EpollHandler.class, "wakenUp"); static { // Ensure JNI is initialized by the time this class is loaded by this time! @@ -73,8 +72,8 @@ public class EpollHandler implements IoHandler { private final SelectStrategy selectStrategy; private final IntSupplier selectNowSupplier = this::epollWaitNow; - @SuppressWarnings("unused") // AtomicIntegerFieldUpdater - private volatile int wakenUp; + private final AtomicInteger wakenUp = new AtomicInteger(1); + private boolean pendingWakeup; // See http://man7.org/linux/man-pages/man2/timerfd_create.2.html. private static final long MAX_SCHEDULED_TIMERFD_NS = 999999999; @@ -219,7 +218,7 @@ public final void deregister(Channel channel) throws Exception { @Override public final void wakeup(boolean inEventLoop) { - if (!inEventLoop && WAKEN_UP_UPDATER.getAndSet(this, 1) == 0) { + if (!inEventLoop && wakenUp.getAndSet(1) == 0) { // write to the evfd which will then wake-up epoll_wait(...) Native.eventFdWrite(eventFd.intValue(), 1L); } @@ -309,13 +308,18 @@ private int epollWait(IoExecutionContext context) throws IOException { } private int epollWaitNow() throws IOException { - return Native.epollWait(epollFd, events, timerFd, 0, 0); + return Native.epollWait(epollFd, events, true); } private int epollBusyWait() throws IOException { return Native.epollBusyWait(epollFd, events); } + private int epollWaitTimeboxed() throws IOException { + // Wait with 1 second "safeguard" timeout + return Native.epollWait(epollFd, events, 1000); + } + @Override public final int run(IoExecutionContext context) { int handled = 0; @@ -331,15 +335,36 @@ public final int run(IoExecutionContext context) { break; case SelectStrategy.SELECT: - if (wakenUp == 1) { - Native.eventFdWrite(eventFd.intValue(), 1L); - wakenUp = 0; - } - if (context.canBlock()) { - strategy = epollWait(context); + if (pendingWakeup) { + // We are going to be immediately woken so no need to reset wakenUp + // or check for timerfd adjustment. + strategy = epollWaitTimeboxed(); + if (strategy != 0) { + break; + } + // We timed out so assume that we missed the write event due to an + // abnormally failed syscall (the write itself or a prior epoll_wait) + logger.warn("Missed eventfd write (not seen after > 1 second)"); + pendingWakeup = false; + if (!context.canBlock()) { + break; + } + // fall-through } - // fallthrough + wakenUp.set(0); + try { + if (context.canBlock()) { + strategy = epollWait(context); + } + } finally { + // Try get() first to avoid much more expensive CAS in the case we + // were woken via the wakeup() method (submitted task) + if (wakenUp.get() == 1 || wakenUp.getAndSet(1) == 1) { + pendingWakeup = true; + } + } + // fall-through default: } if (strategy > 0) { @@ -373,12 +398,6 @@ void handleLoopException(Throwable t) { @Override public void prepareToDestroy() { - try { - epollWaitNow(); - } catch (IOException ignore) { - // ignore on close - } - // Using the intermediate collection to prevent ConcurrentModificationException. // In the `close()` method, the channel is deleted from `channels` map. AbstractEpollChannel[] localChannels = channels.values().toArray(new AbstractEpollChannel[0]); @@ -391,7 +410,9 @@ public void prepareToDestroy() { private void processReady(EpollEventArray events, int ready) { for (int i = 0; i < ready; i ++) { final int fd = events.fd(i); - if (fd == eventFd.intValue() || fd == timerFd.intValue()) { + if (fd == eventFd.intValue()) { + pendingWakeup = false; + } else if (fd == timerFd.intValue()) { // Just ignore as we use ET mode for the eventfd and timerfd. // // See also https://stackoverflow.com/a/12492308/1074097 @@ -453,10 +474,23 @@ private void processReady(EpollEventArray events, int ready) { @Override public final void destroy() { try { - try { - epollFd.close(); - } catch (IOException e) { - logger.warn("Failed to close the epoll fd.", e); + // Ensure any in-flight wakeup writes have been performed prior to closing eventFd. + while (pendingWakeup) { + try { + int count = epollWaitTimeboxed(); + if (count == 0) { + // We timed-out so assume that the write we're expecting isn't coming + break; + } + for (int i = 0; i < count; i++) { + if (events.fd(i) == eventFd.intValue()) { + pendingWakeup = false; + break; + } + } + } catch (IOException ignore) { + // ignore + } } try { eventFd.close(); @@ -468,6 +502,11 @@ public final void destroy() { } catch (IOException e) { logger.warn("Failed to close the timer fd.", e); } + try { + epollFd.close(); + } catch (IOException e) { + logger.warn("Failed to close the epoll fd.", e); + } } finally { // release native memory if (iovArray != null) { diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java index 296ea3c2048..9fb92aaf1ca 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java @@ -91,6 +91,10 @@ public static FileDescriptor newEpollCreate() { private static native int epollCreate(); + /** + * @deprecated this method is no longer supported. This functionality is internal to this package. + */ + @Deprecated public static int epollWait(FileDescriptor epollFd, EpollEventArray events, FileDescriptor timerFd, int timeoutSec, int timeoutNs) throws IOException { int ready = epollWait0(epollFd.intValue(), events.memoryAddress(), events.length(), timerFd.intValue(), @@ -100,7 +104,21 @@ public static int epollWait(FileDescriptor epollFd, EpollEventArray events, File } return ready; } - private static native int epollWait0(int efd, long address, int len, int timerFd, int timeoutSec, int timeoutNs); + + static int epollWait(FileDescriptor epollFd, EpollEventArray events, boolean immediatePoll) throws IOException { + return epollWait(epollFd, events, immediatePoll ? 0 : -1); + } + + /** + * This uses epoll's own timeout and does not reset/re-arm any timerfd + */ + static int epollWait(FileDescriptor epollFd, EpollEventArray events, int timeoutMillis) throws IOException { + int ready = epollWait(epollFd.intValue(), events.memoryAddress(), events.length(), timeoutMillis); + if (ready < 0) { + throw newIOException("epoll_wait", ready); + } + return ready; + } /** * Non-blocking variant of @@ -115,6 +133,8 @@ public static int epollBusyWait(FileDescriptor epollFd, EpollEventArray events) return ready; } + private static native int epollWait0(int efd, long address, int len, int timerFd, int timeoutSec, int timeoutNs); + private static native int epollWait(int efd, long address, int len, int timeout); private static native int epollBusyWait0(int efd, long address, int len); public static void epollCtlAdd(int efd, final int fd, final int flags) throws IOException {