diff --git a/src/main/java/software/amazon/cloudformation/LambdaWrapper.java b/src/main/java/software/amazon/cloudformation/LambdaWrapper.java index be35beca..502f391c 100644 --- a/src/main/java/software/amazon/cloudformation/LambdaWrapper.java +++ b/src/main/java/software/amazon/cloudformation/LambdaWrapper.java @@ -64,6 +64,7 @@ import software.amazon.cloudformation.proxy.OperationStatus; import software.amazon.cloudformation.proxy.ProgressEvent; import software.amazon.cloudformation.proxy.ResourceHandlerRequest; +import software.amazon.cloudformation.proxy.WaitStrategy; import software.amazon.cloudformation.resource.ResourceTypeSchema; import software.amazon.cloudformation.resource.SchemaValidator; import software.amazon.cloudformation.resource.Serializer; @@ -302,10 +303,9 @@ public void handleRequest(final InputStream inputStream, final OutputStream outp // in a non-AWS model) AmazonWebServicesClientProxy awsClientProxy = null; if (request.getRequestData().getCallerCredentials() != null) { - awsClientProxy = new AmazonWebServicesClientProxy(callbackContext == null, this.loggerProxy, - request.getRequestData().getCallerCredentials(), - () -> (long) context.getRemainingTimeInMillis(), - DelayFactory.CONSTANT_DEFAULT_DELAY_FACTORY); + awsClientProxy = new AmazonWebServicesClientProxy(this.loggerProxy, request.getRequestData().getCallerCredentials(), + DelayFactory.CONSTANT_DEFAULT_DELAY_FACTORY, + WaitStrategy.scheduleForCallbackStrategy()); } ProgressEvent handlerResponse = wrapInvocationAndHandleErrors(awsClientProxy, diff --git a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java index 3faf10c6..d07bc429 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java +++ b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java @@ -21,13 +21,11 @@ import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.auth.BasicSessionCredentials; import com.google.common.base.Preconditions; -import com.google.common.util.concurrent.Uninterruptibles; import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Objects; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -64,10 +62,9 @@ public class AmazonWebServicesClientProxy implements CallChain { private final AWSCredentialsProvider v1CredentialsProvider; private final AwsCredentialsProvider v2CredentialsProvider; - private final Supplier remainingTimeInMillis; - private final boolean inHandshakeMode; private final LoggerProxy loggerProxy; private final DelayFactory override; + private final WaitStrategy waitStrategy; public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy, final Credentials credentials, @@ -79,18 +76,14 @@ public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy, final Credentials credentials, final Supplier remainingTimeToExecute, final DelayFactory override) { - this(false, loggerProxy, credentials, remainingTimeToExecute, override); + this(loggerProxy, credentials, override, WaitStrategy.newLocalLoopAwaitStrategy(remainingTimeToExecute)); } - public AmazonWebServicesClientProxy(final boolean inHandshakeMode, - final LoggerProxy loggerProxy, + public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy, final Credentials credentials, - final Supplier remainingTimeToExecute, - final DelayFactory override) { - this.inHandshakeMode = inHandshakeMode; + final DelayFactory override, + final WaitStrategy waitStrategy) { this.loggerProxy = loggerProxy; - this.remainingTimeInMillis = remainingTimeToExecute; - BasicSessionCredentials basicSessionCredentials = new BasicSessionCredentials(credentials.getAccessKeyId(), credentials.getSecretAccessKey(), credentials.getSessionToken()); @@ -100,6 +93,7 @@ public AmazonWebServicesClientProxy(final boolean inHandshakeMode, credentials.getSecretAccessKey(), credentials.getSessionToken()); this.v2CredentialsProvider = StaticCredentialsProvider.create(awsSessionCredentials); this.override = Objects.requireNonNull(override); + this.waitStrategy = Objects.requireNonNull(waitStrategy); } public ProxyClient newProxy(@Nonnull Supplier client) { @@ -395,14 +389,6 @@ public ProgressEvent done(Callback done(Callback localWait) { - loggerProxy.log("Waiting for " + next.getSeconds() + " for call " + callGraph); - Uninterruptibles.sleepUninterruptibly(next.getSeconds(), TimeUnit.SECONDS); - continue; + event = AmazonWebServicesClientProxy.this.waitStrategy.await(elapsed, next, context, model); + if (event != null) { + return event; } - return ProgressEvent.defaultInProgressHandler(context, Math.max((int) next.getSeconds(), 60), - model); } } finally { // @@ -455,10 +436,6 @@ public ProgressEvent done(Function> ResultT injectCredentialsAndInvoke(final RequestT request, final Function requestFunction) { diff --git a/src/main/java/software/amazon/cloudformation/proxy/WaitStrategy.java b/src/main/java/software/amazon/cloudformation/proxy/WaitStrategy.java new file mode 100644 index 00000000..ce03448c --- /dev/null +++ b/src/main/java/software/amazon/cloudformation/proxy/WaitStrategy.java @@ -0,0 +1,54 @@ +/* +* Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"). +* You may not use this file except in compliance with the License. +* A copy of the License is located at +* +* http://aws.amazon.com/apache2.0 +* +* or in the "license" file accompanying this file. This file is distributed +* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +* express or implied. See the License for the specific language governing +* permissions and limitations under the License. +*/ +package software.amazon.cloudformation.proxy; + +import com.google.common.util.concurrent.Uninterruptibles; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public interface WaitStrategy { + + ProgressEvent + await(long operationElapsedTime, Duration nextAttempt, CallbackT context, ModelT model); + + static WaitStrategy newLocalLoopAwaitStrategy(final Supplier remainingTimeToExecute) { + return new WaitStrategy() { + @Override + public + ProgressEvent + await(long operationElapsedTime, Duration next, CallbackT context, ModelT model) { + long remainingTime = remainingTimeToExecute.get(); + long localWait = next.toMillis() + 2 * operationElapsedTime + 100; + if (remainingTime > localWait) { + Uninterruptibles.sleepUninterruptibly(next.getSeconds(), TimeUnit.SECONDS); + return null; + } + return ProgressEvent.defaultInProgressHandler(context, (int) next.getSeconds(), model); + } + }; + } + + static WaitStrategy scheduleForCallbackStrategy() { + return new WaitStrategy() { + @Override + public + ProgressEvent + await(long operationElapsedTime, Duration next, CallbackT context, ModelT model) { + return ProgressEvent.defaultInProgressHandler(context, (int) next.getSeconds(), model); + } + }; + } +} diff --git a/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java b/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java index facd9f4a..f4fb32b7 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java +++ b/src/test/java/software/amazon/cloudformation/proxy/End2EndCallChainTest.java @@ -333,7 +333,7 @@ public AwsErrorDetails awsErrorDetails() { @Order(30) @Test public void createHandlerThrottleException() throws Exception { - final HandlerRequest request = prepareRequest(Model.builder().repoName("repository").build()); + HandlerRequest request = prepareRequest(Model.builder().repoName("repository").build()); request.setAction(Action.CREATE); final Serializer serializer = new Serializer(); final InputStream stream = prepareStream(serializer, request); @@ -364,6 +364,7 @@ public AwsErrorDetails awsErrorDetails() { }; when(client.describeRepository(eq(describeRequest))).thenThrow(throttleException); + when(client.createRepository(any())).thenReturn(mock(CreateResponse.class)); final SdkHttpClient httpClient = mock(SdkHttpClient.class); final ServiceHandlerWrapper wrapper = new ServiceHandlerWrapper(providerLoggingCredentialsProvider, @@ -371,18 +372,19 @@ public AwsErrorDetails awsErrorDetails() { mock(LogPublisher.class), mock(MetricsPublisher.class), new Validator(), serializer, client, httpClient); - // Bail early for the handshake. Reinvoke handler again - wrapper.handleRequest(stream, output, cxt); - ProgressEvent event = serializer.deserialize(output.toString("UTF8"), - new TypeReference>() { - }); - request.setCallbackContext(event.getCallbackContext()); - output = new ByteArrayOutputStream(2048); - wrapper.handleRequest(prepareStream(serializer, request), output, cxt); - - // Handshake mode 1 try, Throttle retries 4 times (1, 0s), (2, 3s), (3, 6s), (4, - // 9s) - verify(client, times(5)).describeRepository(eq(describeRequest)); + ProgressEvent progress; + do { + output = new ByteArrayOutputStream(2048); + wrapper.handleRequest(prepareStream(serializer, request), output, cxt); + progress = serializer.deserialize(output.toString(StandardCharsets.UTF_8.name()), + new TypeReference>() { + }); + request = prepareRequest(progress.getResourceModel()); + request.setCallbackContext(progress.getCallbackContext()); + } while (progress.isInProgressCallbackDelay()); + + // Throttle retries 4 times (1, 0s), (2, 3s), (3, 6s), (4, 9s) + verify(client, times(4)).describeRepository(eq(describeRequest)); ProgressEvent response = serializer.deserialize(output.toString(StandardCharsets.UTF_8.name()), new TypeReference>() { diff --git a/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java b/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java index ac77b5e3..268e72fd 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java +++ b/src/test/java/software/amazon/cloudformation/proxy/service/CreateRequest.java @@ -14,18 +14,34 @@ */ package software.amazon.cloudformation.proxy.service; +import java.util.Arrays; import java.util.Collections; import java.util.List; import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.utils.builder.SdkBuilder; @lombok.Getter @lombok.EqualsAndHashCode(callSuper = false) @lombok.ToString(callSuper = true) public class CreateRequest extends AwsRequest { + private static final SdkField REPO_NAME_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateRequest) obj).repoName).setter((obj, val) -> ((Builder) obj).repoName(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("repoName").build()).build(); + + private static final SdkField USER_NAME_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateRequest) obj).getUserName()).setter((obj, val) -> ((Builder) obj).repoName(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("userName").build()).build(); + + private static final List< + SdkField> SDK_FIELDS = Collections.unmodifiableList(Arrays.asList(REPO_NAME_FIELD, USER_NAME_FIELD)); + private final String repoName; private final String userName; @@ -49,7 +65,7 @@ public List> sdkFields() { @lombok.Getter @lombok.EqualsAndHashCode(callSuper = true) @lombok.ToString(callSuper = true) - public static class Builder extends BuilderImpl implements SdkPojo { + public static class Builder extends BuilderImpl implements SdkPojo, SdkBuilder { private String repoName; private String userName; @@ -76,7 +92,7 @@ public Builder overrideConfiguration(AwsRequestOverrideConfiguration awsRequestO @Override public List> sdkFields() { - return Collections.emptyList(); + return SDK_FIELDS; } } diff --git a/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java b/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java index 76b54a76..c6d36974 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java +++ b/src/test/java/software/amazon/cloudformation/proxy/service/CreateResponse.java @@ -14,16 +14,32 @@ */ package software.amazon.cloudformation.proxy.service; +import java.util.Arrays; import java.util.Collections; import java.util.List; import software.amazon.awssdk.awscore.AwsResponse; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.utils.builder.SdkBuilder; @lombok.Getter @lombok.EqualsAndHashCode(callSuper = true) @lombok.ToString public class CreateResponse extends AwsResponse { + + private static final SdkField REPO_NAME_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateResponse) obj).getRepoName()).setter((obj, val) -> ((CreateResponse.Builder) obj).repoName(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("repoName").build()).build(); + + private static final SdkField ERROR_FIELD = SdkField.builder(MarshallingType.STRING) + .getter((obj) -> ((CreateResponse) obj).getError()).setter((obj, val) -> ((Builder) obj).error(val)) + .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("userName").build()).build(); + + private static final List> SDK_FIELDS = Collections.unmodifiableList(Arrays.asList(REPO_NAME_FIELD, ERROR_FIELD)); + private final String repoName; private final String error; @@ -43,7 +59,7 @@ public List> sdkFields() { return Collections.emptyList(); } - public static class Builder extends BuilderImpl implements SdkPojo { + public static class Builder extends BuilderImpl implements SdkPojo, SdkBuilder { private String repoName; private String error;