From e86730f95ef14b697f6cfaa6a7984d5fbd42185a Mon Sep 17 00:00:00 2001 From: diwakar Date: Fri, 24 Apr 2020 12:53:15 -0700 Subject: [PATCH 1/3] Allow for automated call graph generation to keep contexts unique within StdCallbackContext for replay deduping. Now developers do not need to provide a name when calling services. There is a unique call graph context maintained for each service call made. The context is request aware, so different requests made will have their own independent context for dedupe Dedupe identical requests ```java ProgressEvent result = initiator.translateToServiceRequest(m -> createRepository) .makeServiceCall((r, c) -> c.injectCredentialsAndInvokeV2( r, c.client()::createRepository)) .success(); ProgressEvent result_2 = // make same request call initiator.translateToServiceRequest(m -> createRepository) .makeServiceCall((r, c) -> c.injectCredentialsAndInvokeV2( r, c.client()::createRepository)) .success(); assertThat(result).isEqualsTo(result_2); ``` --- .../proxy/AmazonWebServicesClientProxy.java | 31 ++++ .../cloudformation/proxy/CallChain.java | 7 +- .../proxy/CallGraphNameGenerator.java | 22 +++ .../proxy/StdCallbackContext.java | 37 +++++ .../AmazonWebServicesClientProxyTest.java | 142 ++++++++++++++++-- .../proxy/service/ServiceClient.java | 9 +- 6 files changed, 233 insertions(+), 15 deletions(-) create mode 100644 src/main/java/software/amazon/cloudformation/proxy/CallGraphNameGenerator.java diff --git a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java index f4c46404..265e5185 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java +++ b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java @@ -40,6 +40,7 @@ import software.amazon.awssdk.awscore.AwsResponse; import software.amazon.awssdk.awscore.exception.AwsErrorDetails; import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.SdkClient; import software.amazon.awssdk.core.exception.NonRetryableException; import software.amazon.awssdk.core.exception.RetryableException; import software.amazon.awssdk.core.pagination.sync.SdkIterable; @@ -169,6 +170,12 @@ public RequestMaker initiate(String callGraph) { return new CallContext<>(callGraph, client, model, callback); } + @Override + public < + RequestT> Caller translateToServiceRequest(Function maker) { + return initiate("").translateToServiceRequest(maker); + } + @Override public ModelT getResourceModel() { return model; @@ -192,6 +199,11 @@ public Initiator rebindModel(NewModel Preconditions.checkNotNull(callback, "cxt can not be null"); return new StdInitiator<>(client, model, callback); } + + @Override + public Logger getLogger() { + return AmazonWebServicesClientProxy.this.loggerProxy; + } } @Override @@ -234,6 +246,22 @@ class CallContext RequestT> Caller translateToServiceRequest(Function maker) { return new Caller() { + private final CallGraphNameGenerator generator = (incoming, model_, reqMaker, client_, context_) -> { + final RequestT request = reqMaker.apply(model_); + String objectHash = String.valueOf(Objects.hashCode(request)); + String serviceName = (client_ == null + ? "" + : (client_ instanceof SdkClient) + ? ((SdkClient) client_).serviceName() + : client_.getClass().getSimpleName()); + String requestName = request != null ? request.getClass().getSimpleName().replace("Request", "") : ""; + String callGraph = serviceName + ":" + requestName + "-" + (incoming != null ? incoming : "") + "-" + + objectHash; + context_.request(callGraph, (ignored -> request)).apply(model_); + return callGraph; + }; + @Override public Caller backoffDelay(Delay delay) { CallContext.this.delay = delay; @@ -315,6 +343,8 @@ public ProgressEvent done(Callback reqMaker = context.request(callGraph, maker); BiFunction, ResponseT> resMaker = context.response(callGraph, caller); @@ -377,6 +407,7 @@ public ProgressEvent done(Callback localWait) { + loggerProxy.log("Waiting for " + next.getSeconds() + " for call " + callGraph); Uninterruptibles.sleepUninterruptibly(next.getSeconds(), TimeUnit.SECONDS); continue; } diff --git a/src/main/java/software/amazon/cloudformation/proxy/CallChain.java b/src/main/java/software/amazon/cloudformation/proxy/CallChain.java index c85f2f66..dde56006 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/CallChain.java +++ b/src/main/java/software/amazon/cloudformation/proxy/CallChain.java @@ -53,7 +53,7 @@ public interface CallChain { * @param the model object being worked on * @param the callback context */ - interface Initiator { + interface Initiator extends RequestMaker { /** * Each service call must be first initiated. Every call is provided a separate * name called call graph. This is essential from both a tracing perspective as @@ -74,6 +74,11 @@ interface Initiator { */ CallbackT getCallbackContext(); + /** + * @return logger associated to log messages + */ + Logger getLogger(); + /** * Can rebind a new model to the call chain while retaining the client and * callback context diff --git a/src/main/java/software/amazon/cloudformation/proxy/CallGraphNameGenerator.java b/src/main/java/software/amazon/cloudformation/proxy/CallGraphNameGenerator.java new file mode 100644 index 00000000..87aff452 --- /dev/null +++ b/src/main/java/software/amazon/cloudformation/proxy/CallGraphNameGenerator.java @@ -0,0 +1,22 @@ +/* +* Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"). +* You may not use this file except in compliance with the License. +* A copy of the License is located at +* +* http://aws.amazon.com/apache2.0 +* +* or in the "license" file accompanying this file. This file is distributed +* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +* express or implied. See the License for the specific language governing +* permissions and limitations under the License. +*/ +package software.amazon.cloudformation.proxy; + +import java.util.function.Function; + +@FunctionalInterface +public interface CallGraphNameGenerator { + String callGraph(String incoming, ModelT model, Function reqMaker, ClientT client, CallbackT context); +} diff --git a/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java b/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java index 316637d8..8dab2f99 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java +++ b/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java @@ -32,9 +32,13 @@ import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; import javax.annotation.concurrent.ThreadSafe; /** @@ -272,6 +276,39 @@ public ResponseT response(String callGraph) { return (ResponseT) callGraphs.get(callGraph + ".response"); } + @SuppressWarnings("unchecked") + public RequestT findFirstRequestByContains(String contains) { + return (RequestT) findFirst((key) -> key.contains(contains) && key.endsWith(".request")); + } + + @SuppressWarnings("unchecked") + public List findAllRequestByContains(String contains) { + return (List) findAll((key) -> key.contains(contains) && key.endsWith(".request")); + } + + @SuppressWarnings("unchecked") + public ResponseT findFirstResponseByContains(String contains) { + return (ResponseT) findFirst((key) -> key.contains(contains) && key.endsWith(".response")); + } + + @SuppressWarnings("unchecked") + public List findAllResponseByContains(String contains) { + return (List) findAll((key) -> key.contains(contains) && key.endsWith(".response")); + } + + Object findFirst(Predicate contains) { + Objects.requireNonNull(contains); + return callGraphs.entrySet().stream().filter(e -> contains.test(e.getKey())).findFirst().map(Map.Entry::getValue) + .orElse(null); + + } + + List findAll(Predicate contains) { + Objects.requireNonNull(contains); + return callGraphs.entrySet().stream().filter(e -> contains.test(e.getKey())).map(Map.Entry::getValue) + .collect(Collectors.toList()); + } + CallChain.Callback stabilize(String callGraph, CallChain.Callback callback) { diff --git a/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java b/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java index e09a2b80..701afaea 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java +++ b/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java @@ -16,12 +16,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.amazonaws.auth.AWSStaticCredentialsProvider; @@ -29,7 +29,9 @@ import com.amazonaws.services.cloudformation.model.DescribeStackEventsRequest; import com.amazonaws.services.cloudformation.model.DescribeStackEventsResult; import java.time.Duration; +import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -243,7 +245,8 @@ public void badRequest() { final SdkHttpResponse sdkHttpResponse = mock(SdkHttpResponse.class); when(sdkHttpResponse.statusCode()).thenReturn(400); final ProgressEvent result = proxy.initiate("client:createRespository", proxy.newProxy(() -> null), model, context) + StdCallbackContext> result = proxy + .initiate("client:createRespository", proxy.newProxy(() -> mock(ServiceClient.class)), model, context) .translateToServiceRequest(m -> new CreateRequest.Builder().repoName(m.getRepoName()).build()) .makeServiceCall((r, c) -> { throw new BadRequestException(mock(AwsServiceException.Builder.class)) { @@ -273,7 +276,8 @@ public void notFound() { final SdkHttpResponse sdkHttpResponse = mock(SdkHttpResponse.class); when(sdkHttpResponse.statusCode()).thenReturn(404); ProgressEvent result = proxy.initiate("client:createRespository", proxy.newProxy(() -> null), model, context) + StdCallbackContext> result = proxy + .initiate("client:createRespository", proxy.newProxy(() -> mock(ServiceClient.class)), model, context) .translateToServiceRequest(m -> new CreateRequest.Builder().repoName(m.getRepoName()).build()) .makeServiceCall((r, c) -> { throw new NotFoundException(mock(AwsServiceException.Builder.class)) { @@ -303,7 +307,8 @@ public void accessDenied() { final SdkHttpResponse sdkHttpResponse = mock(SdkHttpResponse.class); when(sdkHttpResponse.statusCode()).thenReturn(401); ProgressEvent result = proxy.initiate("client:createRespository", proxy.newProxy(() -> null), model, context) + StdCallbackContext> result = proxy + .initiate("client:createRespository", proxy.newProxy(() -> mock(ServiceClient.class)), model, context) .translateToServiceRequest(m -> new CreateRequest.Builder().repoName(m.getRepoName()).build()) .makeServiceCall((r, c) -> { throw new AccessDenied(AwsServiceException.builder()) { @@ -395,15 +400,21 @@ public AwsErrorDetails awsErrorDetails() { assertThat(resultModel.getArn()).isNotNull(); assertThat(resultModel.getCreated()).isNotNull(); - Map callGraphs = context.callGraphs(); - assertThat(callGraphs.containsKey("client:createRepository.request")).isEqualTo(true); - assertSame(requests[0], callGraphs.get("client:createRepository.request")); - assertThat(callGraphs.containsKey("client:createRepository.response")).isEqualTo(true); - assertSame(responses[0], callGraphs.get("client:createRepository.response")); - assertThat(callGraphs.containsKey("client:readRepository.request")).isEqualTo(true); - assertSame(describeRequests[0], callGraphs.get("client:readRepository.request")); - assertThat(callGraphs.containsKey("client:readRepository.response")).isEqualTo(true); - assertSame(describeResponses[0], callGraphs.get("client:readRepository.response")); + Object objToCmp = context.findFirstRequestByContains("client:createRepository"); + assertThat(objToCmp).isNotNull(); + assertThat(requests[0]).isSameAs(objToCmp); + + objToCmp = context.findFirstResponseByContains("client:createRepository"); + assertThat(objToCmp).isNotNull(); + assertThat(responses[0]).isSameAs(objToCmp); + + objToCmp = context.findFirstRequestByContains("client:readRepository"); + assertThat(objToCmp).isNotNull(); + assertThat(describeRequests[0]).isSameAs(objToCmp); + + objToCmp = context.findFirstResponseByContains("client:readRepository"); + assertThat(objToCmp).isNotNull(); + assertThat(describeResponses[0]).isSameAs(objToCmp); } @Test @@ -590,6 +601,111 @@ public void thenChainPattern() { return ProgressEvent.success(model, context); }); assertThat(event.isFailed()).isTrue(); + } + + @Test + public void automaticNamedRequests() { + AmazonWebServicesClientProxy proxy = new AmazonWebServicesClientProxy(mock(LoggerProxy.class), MOCK, + () -> Duration.ofSeconds(1).toMillis()); + final String repoName = "NewRepo"; + final Model model = new Model(); + model.setRepoName(repoName); + final StdCallbackContext context = new StdCallbackContext(); + // + // Mock calls + // + final ServiceClient client = mock(ServiceClient.class); + when(client.createRepository(any(CreateRequest.class))) + .thenReturn(new CreateResponse.Builder().repoName(model.getRepoName()).build()); + when(client.serviceName()).thenReturn("repositoryService"); + + CallChain.Initiator initiator = proxy.newInitiator(() -> client, model, context); + + final CreateRequest createRepository = new CreateRequest.Builder().repoName(repoName).build(); + ProgressEvent result = initiator.translateToServiceRequest(m -> createRepository) + .makeServiceCall((r, c) -> c.injectCredentialsAndInvokeV2(r, c.client()::createRepository)).success(); + + ProgressEvent result_2 = initiator.translateToServiceRequest(m -> createRepository) + .makeServiceCall((r, c) -> c.injectCredentialsAndInvokeV2(r, c.client()::createRepository)).success(); + + assertThat(result).isNotSameAs(result_2); + assertThat(result_2).isEqualTo(result); + + assertThat(result).isNotNull(); + assertThat(result.isSuccess()).isTrue(); + CreateRequest internal = context.findFirstRequestByContains("repositoryService:Create"); + assertThat(internal).isNotNull(); + assertThat(internal).isSameAs(createRepository); + + Map callGraphs = context.callGraphs(); + assertThat(callGraphs.size()).isEqualTo(3); + // verify this was called only once for both requests. + verify(client).createRepository(any(CreateRequest.class)); + } + + @Test + public void automaticNamedUniqueRequests() { + AmazonWebServicesClientProxy proxy = new AmazonWebServicesClientProxy(mock(LoggerProxy.class), MOCK, + () -> Duration.ofSeconds(1).toMillis()); + final String repoName = "NewRepo"; + final Model model = new Model(); + model.setRepoName(repoName); + final StdCallbackContext context = new StdCallbackContext(); + // + // TODO add the mocks needed + // + final ServiceClient client = mock(ServiceClient.class); + when(client.createRepository(any(CreateRequest.class))) + .thenAnswer(invocation -> new CreateResponse.Builder().repoName(model.getRepoName()).build()); + when(client.serviceName()).thenReturn("repositoryService"); + + final CallChain.Initiator initiator = proxy.newInitiator(() -> client, model, context); + + final CreateRequest createRepository = new CreateRequest.Builder().repoName(repoName).build(); + ProgressEvent result = initiator.translateToServiceRequest(m -> createRepository) + .makeServiceCall((r, c) -> c.injectCredentialsAndInvokeV2(r, c.client()::createRepository)).success(); + + model.setRepoName(repoName + "-2"); + ProgressEvent result_2 = initiator.rebindModel(Model.builder().repoName(repoName + "-2").build()) + .translateToServiceRequest(m -> new CreateRequest.Builder().repoName(model.getRepoName()).build()) + .makeServiceCall((r, c) -> c.injectCredentialsAndInvokeV2(r, c.client()::createRepository)).success(); + model.setRepoName(repoName); + + assertThat(result).isNotEqualTo(result_2); + CreateRequest internal = context.findFirstRequestByContains("repositoryService:Create"); + assertThat(internal).isNotNull(); + assertThat(internal).isSameAs(createRepository); // we picked the one with the first call + + List responses = context.findAllResponseByContains("repositoryService:Create"); + assertThat(responses.size()).isEqualTo(2); + List expected = Arrays.asList(new CreateResponse.Builder().repoName(repoName).build(), + new CreateResponse.Builder().repoName(repoName + "-2").build()); + assertThat(responses).isEqualTo(expected); + + verify(client, times(2)).createRepository(any(CreateRequest.class)); + } + + @Test + public void nullRequestTest() { + AmazonWebServicesClientProxy proxy = new AmazonWebServicesClientProxy(mock(LoggerProxy.class), MOCK, + () -> Duration.ofSeconds(1).toMillis()); + final String repoName = "NewRepo"; + final Model model = new Model(); + model.setRepoName(repoName); + final StdCallbackContext context = new StdCallbackContext(); + // + // Mock calls + // + final ServiceClient client = mock(ServiceClient.class); + final CallChain.Initiator initiator = proxy.newInitiator(() -> client, model, context); + ProgressEvent result = initiator.translateToServiceRequest(m -> (CreateRequest) null) + .makeServiceCall((r, c) -> c.injectCredentialsAndInvokeV2(r, c.client()::createRepository)).success(); + + assertThat(result).isNotNull(); } } diff --git a/src/test/java/software/amazon/cloudformation/proxy/service/ServiceClient.java b/src/test/java/software/amazon/cloudformation/proxy/service/ServiceClient.java index f146a9b0..b1290419 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/service/ServiceClient.java +++ b/src/test/java/software/amazon/cloudformation/proxy/service/ServiceClient.java @@ -14,7 +14,14 @@ */ package software.amazon.cloudformation.proxy.service; -public interface ServiceClient { +import software.amazon.awssdk.core.SdkClient; + +public interface ServiceClient extends SdkClient { + + default String serviceName() { + return "serviceClient"; + } + CreateResponse createRepository(CreateRequest r); DescribeResponse describeRepository(DescribeRequest r); From 547f3dace30e88617f7d236863e016b494b3521d Mon Sep 17 00:00:00 2001 From: diwakar Date: Tue, 28 Apr 2020 14:22:05 -0700 Subject: [PATCH 2/3] Prevent ConcurrentModificationExceptions in stabilize calls if they access the map and attempt to modify it --- .../proxy/StdCallbackContext.java | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java b/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java index 8dab2f99..c555ee19 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java +++ b/src/main/java/software/amazon/cloudformation/proxy/StdCallbackContext.java @@ -313,9 +313,30 @@ List findAll(Predicate contains) { CallChain.Callback stabilize(String callGraph, CallChain.Callback callback) { return (request1, response1, client, model, context) -> { - Boolean result = (Boolean) callGraphs.computeIfAbsent(callGraph + ".stabilize", - (ign) -> callback.invoke(request1, response1, client, model, context) ? Boolean.TRUE : null); - return result != null ? Boolean.TRUE : Boolean.FALSE; + final String key = callGraph + ".stabilize"; + Boolean result = (Boolean) callGraphs.getOrDefault(key, Boolean.FALSE); + if (!result) { + // + // The StdCallbackContext can be shared. However the call to stabilize for a + // given content + // is usually confined to one thread. If for some reason we spread that across + // threads, the + // worst that can happen is a double compute for stabilize. This isn't the + // intended pattern. + // Why are we changing it from computeIfAbsent pattern? For the callback we send + // in the + // StdCallbackContext which can be used to add things into context. That will + // lead to + // ConcurrentModificationExceptions when the compute running added things into + // context when + // needed + // + result = callback.invoke(request1, response1, client, model, context); + if (result) { + callGraphs.put(key, Boolean.TRUE); + } + } + return result; }; } From b106098ee22355a1cd091591b3d3a32f05772098 Mon Sep 17 00:00:00 2001 From: diwakar Date: Mon, 6 Jul 2020 22:12:30 -0700 Subject: [PATCH 3/3] - Remove 60s hard coded callback - Remove handshake mode - Adding wait strategy that can be changed, defaulting to re-schedules with no local waits for Lambda binding - changes unit tests to reflect removal of handshake mode, fixed unit tests for serialization errors --- .../amazon/cloudformation/LambdaWrapper.java | 8 +-- .../proxy/AmazonWebServicesClientProxy.java | 41 ++++---------- .../cloudformation/proxy/WaitStrategy.java | 54 +++++++++++++++++++ .../proxy/End2EndCallChainTest.java | 28 +++++----- .../proxy/service/CreateRequest.java | 20 ++++++- .../proxy/service/CreateResponse.java | 18 ++++++- 6 files changed, 117 insertions(+), 52 deletions(-) create mode 100644 src/main/java/software/amazon/cloudformation/proxy/WaitStrategy.java diff --git a/src/main/java/software/amazon/cloudformation/LambdaWrapper.java b/src/main/java/software/amazon/cloudformation/LambdaWrapper.java index c8bfdcad..d45601a4 100644 --- a/src/main/java/software/amazon/cloudformation/LambdaWrapper.java +++ b/src/main/java/software/amazon/cloudformation/LambdaWrapper.java @@ -64,6 +64,7 @@ import software.amazon.cloudformation.proxy.OperationStatus; import software.amazon.cloudformation.proxy.ProgressEvent; import software.amazon.cloudformation.proxy.ResourceHandlerRequest; +import software.amazon.cloudformation.proxy.WaitStrategy; import software.amazon.cloudformation.resource.ResourceTypeSchema; import software.amazon.cloudformation.resource.SchemaValidator; import software.amazon.cloudformation.resource.Serializer; @@ -303,10 +304,9 @@ public void handleRequest(final InputStream inputStream, final OutputStream outp // in a non-AWS model) AmazonWebServicesClientProxy awsClientProxy = null; if (request.getRequestData().getCallerCredentials() != null) { - awsClientProxy = new AmazonWebServicesClientProxy(callbackContext == null, this.loggerProxy, - request.getRequestData().getCallerCredentials(), - () -> (long) context.getRemainingTimeInMillis(), - DelayFactory.CONSTANT_DEFAULT_DELAY_FACTORY); + awsClientProxy = new AmazonWebServicesClientProxy(this.loggerProxy, request.getRequestData().getCallerCredentials(), + DelayFactory.CONSTANT_DEFAULT_DELAY_FACTORY, + WaitStrategy.scheduleForCallbackStrategy()); } ProgressEvent handlerResponse = wrapInvocationAndHandleErrors(awsClientProxy, diff --git a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java index 3faf10c6..d07bc429 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java +++ b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java @@ -21,13 +21,11 @@ import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.auth.BasicSessionCredentials; import com.google.common.base.Preconditions; -import com.google.common.util.concurrent.Uninterruptibles; import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Objects; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -64,10 +62,9 @@ public class AmazonWebServicesClientProxy implements CallChain { private final AWSCredentialsProvider v1CredentialsProvider; private final AwsCredentialsProvider v2CredentialsProvider; - private final Supplier remainingTimeInMillis; - private final boolean inHandshakeMode; private final LoggerProxy loggerProxy; private final DelayFactory override; + private final WaitStrategy waitStrategy; public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy, final Credentials credentials, @@ -79,18 +76,14 @@ public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy, final Credentials credentials, final Supplier remainingTimeToExecute, final DelayFactory override) { - this(false, loggerProxy, credentials, remainingTimeToExecute, override); + this(loggerProxy, credentials, override, WaitStrategy.newLocalLoopAwaitStrategy(remainingTimeToExecute)); } - public AmazonWebServicesClientProxy(final boolean inHandshakeMode, - final LoggerProxy loggerProxy, + public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy, final Credentials credentials, - final Supplier remainingTimeToExecute, - final DelayFactory override) { - this.inHandshakeMode = inHandshakeMode; + final DelayFactory override, + final WaitStrategy waitStrategy) { this.loggerProxy = loggerProxy; - this.remainingTimeInMillis = remainingTimeToExecute; - BasicSessionCredentials basicSessionCredentials = new BasicSessionCredentials(credentials.getAccessKeyId(), credentials.getSecretAccessKey(), credentials.getSessionToken()); @@ -100,6 +93,7 @@ public AmazonWebServicesClientProxy(final boolean inHandshakeMode, credentials.getSecretAccessKey(), credentials.getSessionToken()); this.v2CredentialsProvider = StaticCredentialsProvider.create(awsSessionCredentials); this.override = Objects.requireNonNull(override); + this.waitStrategy = Objects.requireNonNull(waitStrategy); } public ProxyClient newProxy(@Nonnull Supplier client) { @@ -395,14 +389,6 @@ public ProgressEvent done(Callback done(Callback localWait) { - loggerProxy.log("Waiting for " + next.getSeconds() + " for call " + callGraph); - Uninterruptibles.sleepUninterruptibly(next.getSeconds(), TimeUnit.SECONDS); - continue; + event = AmazonWebServicesClientProxy.this.waitStrategy.await(elapsed, next, context, model); + if (event != null) { + return event; } - return ProgressEvent.defaultInProgressHandler(context, Math.max((int) next.getSeconds(), 60), - model); } } finally { // @@ -455,10 +436,6 @@ public ProgressEvent done(Function> ResultT injectCredentialsAndInvoke(final RequestT request, final Function requestFunction) { diff --git a/src/main/java/software/amazon/cloudformation/proxy/WaitStrategy.java b/src/main/java/software/amazon/cloudformation/proxy/WaitStrategy.java new file mode 100644 index 00000000..ce03448c --- /dev/null +++ b/src/main/java/software/amazon/cloudformation/proxy/WaitStrategy.java @@ -0,0 +1,54 @@ +/* +* Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"). +* You may not use this file except in compliance with the License. +* A copy of the License is located at +* +* http://aws.amazon.com/apache2.0 +* +* or in the "license" file accompanying this file. This file is distributed +* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +* express or implied. See the License for the specific language governing +* permissions and limitations under the License. +*/ +package software.amazon.cloudformation.proxy; + +import com.google.common.util.concurrent.Uninterruptibles; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public interface WaitStrategy { + + ProgressEvent + await(long operationElapsedTime, Duration nextAttempt, CallbackT context, ModelT model); + + static WaitStrategy newLocalLoopAwaitStrategy(final Supplier remainingTimeToExecute) { + return new WaitStrategy() { + @Override + public + ProgressEvent + await(long operationElapsedTime, Duration next, CallbackT context, ModelT model) { + long remainingTime = remainingTimeToExecute.get(); + long localWait = next.toMillis() + 2 * operationElapsedTime + 100; + if (remainingTime > localWait) { + Uninterruptibles.sleepUninterruptibly(next.getSeconds(), TimeUnit.SECONDS); + return null; + } + return ProgressEvent.defaultInProgressHandler(context, (int) next.getSeconds(), model); + } + }; + } + + static WaitStrategy scheduleForCallbackStrategy() { + return new WaitStrategy() { + @Override + public + ProgressEvent + await(long operationElapsedTime, Duration next, CallbackT context, ModelT model) { + return ProgressEvent.defaultInProgressHandler(context, (int) next.getSeconds(), model); + } + }; + } +} diff --git a/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java b/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java index facd9f4a..f4fb32b7 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java +++ b/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java @@ -333,7 +333,7 @@ public AwsErrorDetails awsErrorDetails() { @Order(30) @Test public void createHandlerThrottleException() throws Exception { - final HandlerRequest request = prepareRequest(Model.builder().repoName("repository").build()); + HandlerRequest request = prepareRequest(Model.builder().repoName("repository").build()); request.setAction(Action.CREATE); final Serializer serializer = new Serializer(); final InputStream stream = prepareStream(serializer, request); @@ -364,6 +364,7 @@ public AwsErrorDetails awsErrorDetails() { }; when(client.describeRepository(eq(describeRequest))).thenThrow(throttleException); + when(client.createRepository(any())).thenReturn(mock(CreateResponse.class)); final SdkHttpClient httpClient = mock(SdkHttpClient.class); final ServiceHandlerWrapper wrapper = new ServiceHandlerWrapper(providerLoggingCredentialsProvider, @@ -371,18 +372,19 @@ public AwsErrorDetails awsErrorDetails() { mock(LogPublisher.class), mock(MetricsPublisher.class), new Validator(), serializer, client, httpClient); - // Bail early for the handshake. Reinvoke handler again - wrapper.handleRequest(stream, output, cxt); - ProgressEvent event = serializer.deserialize(output.toString("UTF8"), - new TypeReference>() { - }); - request.setCallbackContext(event.getCallbackContext()); - output = new ByteArrayOutputStream(2048); - wrapper.handleRequest(prepareStream(serializer, request), output, cxt); - - // Handshake mode 1 try, Throttle retries 4 times (1, 0s), (2, 3s), (3, 6s), (4, - // 9s) - verify(client, times(5)).describeRepository(eq(describeRequest)); + ProgressEvent progress; + do { + output = new ByteArrayOutputStream(2048); + wrapper.handleRequest(prepareStream(serializer, request), output, cxt); + progress = serializer.deserialize(output.toString(StandardCharsets.UTF_8.name()), + new TypeReference>() { + }); + request = prepareRequest(progress.getResourceModel()); + request.setCallbackContext(progress.getCallbackContext()); + } while (progress.isInProgressCallbackDelay()); + + // Throttle retries 4 times (1, 0s), (2, 3s), (3, 6s), (4, 9s) + verify(client, times(4)).describeRepository(eq(describeRequest)); ProgressEvent response = serializer.deserialize(output.toString(StandardCharsets.UTF_8.name()), new TypeReference>() { diff --git a/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java b/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java index ac77b5e3..268e72fd 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java +++ b/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java @@ -14,18 +14,34 @@ */ package software.amazon.cloudformation.proxy.service; +import java.util.Arrays; import java.util.Collections; import java.util.List; import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.utils.builder.SdkBuilder; @lombok.Getter @lombok.EqualsAndHashCode(callSuper = false) @lombok.ToString(callSuper = true) public class CreateRequest extends AwsRequest { + private static final SdkField REPO_NAME_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateRequest) obj).repoName).setter((obj, val) -> ((Builder) obj).repoName(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("repoName").build()).build(); + + private static final SdkField USER_NAME_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateRequest) obj).getUserName()).setter((obj, val) -> ((Builder) obj).repoName(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("userName").build()).build(); + + private static final List< + SdkField> SDK_FIELDS = Collections.unmodifiableList(Arrays.asList(REPO_NAME_FIELD, USER_NAME_FIELD)); + private final String repoName; private final String userName; @@ -49,7 +65,7 @@ public List> sdkFields() { @lombok.Getter @lombok.EqualsAndHashCode(callSuper = true) @lombok.ToString(callSuper = true) - public static class Builder extends BuilderImpl implements SdkPojo { + public static class Builder extends BuilderImpl implements SdkPojo, SdkBuilder { private String repoName; private String userName; @@ -76,7 +92,7 @@ public Builder overrideConfiguration(AwsRequestOverrideConfiguration awsRequestO @Override public List> sdkFields() { - return Collections.emptyList(); + return SDK_FIELDS; } } diff --git a/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java b/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java index 76b54a76..c6d36974 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java +++ b/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java @@ -14,16 +14,32 @@ */ package software.amazon.cloudformation.proxy.service; +import java.util.Arrays; import java.util.Collections; import java.util.List; import software.amazon.awssdk.awscore.AwsResponse; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.utils.builder.SdkBuilder; @lombok.Getter @lombok.EqualsAndHashCode(callSuper = true) @lombok.ToString public class CreateResponse extends AwsResponse { + + private static final SdkField REPO_NAME_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateResponse) obj).getRepoName()).setter((obj, val) -> ((CreateResponse.Builder) obj).repoName(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("repoName").build()).build(); + + private static final SdkField ERROR_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateResponse) obj).getError()).setter((obj, val) -> ((Builder) obj).error(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("userName").build()).build(); + + private static final List> SDK_FIELDS = Collections.unmodifiableList(Arrays.asList(REPO_NAME_FIELD, ERROR_FIELD)); + private final String repoName; private final String error; @@ -43,7 +59,7 @@ public List> sdkFields() { return Collections.emptyList(); } - public static class Builder extends BuilderImpl implements SdkPojo { + public static class Builder extends BuilderImpl implements SdkPojo, SdkBuilder { private String repoName; private String error;