Skip to content

Commit

Permalink
Implement OIDC auth for async (#1131)
Browse files Browse the repository at this point in the history
  • Loading branch information
katcharov committed Feb 13, 2024
1 parent 5172061 commit ec9887b
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 80 deletions.
36 changes: 0 additions & 36 deletions driver-core/src/main/com/mongodb/assertions/Assertions.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package com.mongodb.assertions;

import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.lang.Nullable;

import java.util.Collection;
Expand Down Expand Up @@ -79,25 +78,6 @@ public static <T> Iterable<T> notNullElements(final String name, final Iterable<
return values;
}

/**
* Throw IllegalArgumentException if the value is null.
*
* @param name the parameter name
* @param value the value that should not be null
* @param callback the callback that also is passed the exception if the value is null
* @param <T> the value type
* @return the value
* @throws java.lang.IllegalArgumentException if value is null
*/
public static <T> T notNull(final String name, final T value, final SingleResultCallback<?> callback) {
if (value == null) {
IllegalArgumentException exception = new IllegalArgumentException(name + " can not be null");
callback.completeExceptionally(exception);
throw exception;
}
return value;
}

/**
* Throw IllegalStateException if the condition if false.
*
Expand All @@ -111,22 +91,6 @@ public static void isTrue(final String name, final boolean condition) {
}
}

/**
* Throw IllegalStateException if the condition if false.
*
* @param name the name of the state that is being checked
* @param condition the condition about the parameter to check
* @param callback the callback that also is passed the exception if the condition is not true
* @throws java.lang.IllegalStateException if the condition is false
*/
public static void isTrue(final String name, final boolean condition, final SingleResultCallback<?> callback) {
if (!condition) {
IllegalStateException exception = new IllegalStateException("state should be: " + name);
callback.completeExceptionally(exception);
throw exception;
}
}

/**
* Throw IllegalArgumentException if the condition if false.
*
Expand Down
22 changes: 20 additions & 2 deletions driver-core/src/main/com/mongodb/internal/Locks.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package com.mongodb.internal;

import com.mongodb.MongoInterruptedException;
import com.mongodb.internal.async.AsyncRunnable;
import com.mongodb.internal.async.SingleResultCallback;

import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
Expand All @@ -36,7 +38,23 @@ public static void withLock(final Lock lock, final Runnable action) {
});
}

public static <V> V withLock(final StampedLock lock, final Supplier<V> supplier) {
public static void withLockAsync(final StampedLock lock, final AsyncRunnable runnable,
final SingleResultCallback<Void> callback) {
long stamp;
try {
stamp = lock.writeLockInterruptibly();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
callback.onResult(null, new MongoInterruptedException("Interrupted waiting for lock", e));
return;
}

runnable.thenAlwaysRunAndFinish(() -> {
lock.unlockWrite(stamp);
}, callback);
}

public static void withLock(final StampedLock lock, final Runnable runnable) {
long stamp;
try {
stamp = lock.writeLockInterruptibly();
Expand All @@ -45,7 +63,7 @@ public static <V> V withLock(final StampedLock lock, final Supplier<V> supplier)
throw new MongoInterruptedException("Interrupted waiting for lock", e);
}
try {
return supplier.get();
runnable.run();
} finally {
lock.unlockWrite(stamp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.mongodb.lang.Nullable;

import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;

/**
* <p>This class is not part of the public API and may be removed or changed at any time</p>
Expand Down Expand Up @@ -104,4 +105,10 @@ public void reauthenticate(final InternalConnection connection) {
authenticate(connection, connection.getDescription());
}

public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback<Void> callback) {
beginAsync().thenRun((c) -> {
authenticateAsync(connection, connection.getDescription(), c);
}).finish(callback);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public interface InternalConnection extends BufferProvider {
ServerDescription getInitialServerDescription();

/**
* Opens the connection so its ready for use
* Opens the connection so its ready for use. Will perform a handshake.
*/
void open();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.mongodb.event.CommandListener;
import com.mongodb.internal.ResourceUtil;
import com.mongodb.internal.VisibleForTesting;
import com.mongodb.internal.async.AsyncSupplier;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
Expand All @@ -68,9 +69,12 @@
import java.util.function.Supplier;

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.assertions.Assertions.assertNull;
import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback;
import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate;
import static com.mongodb.internal.connection.CommandHelper.HELLO;
import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO;
import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO_LOWER;
Expand Down Expand Up @@ -238,7 +242,7 @@ public void open() {

@Override
public void openAsync(final SingleResultCallback<Void> callback) {
isTrue("Open already called", stream == null, callback);
assertNull(stream);
try {
stream = streamFactory.create(serverId.getAddress());
stream.openAsync(new AsyncCompletionHandler<Void>() {
Expand Down Expand Up @@ -364,17 +368,48 @@ public <T> T sendAndReceive(final CommandMessage message, final Decoder<T> decod
try {
return sendAndReceiveInternal.get();
} catch (MongoCommandException e) {
if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) {
authenticated.set(false);
authenticator.reauthenticate(this);
authenticated.set(true);
return sendAndReceiveInternal.get();
if (reauthenticationIsTriggered(e)) {
return reauthenticateAndRetry(sendAndReceiveInternal);
}
throw e;
}
}

public static boolean triggersReauthentication(@Nullable final Throwable t) {
@Override
public <T> void sendAndReceiveAsync(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback<T> callback) {

AsyncSupplier<T> sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal(
message, decoder, sessionContext, requestContext, operationContext, c);
beginAsync().<T>thenSupply(c -> {
sendAndReceiveAsyncInternal.getAsync(c);
}).onErrorIf(e -> reauthenticationIsTriggered(e), c -> {
reauthenticateAndRetryAsync(sendAndReceiveAsyncInternal, c);
}).finish(callback);
}

private <T> T reauthenticateAndRetry(final Supplier<T> operation) {
authenticated.set(false);
assertNotNull(authenticator).reauthenticate(this);
authenticated.set(true);
return operation.get();
}

private <T> void reauthenticateAndRetryAsync(final AsyncSupplier<T> operation,
final SingleResultCallback<T> callback) {
beginAsync().thenRun(c -> {
authenticated.set(false);
assertNotNull(authenticator).reauthenticateAsync(this, c);
}).<T>thenSupply((c) -> {
authenticated.set(true);
operation.getAsync(c);
}).finish(callback);
}

public boolean reauthenticationIsTriggered(@Nullable final Throwable t) {
if (!shouldAuthenticate(authenticator, this.description)) {
return false;
}
if (t instanceof MongoCommandException) {
MongoCommandException e = (MongoCommandException) t;
return e.getErrorCode() == 391;
Expand Down Expand Up @@ -501,11 +536,8 @@ private <T> T receiveCommandMessageResponse(final Decoder<T> decoder,
}
}

@Override
public <T> void sendAndReceiveAsync(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
private <T> void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback<T> callback) {
notNull("stream is open", stream, callback);

if (isClosed()) {
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
return;
Expand Down Expand Up @@ -616,7 +648,7 @@ public void sendMessage(final List<ByteBuf> byteBuffers, final int lastRequestId

@Override
public ResponseBuffers receiveMessage(final int responseTo) {
notNull("stream is open", stream);
assertNotNull(stream);
if (isClosed()) {
throw new MongoSocketClosedException("Cannot read from a closed stream", getServerAddress());
}
Expand All @@ -634,8 +666,9 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional
}

@Override
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId, final SingleResultCallback<Void> callback) {
notNull("stream is open", stream, callback);
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId,
final SingleResultCallback<Void> callback) {
assertNotNull(stream);

if (isClosed()) {
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
Expand Down Expand Up @@ -667,7 +700,7 @@ public void failed(final Throwable t) {

@Override
public void receiveMessageAsync(final int responseTo, final SingleResultCallback<ResponseBuffers> callback) {
isTrue("stream is open", stream != null, callback);
assertNotNull(stream);

if (isClosed()) {
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.concurrent.locks.StampedLock;

import static com.mongodb.internal.Locks.withInterruptibleLock;
import static com.mongodb.internal.Locks.withLock;
import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry;

/**
Expand Down

0 comments on commit ec9887b

Please sign in to comment.