Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package io.numaproj.numaflow.examples.sink.forkjoin;

import io.numaproj.numaflow.sinker.Datum;
import io.numaproj.numaflow.sinker.DatumIterator;
import io.numaproj.numaflow.sinker.Response;
import io.numaproj.numaflow.sinker.ResponseList;
import io.numaproj.numaflow.sinker.Server;
import io.numaproj.numaflow.sinker.Sinker;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;

/**
* ConcurrentSink demonstrates concurrent processing of incoming messages using ThreadPoolExecutor.
* This example shows how to process messages in parallel using a thread pool for
* CPU-intensive operations where parallel processing can improve performance.
*
* Key features:
* - Uses ThreadPoolExecutor for parallel execution
* - Processes each message independently in parallel
* - Demonstrates concurrent data transformation
* - Handles exceptions gracefully in parallel processing
* - Shows how to aggregate results from multiple threads
*/
@Slf4j
public class ConcurrentSink extends Sinker {

private static final int DEFAULT_THREAD_POOL_SIZE = Runtime.getRuntime().availableProcessors();

private final ThreadPoolExecutor threadPool;

public ConcurrentSink() {
this(DEFAULT_THREAD_POOL_SIZE);
}

public ConcurrentSink(int threadPoolSize) {
this.threadPool = new ThreadPoolExecutor(
threadPoolSize,
threadPoolSize,
60L,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(),
new ThreadFactory() {
private int counter = 0;
@Override
public Thread newThread(Runnable r) {
return new Thread(r, "ConcurrentSink-Worker-" + (++counter));
}
}
);
}

public static void main(String[] args) throws Exception {
ConcurrentSink concurrentSink = new ConcurrentSink();

Server server = new Server(concurrentSink);
server.start();
server.awaitTermination();
server.stop();

concurrentSink.shutdown();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invoke shutdown after the server terminates?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


@Override
public ResponseList processMessages(DatumIterator datumIterator) {
log.info("Starting concurrent processing with thread pool size: {}",
threadPool.getCorePoolSize());

List<Datum> messages = new ArrayList<>();
while (true) {
Datum datum;
try {
datum = datumIterator.next();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
continue;
}
if (datum == null) {
break;
}
messages.add(datum);
}

log.info("Collected {} messages for concurrent processing", messages.size());

if (messages.isEmpty()) {
return ResponseList.newBuilder().build();
}

List<Response> allResponses = processInParallel(messages);

log.info("Completed concurrent processing, generated {} responses", allResponses.size());

ResponseList.ResponseListBuilder responseListBuilder = ResponseList.newBuilder();
for (Response response : allResponses) {
responseListBuilder.addResponse(response);
}

return responseListBuilder.build();
}

/**
* Processes messages in parallel using ThreadPoolExecutor.
* Each message is processed independently in a separate thread.
*/
private List<Response> processInParallel(List<Datum> messages) {
List<Future<Response>> futures = new ArrayList<>();

for (Datum message : messages) {
Future<Response> future = threadPool.submit(new MessageProcessingTask(message));
futures.add(future);
}

List<Response> allResponses = new ArrayList<>();
for (Future<Response> future : futures) {
try {
Response response = future.get(30, TimeUnit.SECONDS);
allResponses.add(response);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("Interrupted while waiting for message processing", e);
} catch (ExecutionException e) {
log.error("Error during message processing", e.getCause());
} catch (TimeoutException e) {
log.error("Timeout waiting for message processing", e);
future.cancel(true);
}
}

return allResponses;
}

/**
* Task that processes a single message in a thread.
* This is where the actual CPU-intensive work would be done.
*/
private static class MessageProcessingTask implements Callable<Response> {
private final Datum datum;

public MessageProcessingTask(Datum datum) {
this.datum = datum;
}

@Override
public Response call() {
try {
String message = new String(datum.getValue());
String processedMessage = processMessage(message);

log.debug("Processed message {} -> {}", message, processedMessage);
return Response.responseOK(datum.getId());

} catch (Exception e) {
log.error("Error processing message with ID: {}", datum.getId(), e);
return Response.responseFailure(datum.getId(), e.getMessage());
}
}

/**
* Simulates CPU-intensive message processing.
* In a real-world scenario, this could be data transformation, validation,
* encryption, compression, or any other compute-intensive operation.
*/
private String processMessage(String message) {
StringBuilder processed = new StringBuilder();
processed.append("PROCESSED[")
.append(new StringBuilder(message).reverse())
.append("]-")
.append(Thread.currentThread().getName())
.append("-")
.append(System.currentTimeMillis() % 1000);

for (int i = 0; i < 100; i++) {
Math.sqrt(i * message.hashCode());
}

return processed.toString();
}
}

/**
* Shutdown the thread pool gracefully.
* This should be called when the sink is no longer needed.
*/
public void shutdown() {
log.info("Shutting down concurrent sink thread pool");
threadPool.shutdown();
try {
if (!threadPool.awaitTermination(10, TimeUnit.SECONDS)) {
threadPool.shutdownNow();
}
} catch (InterruptedException e) {
threadPool.shutdownNow();
Thread.currentThread().interrupt();
}
}

/**
* Get current thread pool statistics for monitoring.
*/
public String getThreadPoolStats() {
return String.format("ThreadPool[active=%d, completed=%d, queued=%d, pool=%d/%d]",
threadPool.getActiveCount(),
threadPool.getCompletedTaskCount(),
threadPool.getQueue().size(),
threadPool.getPoolSize(),
threadPool.getMaximumPoolSize());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package io.numaproj.numaflow.examples.sink.forkjoin;

import io.numaproj.numaflow.sinker.Response;
import io.numaproj.numaflow.sinker.ResponseList;
import io.numaproj.numaflow.sinker.SinkerTestKit;
import io.numaproj.numaflow.sinker.Datum;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;

import static org.mockito.Mockito.mock;


import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.when;

/**
* Comprehensive test suite for ConcurrentSink to verify concurrent processing functionality.
*/
@RunWith(MockitoJUnitRunner.class)
public class ConcurrentSinkTest {

private ConcurrentSink concurrentSink;
private SinkerTestKit.TestListIterator testIterator;

@BeforeEach
void setUp() {
concurrentSink = new ConcurrentSink();
testIterator = new SinkerTestKit.TestListIterator();
}

@AfterEach
void tearDown() {
if (concurrentSink != null) {
concurrentSink.shutdown();
}
}

@Test
@DisplayName("Should process empty message list")
void testEmptyMessageList() {
ResponseList responseList = concurrentSink.processMessages(testIterator);

assertNotNull(responseList);
assertEquals(0, responseList.getResponses().size());
}

@Test
@DisplayName("Should process single message")
void testSingleMessage() {
testIterator.addDatum(createTestDatum("id-1", "test-message"));

ResponseList responseList = concurrentSink.processMessages(testIterator);

assertNotNull(responseList);
assertEquals(1, responseList.getResponses().size());

Response response = responseList.getResponses().get(0);
assertTrue(response.getSuccess());
assertEquals("id-1", response.getId());
}

@Test
@DisplayName("Should process multiple messages")
void testMultipleMessages() {
int messageCount = 15;

for (int i = 0; i < messageCount; i++) {
testIterator.addDatum(createTestDatum("id-" + i, "message-" + i));
}

ResponseList responseList = concurrentSink.processMessages(testIterator);

assertNotNull(responseList);
assertEquals(messageCount, responseList.getResponses().size());

for (Response response : responseList.getResponses()) {
assertTrue(response.getSuccess());
}
}

@Test
@DisplayName("Should handle concurrent processing with custom configuration")
void testCustomConfiguration() {
int threadPoolSize = 2;
ConcurrentSink customSink = new ConcurrentSink(threadPoolSize);

try {
int messageCount = 15;
SinkerTestKit.TestListIterator iterator = new SinkerTestKit.TestListIterator();

for (int i = 0; i < messageCount; i++) {
iterator.addDatum(createTestDatum("id-" + i, "message-" + i));
}

ResponseList responseList = customSink.processMessages(iterator);

assertNotNull(responseList);
assertEquals(messageCount, responseList.getResponses().size());

for (Response response : responseList.getResponses()) {
assertTrue(response.getSuccess());
}

// Verify thread pool stats
String stats = customSink.getThreadPoolStats();
assertNotNull(stats);
assertTrue(stats.contains("ThreadPool"));

} finally {
customSink.shutdown();
}
}

@Test
@DisplayName("Should handle null values gracefully")
void testNullValues() {
testIterator.addDatum(createTestDatum("id-1", null));
testIterator.addDatum(createTestDatum("id-2", "valid-message"));

ResponseList responseList = concurrentSink.processMessages(testIterator);

assertNotNull(responseList);
assertEquals(2, responseList.getResponses().size());

long successCount = responseList.getResponses().stream()
.mapToLong(response -> response.getSuccess() ? 1 : 0)
.sum();
assertEquals(2, successCount);
}

@Test
@DisplayName("Should handle errors gracefully")
void testErrors() {
Datum mockDatum = mock(Datum.class);
testIterator.addDatum(mockDatum);
testIterator.addDatum(mockDatum);

when(mockDatum.getValue()).thenThrow(new RuntimeException("some exception happened"));

ResponseList responseList = concurrentSink.processMessages(testIterator);

assertNotNull(responseList);
assertEquals(2, responseList.getResponses().size());

long errorCount = responseList.getResponses().stream()
.mapToLong(response -> response.getSuccess() ? 0 : 1)
.sum();
assertEquals(2, errorCount);
}

private SinkerTestKit.TestDatum createTestDatum(String id, String value) {
return SinkerTestKit.TestDatum.builder()
.id(id)
.value(value != null ? value.getBytes() : new byte[0])
.build();
}
}
Loading