Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import software.amazon.cloudformation.proxy.OperationStatus;
import software.amazon.cloudformation.proxy.ProgressEvent;
import software.amazon.cloudformation.proxy.ResourceHandlerRequest;
import software.amazon.cloudformation.proxy.WaitStrategy;
import software.amazon.cloudformation.resource.ResourceTypeSchema;
import software.amazon.cloudformation.resource.SchemaValidator;
import software.amazon.cloudformation.resource.Serializer;
Expand Down Expand Up @@ -302,10 +303,9 @@ public void handleRequest(final InputStream inputStream, final OutputStream outp
// in a non-AWS model)
AmazonWebServicesClientProxy awsClientProxy = null;
if (request.getRequestData().getCallerCredentials() != null) {
awsClientProxy = new AmazonWebServicesClientProxy(callbackContext == null, this.loggerProxy,
request.getRequestData().getCallerCredentials(),
() -> (long) context.getRemainingTimeInMillis(),
DelayFactory.CONSTANT_DEFAULT_DELAY_FACTORY);
awsClientProxy = new AmazonWebServicesClientProxy(this.loggerProxy, request.getRequestData().getCallerCredentials(),
DelayFactory.CONSTANT_DEFAULT_DELAY_FACTORY,
WaitStrategy.scheduleForCallbackStrategy());
}

ProgressEvent<ResourceT, CallbackT> handlerResponse = wrapInvocationAndHandleErrors(awsClientProxy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicSessionCredentials;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Uninterruptibles;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -64,10 +62,9 @@ public class AmazonWebServicesClientProxy implements CallChain {

private final AWSCredentialsProvider v1CredentialsProvider;
private final AwsCredentialsProvider v2CredentialsProvider;
private final Supplier<Long> remainingTimeInMillis;
private final boolean inHandshakeMode;
private final LoggerProxy loggerProxy;
private final DelayFactory override;
private final WaitStrategy waitStrategy;

public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy,
final Credentials credentials,
Expand All @@ -79,18 +76,14 @@ public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy,
final Credentials credentials,
final Supplier<Long> remainingTimeToExecute,
final DelayFactory override) {
this(false, loggerProxy, credentials, remainingTimeToExecute, override);
this(loggerProxy, credentials, override, WaitStrategy.newLocalLoopAwaitStrategy(remainingTimeToExecute));
}

public AmazonWebServicesClientProxy(final boolean inHandshakeMode,
final LoggerProxy loggerProxy,
public AmazonWebServicesClientProxy(final LoggerProxy loggerProxy,
final Credentials credentials,
final Supplier<Long> remainingTimeToExecute,
final DelayFactory override) {
this.inHandshakeMode = inHandshakeMode;
final DelayFactory override,
final WaitStrategy waitStrategy) {
this.loggerProxy = loggerProxy;
this.remainingTimeInMillis = remainingTimeToExecute;

BasicSessionCredentials basicSessionCredentials = new BasicSessionCredentials(credentials.getAccessKeyId(),
credentials.getSecretAccessKey(),
credentials.getSessionToken());
Expand All @@ -100,6 +93,7 @@ public AmazonWebServicesClientProxy(final boolean inHandshakeMode,
credentials.getSecretAccessKey(), credentials.getSessionToken());
this.v2CredentialsProvider = StaticCredentialsProvider.create(awsSessionCredentials);
this.override = Objects.requireNonNull(override);
this.waitStrategy = Objects.requireNonNull(waitStrategy);
}

public <ClientT> ProxyClient<ClientT> newProxy(@Nonnull Supplier<ClientT> client) {
Expand Down Expand Up @@ -395,14 +389,6 @@ public ProgressEvent<ModelT, CallbackT> done(Callback<RequestT, ResponseT, Clien
event = exceptionHandler.invoke(req, e, client, model, context);
}

if (event != null && (event.isFailed() || event.isSuccess())) {
return event;
}

if (inHandshakeMode) {
return ProgressEvent.defaultInProgressHandler(context, 60, model);
}

if (event != null) {
return event;
}
Expand All @@ -422,15 +408,10 @@ public ProgressEvent<ModelT, CallbackT> done(Callback<RequestT, ResponseT, Clien
return ProgressEvent.failed(model, context, HandlerErrorCode.NotStabilized,
"Exceeded attempts to wait");
}
long remainingTime = getRemainingTimeInMillis();
long localWait = next.toMillis() + 2 * elapsed + 100;
if (remainingTime > localWait) {
loggerProxy.log("Waiting for " + next.getSeconds() + " for call " + callGraph);
Uninterruptibles.sleepUninterruptibly(next.getSeconds(), TimeUnit.SECONDS);
continue;
event = AmazonWebServicesClientProxy.this.waitStrategy.await(elapsed, next, context, model);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would be good to use a different variable as it's about whether to continue.

final Optional<ProgressEvent> completeLocalExecution = AmazonWebServicesClientProxy.this.waitStrategy.await(elapsed, next, context, model);
if (completeLocalExecution.present()) {
 return completeLocalExecution.get();
}

if (event != null) {
return event;
}
return ProgressEvent.defaultInProgressHandler(context, Math.max((int) next.getSeconds(), 60),
model);
}
} finally {
//
Expand All @@ -455,10 +436,6 @@ public ProgressEvent<ModelT, CallbackT> done(Function<ResponseT, ProgressEvent<M

}

public final long getRemainingTimeInMillis() {
return remainingTimeInMillis.get();
}

public <RequestT extends AmazonWebServiceRequest, ResultT extends AmazonWebServiceResult<ResponseMetadata>>
ResultT
injectCredentialsAndInvoke(final RequestT request, final Function<RequestT, ResultT> requestFunction) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package software.amazon.cloudformation.proxy;

import com.google.common.util.concurrent.Uninterruptibles;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

public interface WaitStrategy {
<ModelT, CallbackT>
ProgressEvent<ModelT, CallbackT>
await(long operationElapsedTime, Duration nextAttempt, CallbackT context, ModelT model);

static WaitStrategy newLocalLoopAwaitStrategy(final Supplier<Long> remainingTimeToExecute) {
return new WaitStrategy() {
@Override
public <ModelT, CallbackT>
ProgressEvent<ModelT, CallbackT>
await(long operationElapsedTime, Duration next, CallbackT context, ModelT model) {
long remainingTime = remainingTimeToExecute.get();
long localWait = next.toMillis() + 2 * operationElapsedTime + 100;
if (remainingTime > localWait) {
Uninterruptibles.sleepUninterruptibly(next.getSeconds(), TimeUnit.SECONDS);
return null;
}
return ProgressEvent.defaultInProgressHandler(context, (int) next.getSeconds(), model);
}
};
}

static WaitStrategy scheduleForCallbackStrategy() {
return new WaitStrategy() {
@Override
public <ModelT, CallbackT>
ProgressEvent<ModelT, CallbackT>
await(long operationElapsedTime, Duration next, CallbackT context, ModelT model) {
return ProgressEvent.defaultInProgressHandler(context, (int) next.getSeconds(), model);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ public AwsErrorDetails awsErrorDetails() {
@Order(30)
@Test
public void createHandlerThrottleException() throws Exception {
final HandlerRequest<Model, StdCallbackContext> request = prepareRequest(Model.builder().repoName("repository").build());
HandlerRequest<Model, StdCallbackContext> request = prepareRequest(Model.builder().repoName("repository").build());
request.setAction(Action.CREATE);
final Serializer serializer = new Serializer();
final InputStream stream = prepareStream(serializer, request);
Expand Down Expand Up @@ -364,25 +364,27 @@ public AwsErrorDetails awsErrorDetails() {

};
when(client.describeRepository(eq(describeRequest))).thenThrow(throttleException);
when(client.createRepository(any())).thenReturn(mock(CreateResponse.class));

final SdkHttpClient httpClient = mock(SdkHttpClient.class);
final ServiceHandlerWrapper wrapper = new ServiceHandlerWrapper(providerLoggingCredentialsProvider,
mock(CloudWatchLogPublisher.class),
mock(LogPublisher.class), mock(MetricsPublisher.class),
new Validator(), serializer, client, httpClient);

// Bail early for the handshake. Reinvoke handler again
wrapper.handleRequest(stream, output, cxt);
ProgressEvent<Model, StdCallbackContext> event = serializer.deserialize(output.toString("UTF8"),
new TypeReference<ProgressEvent<Model, StdCallbackContext>>() {
});
request.setCallbackContext(event.getCallbackContext());
output = new ByteArrayOutputStream(2048);
wrapper.handleRequest(prepareStream(serializer, request), output, cxt);

// Handshake mode 1 try, Throttle retries 4 times (1, 0s), (2, 3s), (3, 6s), (4,
// 9s)
verify(client, times(5)).describeRepository(eq(describeRequest));
ProgressEvent<Model, StdCallbackContext> progress;
do {
output = new ByteArrayOutputStream(2048);
wrapper.handleRequest(prepareStream(serializer, request), output, cxt);
progress = serializer.deserialize(output.toString(StandardCharsets.UTF_8.name()),
new TypeReference<ProgressEvent<Model, StdCallbackContext>>() {
});
request = prepareRequest(progress.getResourceModel());
request.setCallbackContext(progress.getCallbackContext());
} while (progress.isInProgressCallbackDelay());

// Throttle retries 4 times (1, 0s), (2, 3s), (3, 6s), (4, 9s)
verify(client, times(4)).describeRepository(eq(describeRequest));

ProgressEvent<Model, StdCallbackContext> response = serializer.deserialize(output.toString(StandardCharsets.UTF_8.name()),
new TypeReference<ProgressEvent<Model, StdCallbackContext>>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,34 @@
*/
package software.amazon.cloudformation.proxy.service;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import software.amazon.awssdk.awscore.AwsRequest;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.LocationTrait;
import software.amazon.awssdk.utils.builder.SdkBuilder;

@lombok.Getter
@lombok.EqualsAndHashCode(callSuper = false)
@lombok.ToString(callSuper = true)
public class CreateRequest extends AwsRequest {

private static final SdkField<String> REPO_NAME_FIELD = SdkField.<String>builder(MarshallingType.STRING)
.getter((obj) -> ((CreateRequest) obj).repoName).setter((obj, val) -> ((Builder) obj).repoName(val))
.traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("repoName").build()).build();

private static final SdkField<String> USER_NAME_FIELD = SdkField.<String>builder(MarshallingType.STRING)
.getter((obj) -> ((CreateRequest) obj).getUserName()).setter((obj, val) -> ((Builder) obj).repoName(val))
.traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("userName").build()).build();

private static final List<
SdkField<?>> SDK_FIELDS = Collections.unmodifiableList(Arrays.asList(REPO_NAME_FIELD, USER_NAME_FIELD));

private final String repoName;
private final String userName;

Expand All @@ -49,7 +65,7 @@ public List<SdkField<?>> sdkFields() {
@lombok.Getter
@lombok.EqualsAndHashCode(callSuper = true)
@lombok.ToString(callSuper = true)
public static class Builder extends BuilderImpl implements SdkPojo {
public static class Builder extends BuilderImpl implements SdkPojo, SdkBuilder<Builder, CreateRequest> {
private String repoName;
private String userName;

Expand All @@ -76,7 +92,7 @@ public Builder overrideConfiguration(AwsRequestOverrideConfiguration awsRequestO

@Override
public List<SdkField<?>> sdkFields() {
return Collections.emptyList();
return SDK_FIELDS;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,32 @@
*/
package software.amazon.cloudformation.proxy.service;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import software.amazon.awssdk.awscore.AwsResponse;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.LocationTrait;
import software.amazon.awssdk.utils.builder.SdkBuilder;

@lombok.Getter
@lombok.EqualsAndHashCode(callSuper = true)
@lombok.ToString
public class CreateResponse extends AwsResponse {

private static final SdkField<String> REPO_NAME_FIELD = SdkField.<String>builder(MarshallingType.STRING)
.getter((obj) -> ((CreateResponse) obj).getRepoName()).setter((obj, val) -> ((CreateResponse.Builder) obj).repoName(val))
.traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("repoName").build()).build();

private static final SdkField<String> ERROR_FIELD = SdkField.<String>builder(MarshallingType.STRING)
.getter((obj) -> ((CreateResponse) obj).getError()).setter((obj, val) -> ((Builder) obj).error(val))
.traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("userName").build()).build();

private static final List<SdkField<?>> SDK_FIELDS = Collections.unmodifiableList(Arrays.asList(REPO_NAME_FIELD, ERROR_FIELD));

private final String repoName;
private final String error;

Expand All @@ -43,7 +59,7 @@ public List<SdkField<?>> sdkFields() {
return Collections.emptyList();
}

public static class Builder extends BuilderImpl implements SdkPojo {
public static class Builder extends BuilderImpl implements SdkPojo, SdkBuilder<Builder, CreateResponse> {
private String repoName;
private String error;

Expand Down