Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;

import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -51,6 +55,7 @@
*
* @author Mark Paluch
* @author Christoph Strobl
* @author Han Li
* @since 2.2
*/
class DefaultStreamMessageListenerContainer<K, V extends Record<K, ?>> implements StreamMessageListenerContainer<K, V> {
Expand Down Expand Up @@ -160,9 +165,22 @@ public void stop() {
synchronized (lifecycleMonitor) {

if (this.running) {

subscriptions.forEach(Cancelable::cancel);

subscriptions.stream()
.map(subscription -> CompletableFuture.runAsync(() -> {
subscription.cancel();
while (subscription.isActive()) {
// NO-OP
}
}, taskExecutor))
.forEach(f -> {
try {
f.get(this.containerOptions.getShutdownTimeout().toNanos(), TimeUnit.NANOSECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException | TimeoutException e) {
// ignore
}
});
running = false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,13 @@ class StreamMessageListenerContainerOptions<K, V extends Record<K, ?>> {
private final @Nullable HashMapper<Object, Object, Object> hashMapper;
private final ErrorHandler errorHandler;
private final Executor executor;
private final Duration shutdownTimeout;

@SuppressWarnings("unchecked")
private StreamMessageListenerContainerOptions(Duration pollTimeout, @Nullable Integer batchSize,
RedisSerializer<K> keySerializer, RedisSerializer<Object> hashKeySerializer,
RedisSerializer<Object> hashValueSerializer, @Nullable Class<?> targetType,
@Nullable HashMapper<V, ?, ?> hashMapper, ErrorHandler errorHandler, Executor executor) {
@Nullable HashMapper<V, ?, ?> hashMapper, ErrorHandler errorHandler, Executor executor, Duration shutdownTimeout) {
this.pollTimeout = pollTimeout;
this.batchSize = batchSize;
this.keySerializer = keySerializer;
Expand All @@ -509,6 +510,7 @@ private StreamMessageListenerContainerOptions(Duration pollTimeout, @Nullable In
this.hashMapper = (HashMapper) hashMapper;
this.errorHandler = errorHandler;
this.executor = executor;
this.shutdownTimeout = shutdownTimeout;
}

/**
Expand Down Expand Up @@ -589,6 +591,15 @@ public Executor getExecutor() {
return executor;
}

/**
* Timeout for shutdown container.
*
* @return the timeout.
*/
public Duration getShutdownTimeout() {
return shutdownTimeout;
}

}

/**
Expand All @@ -609,6 +620,7 @@ class StreamMessageListenerContainerOptionsBuilder<K, V extends Record<K, ?>> {
private @Nullable Class<?> targetType;
private ErrorHandler errorHandler = LoggingErrorHandler.INSTANCE;
private Executor executor = new SimpleAsyncTaskExecutor();
private Duration shutdownTimeout = Duration.ofSeconds(1);

private StreamMessageListenerContainerOptionsBuilder() {}

Expand All @@ -627,6 +639,21 @@ public StreamMessageListenerContainerOptionsBuilder<K, V> pollTimeout(Duration p
return this;
}

/**
* Configure a timeout for shutdown container.
*
* @param shutdownTimeout must not be {@literal null} or negative.
* @return {@code this} {@link StreamMessageListenerContainerOptionsBuilder}.
*/
public StreamMessageListenerContainerOptionsBuilder<K, V> shutdownTimeout(Duration shutdownTimeout) {

Assert.notNull(shutdownTimeout, "Shutdown timeout must not be null");
Assert.isTrue(!shutdownTimeout.isNegative(), "Shutdown timeout must not be negative");

this.shutdownTimeout = shutdownTimeout;
return this;
}

/**
* Configure a batch size for the {@code COUNT} option during reading.
*
Expand Down Expand Up @@ -777,7 +804,7 @@ public <NV> StreamMessageListenerContainerOptionsBuilder<K, ObjectRecord<K, NV>>
*/
public StreamMessageListenerContainerOptions<K, V> build() {
return new StreamMessageListenerContainerOptions<>(pollTimeout, batchSize, keySerializer, hashKeySerializer,
hashValueSerializer, targetType, hashMapper, errorHandler, executor);
hashValueSerializer, targetType, hashMapper, errorHandler, executor, shutdownTimeout);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ abstract class AbstractStreamMessageListenerContainerIntegrationTests {
private final RedisConnectionFactory connectionFactory;
private final StringRedisTemplate redisTemplate;
private final StreamMessageListenerContainerOptions<String, MapRecord<String, String, String>> containerOptions = StreamMessageListenerContainerOptions
.builder().pollTimeout(Duration.ofMillis(100)).build();
.builder().pollTimeout(Duration.ofMillis(100)).shutdownTimeout(Duration.ofMillis(2000)).build();

AbstractStreamMessageListenerContainerIntegrationTests(RedisConnectionFactory connectionFactory) {
this.connectionFactory = connectionFactory;
Expand Down Expand Up @@ -383,6 +383,26 @@ void containerRestartShouldRestartSubscription() throws InterruptedException {

cancelAwait(subscription);
}
@Test // GH-2261
void containerShouldStopGracefully() throws InterruptedException {
StreamMessageListenerContainer<String, MapRecord<String, String, String>> container = StreamMessageListenerContainer
.create(connectionFactory, containerOptions);

BlockingQueue<MapRecord<String, String, String>> queue = new LinkedBlockingQueue<>();
container.start();
Subscription subscription = container.receive(StreamOffset.create("my-stream", ReadOffset.from("0-0")), r -> {
try {
Thread.sleep(1500);
} catch (InterruptedException e) {
// ignore
}
queue.add(r);
});
redisTemplate.opsForStream().add("my-stream", Collections.singletonMap("key", "value1"));
subscription.await(DEFAULT_TIMEOUT);
container.stop();
assertThat(queue.poll(500, TimeUnit.MILLISECONDS)).isNotNull();
}

private static void cancelAwait(Subscription subscription) {

Expand Down