Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Redis channel subscription issue when there is no context available #27361

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,137 @@
package io.quarkus.redis.client.deployment.patterns;

import static org.awaitility.Awaitility.await;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;

import javax.annotation.PreDestroy;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.redis.client.deployment.RedisTestResource;
import io.quarkus.redis.datasource.ReactiveRedisDataSource;
import io.quarkus.redis.datasource.RedisDataSource;
import io.quarkus.redis.datasource.pubsub.PubSubCommands;
import io.quarkus.redis.datasource.pubsub.ReactivePubSubCommands;
import io.quarkus.redis.datasource.string.StringCommands;
import io.quarkus.runtime.Startup;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.QuarkusTestResource;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.subscription.Cancellable;

@QuarkusTestResource(RedisTestResource.class)
public class PubSubOnStartupTest {
@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(MyCache.class)
.addClass(BusinessObject.class).addClass(Notification.class).addClass(MySubscriber.class))
.overrideConfigKey("quarkus.redis.hosts", "${quarkus.redis.tr}");

@Inject
MyCache cache;

@Inject
MySubscriber subscriber;

@Test
void cacheWithPubSub() {
BusinessObject foo = cache.get("ps-foo-2");
BusinessObject bar = cache.get("ps-bar-2");
Assertions.assertNull(foo);
Assertions.assertNull(bar);

cache.set("ps-foo-2", new BusinessObject("ps-foo-2"));
cache.set("ps-bar-2", new BusinessObject("ps-bar-2"));

await().until(() -> subscriber.list().size() == 2);
}

public static final class BusinessObject {
public String result;

public BusinessObject() {

}

public BusinessObject(String v) {
this.result = v;
}
}

public static final class Notification {
public String key;
public BusinessObject bo;

public Notification() {

}

public Notification(String key, BusinessObject bo) {
this.key = key;
this.bo = bo;
}
}

@ApplicationScoped
@Startup
public static class MySubscriber implements Consumer<Notification> {
private final ReactivePubSubCommands<Notification> pub;

private final Cancellable cancellable;

public List<Notification> list = new ArrayList<>();

public MySubscriber(ReactiveRedisDataSource ds) {
pub = ds.pubsub(Notification.class);
Multi<Notification> multi = pub.subscribe("notifications");
cancellable = multi.subscribe().with(n -> list.add(n));
}

@PreDestroy
public void terminate() {
cancellable.cancel();
}

@Override
public void accept(Notification notification) {
// Received the notification
list.add(notification);
}

public List<Notification> list() {
return list;
}
}

@ApplicationScoped
public static class MyCache {

private final StringCommands<String, BusinessObject> commands;
private final PubSubCommands<Notification> pub;

public MyCache(RedisDataSource ds) {
commands = ds.string(BusinessObject.class);
pub = ds.pubsub(Notification.class);
}

public BusinessObject get(String key) {
return commands.get(key);
}

public void set(String key, BusinessObject bo) {
commands.set(key, bo);
pub.publish("notifications", new Notification(key, bo));
}

}

}
Expand Up @@ -33,14 +33,15 @@ public class RedisClientRecorder {
private final RedisConfig config;
private static final Map<String, RedisClientAndApi> clients = new HashMap<>();
private static final Map<String, ReactiveRedisDataSourceImpl> dataSources = new HashMap<>();
private Vertx vertx;

public RedisClientRecorder(RedisConfig rc) {
this.config = rc;
}

public void initialize(RuntimeValue<io.vertx.core.Vertx> vertx, Set<String> names) {
Vertx v = Vertx.newInstance(vertx.getValue());
_initialize(v, names);
this.vertx = Vertx.newInstance(vertx.getValue());
_initialize(this.vertx, names);
}

private void closeAllClients() {
Expand Down Expand Up @@ -138,7 +139,7 @@ public ReactiveRedisDataSource get() {
RedisClientAndApi redisClientAndApi = clients.get(name);
Redis redis = redisClientAndApi.redis;
RedisAPI api = redisClientAndApi.api;
return new ReactiveRedisDataSourceImpl(redis, api);
return new ReactiveRedisDataSourceImpl(vertx, redis, api);
});
}
};
Expand Down
Expand Up @@ -22,6 +22,7 @@
import io.quarkus.redis.datasource.transactions.OptimisticLockingTransactionResult;
import io.quarkus.redis.datasource.transactions.TransactionResult;
import io.quarkus.redis.datasource.transactions.TransactionalRedisDataSource;
import io.vertx.mutiny.core.Vertx;
import io.vertx.mutiny.redis.client.Command;
import io.vertx.mutiny.redis.client.Redis;
import io.vertx.mutiny.redis.client.RedisAPI;
Expand All @@ -35,8 +36,8 @@ public class BlockingRedisDataSourceImpl implements RedisDataSource {
final ReactiveRedisDataSourceImpl reactive;
final RedisConnection connection;

public BlockingRedisDataSourceImpl(Redis redis, RedisAPI api, Duration timeout) {
this(new ReactiveRedisDataSourceImpl(redis, api), timeout);
public BlockingRedisDataSourceImpl(Vertx vertx, Redis redis, RedisAPI api, Duration timeout) {
this(new ReactiveRedisDataSourceImpl(vertx, redis, api), timeout);
}

public BlockingRedisDataSourceImpl(ReactiveRedisDataSourceImpl reactive, Duration timeout) {
Expand All @@ -45,13 +46,14 @@ public BlockingRedisDataSourceImpl(ReactiveRedisDataSourceImpl reactive, Duratio
this.connection = reactive.connection;
}

public BlockingRedisDataSourceImpl(Redis redis, RedisConnection connection, Duration timeout) {
this(new ReactiveRedisDataSourceImpl(redis, connection), timeout);
public BlockingRedisDataSourceImpl(Vertx vertx, Redis redis, RedisConnection connection, Duration timeout) {
this(new ReactiveRedisDataSourceImpl(vertx, redis, connection), timeout);
}

public TransactionResult withTransaction(Consumer<TransactionalRedisDataSource> ds) {
RedisConnection connection = reactive.redis.connect().await().atMost(timeout);
ReactiveRedisDataSourceImpl dataSource = new ReactiveRedisDataSourceImpl(reactive.redis, connection);
ReactiveRedisDataSourceImpl dataSource = new ReactiveRedisDataSourceImpl(reactive.getVertx(), reactive.redis,
connection);
TransactionHolder th = new TransactionHolder();
BlockingTransactionalRedisDataSourceImpl source = new BlockingTransactionalRedisDataSourceImpl(
new ReactiveTransactionalRedisDataSourceImpl(dataSource, th), timeout);
Expand All @@ -73,7 +75,8 @@ public TransactionResult withTransaction(Consumer<TransactionalRedisDataSource>
@Override
public TransactionResult withTransaction(Consumer<TransactionalRedisDataSource> ds, String... watchedKeys) {
RedisConnection connection = reactive.redis.connect().await().atMost(timeout);
ReactiveRedisDataSourceImpl dataSource = new ReactiveRedisDataSourceImpl(reactive.redis, connection);
ReactiveRedisDataSourceImpl dataSource = new ReactiveRedisDataSourceImpl(reactive.getVertx(), reactive.redis,
connection);
TransactionHolder th = new TransactionHolder();
BlockingTransactionalRedisDataSourceImpl source = new BlockingTransactionalRedisDataSourceImpl(
new ReactiveTransactionalRedisDataSourceImpl(dataSource, th), timeout);
Expand Down Expand Up @@ -104,7 +107,8 @@ public TransactionResult withTransaction(Consumer<TransactionalRedisDataSource>
public <I> OptimisticLockingTransactionResult<I> withTransaction(Function<RedisDataSource, I> preTxBlock,
BiConsumer<I, TransactionalRedisDataSource> tx, String... watchedKeys) {
RedisConnection connection = reactive.redis.connect().await().atMost(timeout);
ReactiveRedisDataSourceImpl dataSource = new ReactiveRedisDataSourceImpl(reactive.redis, connection);
ReactiveRedisDataSourceImpl dataSource = new ReactiveRedisDataSourceImpl(reactive.getVertx(), reactive.redis,
connection);
TransactionHolder th = new TransactionHolder();
BlockingTransactionalRedisDataSourceImpl source = new BlockingTransactionalRedisDataSourceImpl(
new ReactiveTransactionalRedisDataSourceImpl(dataSource, th), timeout);
Expand All @@ -116,7 +120,8 @@ public <I> OptimisticLockingTransactionResult<I> withTransaction(Function<RedisD
}
connection.send(cmd).await().atMost(timeout);

I input = preTxBlock.apply(new BlockingRedisDataSourceImpl(reactive.redis, connection, timeout));
I input = preTxBlock
.apply(new BlockingRedisDataSourceImpl(reactive.getVertx(), reactive.redis, connection, timeout));

connection.send(Request.cmd(Command.MULTI)).await().atMost(timeout);

Expand All @@ -143,7 +148,7 @@ public void withConnection(Consumer<RedisDataSource> consumer) {
}

BlockingRedisDataSourceImpl source = reactive.redis.connect()
.map(rc -> new BlockingRedisDataSourceImpl(reactive.redis, rc, timeout))
.map(rc -> new BlockingRedisDataSourceImpl(reactive.getVertx(), reactive.redis, rc, timeout))
.await().atMost(timeout);

try {
Expand Down
Expand Up @@ -14,21 +14,25 @@
import io.smallrye.common.vertx.VertxContext;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.subscription.UniEmitter;
import io.vertx.core.Context;
import io.vertx.core.Vertx;
import io.vertx.mutiny.redis.client.Command;
import io.vertx.mutiny.redis.client.Redis;
import io.vertx.mutiny.redis.client.RedisAPI;
import io.vertx.mutiny.redis.client.RedisConnection;
import io.vertx.mutiny.redis.client.Response;

public class ReactivePubSubCommandsImpl<V> extends AbstractRedisCommands implements ReactivePubSubCommands<V> {

private final Class<V> classOfMessage;
private final Redis ds;
private final Redis client;
private final ReactiveRedisDataSourceImpl datasource;

public ReactivePubSubCommandsImpl(ReactiveRedisDataSourceImpl ds, Class<V> classOfMessage) {
super(ds, new Marshaller(classOfMessage));
this.ds = ds.redis;
this.client = ds.redis;
this.datasource = ds;
this.classOfMessage = classOfMessage;
}

Expand Down Expand Up @@ -67,7 +71,7 @@ public Uni<ReactiveRedisSubscriber> subscribeToPatterns(List<String> patterns, C
}
}

return ds.connect()
return client.connect()
.chain(conn -> {
RedisAPI api = RedisAPI.api(conn);
ReactiveRedisPatternSubscriberImpl subscriber = new ReactiveRedisPatternSubscriberImpl(conn, api, onMessage,
Expand All @@ -91,7 +95,7 @@ public Uni<ReactiveRedisSubscriber> subscribe(List<String> channels, Consumer<V>
}
}

return ds.connect()
return client.connect()
.chain(conn -> {
RedisAPI api = RedisAPI.api(conn);
ReactiveAbstractRedisSubscriberImpl subscriber = new ReactiveAbstractRedisSubscriberImpl(conn, api,
Expand Down Expand Up @@ -146,14 +150,13 @@ public Uni<String> subscribe() {
Uni<Void> handled = Uni.createFrom().emitter(emitter -> {
connection.handler(r -> {
if (r != null && r.size() > 0) {
Context context = VertxContext.getOrCreateDuplicatedContext(Vertx.currentContext());
String command = r.get(0).toString();
if ("subscribe".equalsIgnoreCase(command) || "psubscribe".equalsIgnoreCase(command)) {
emitter.complete(null); // Subscribed
} else if ("message".equalsIgnoreCase(command)) {
context.runOnContext(x -> onMessage.accept(marshaller.decode(classOfMessage, r.get(2))));
} else if ("pmessage".equalsIgnoreCase(command)) {
context.runOnContext(x -> onMessage.accept(marshaller.decode(classOfMessage, r.get(3))));
Context ctxt = Vertx.currentContext();
if (ctxt != null) {
handleRedisEvent(emitter, r);
} else {
datasource.getVertx().runOnContext(() -> {
handleRedisEvent(emitter, r);
});
}
}
});
Expand All @@ -165,6 +168,18 @@ public Uni<String> subscribe() {
.replaceWith(id);
}

private void handleRedisEvent(UniEmitter<? super Void> emitter, Response r) {
Context context = VertxContext.getOrCreateDuplicatedContext(Vertx.currentContext());
String command = r.get(0).toString();
if ("subscribe".equalsIgnoreCase(command) || "psubscribe".equalsIgnoreCase(command)) {
emitter.complete(null); // Subscribed
} else if ("message".equalsIgnoreCase(command)) {
context.runOnContext(x -> onMessage.accept(marshaller.decode(classOfMessage, r.get(2))));
} else if ("pmessage".equalsIgnoreCase(command)) {
context.runOnContext(x -> onMessage.accept(marshaller.decode(classOfMessage, r.get(3))));
}
}

public Uni<Void> closeAndUnregister(Collection<?> collection) {
if (collection.isEmpty()) {
return connection.close();
Expand Down