Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.internal.connection;

import com.mongodb.ClusterFixture;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCommandException;
Expand All @@ -41,11 +42,11 @@
import org.bson.Document;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;

import java.io.IOException;
import java.lang.reflect.Field;
Expand Down Expand Up @@ -79,7 +80,6 @@
import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY;
import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.testing.MongoAssertions.assertCause;
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.lang.System.getenv;
import static java.util.Arrays.asList;
Expand Down Expand Up @@ -215,9 +215,9 @@ public void test2p1ValidCallbackInputs() {
+ " expectedTimeoutThreshold={3}")
@MethodSource
void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName,
final int timeoutMs,
final int serverSelectionTimeoutMS,
final int expectedTimeoutThreshold) {
final long timeoutMs,
final long serverSelectionTimeoutMS,
final long expectedTimeoutThreshold) {
TestCallback callback1 = createCallback();

OidcCallback callback2 = (context) -> {
Expand All @@ -242,40 +242,50 @@ void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName,
assertEquals(1, callback1.getInvocations());
long elapsed = msElapsedSince(start);

assertFalse(elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min(serverSelectionTimeoutMS, timeoutMs)),

assertFalse(elapsed > minTimeout(timeoutMs, serverSelectionTimeoutMS),
format("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. "
+ "This indicates that the callback was not called with the expected timeout.",
min(serverSelectionTimeoutMS, timeoutMs),
elapsed));
elapsed,
minTimeout(timeoutMs, serverSelectionTimeoutMS)));

}
}

private static Stream<Arguments> testValidCallbackInputsTimeoutWhenTimeoutMsIsSet() {
long rtt = ClusterFixture.getPrimaryRTT();
return Stream.of(
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if it's lower than timeoutMS",
1000, // timeoutMS
500, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold
1000 + rtt, // timeoutMS
500 + rtt, // serverSelectionTimeoutMS
499 + rtt), // expectedTimeoutThreshold
Arguments.of("timeoutMS honored for oidc callback if it's lower than serverSelectionTimeoutMS",
500, // timeoutMS
1000, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold
500 + rtt, // timeoutMS
1000 + rtt, // serverSelectionTimeoutMS
499 + rtt), // expectedTimeoutThreshold
Arguments.of("timeoutMS honored for oidc callback if serverSelectionTimeoutMS is infinite",
500 + rtt, // timeoutMS
-1, // serverSelectionTimeoutMS
499 + rtt), // expectedTimeoutThreshold,
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0",
0, // infinite timeoutMS
500, // serverSelectionTimeoutMS
499) // expectedTimeoutThreshold
500 + rtt, // serverSelectionTimeoutMS
499 + rtt) // expectedTimeoutThreshold
);
}

// Not a prose test
@ParameterizedTest(name = "test callback timeout when server selection timeout is "
+ "infinite and timeoutMs is set to {0}")
@ValueSource(ints = {0, 100})
void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final int timeoutMs) {
@Test
@DisplayName("test callback timeout when serverSelectionTimeoutMS and timeoutMS are infinite")
void testCallbackTimeoutWhenServerSelectionTimeoutMsIsInfiniteTimeoutMsIsSet() {
TestCallback callback1 = createCallback();
Duration expectedTimeout = ChronoUnit.FOREVER.getDuration();

OidcCallback callback2 = (context) -> {
assertEquals(context.getTimeout(), ChronoUnit.FOREVER.getDuration());
assertEquals(expectedTimeout, context.getTimeout(),
format("Expected timeout to be infinite (%s), but was %s",
expectedTimeout, context.getTimeout()));

return callback1.onRequest(context);
};

Expand All @@ -284,7 +294,7 @@ void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final
builder.serverSelectionTimeout(
-1, // -1 means infinite
TimeUnit.MILLISECONDS))
.timeout(timeoutMs, TimeUnit.MILLISECONDS)
.timeout(0, TimeUnit.MILLISECONDS)
.build();

try (MongoClient mongoClient = createMongoClient(clientSettings)) {
Expand Down Expand Up @@ -1242,4 +1252,10 @@ public TestCallback createHumanCallback() {
private long msElapsedSince(final long timeOfStart) {
return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - timeOfStart);
}

private static long minTimeout(final long timeoutMs, final long serverSelectionTimeoutMS) {
long timeoutMsEffective = timeoutMs != 0 ? timeoutMs : Long.MAX_VALUE;
long serverSelectionTimeoutMSEffective = serverSelectionTimeoutMS != -1 ? serverSelectionTimeoutMS : Long.MAX_VALUE;
return Math.min(timeoutMsEffective, serverSelectionTimeoutMSEffective);
}
}