DurableContext.java
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable;
import com.amazonaws.services.lambda.runtime.Context;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.HexFormat;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.slf4j.LoggerFactory;
import software.amazon.lambda.durable.execution.ExecutionManager;
import software.amazon.lambda.durable.execution.ThreadType;
import software.amazon.lambda.durable.logging.DurableLogger;
import software.amazon.lambda.durable.model.OperationSubType;
import software.amazon.lambda.durable.operation.CallbackOperation;
import software.amazon.lambda.durable.operation.ChildContextOperation;
import software.amazon.lambda.durable.operation.InvokeOperation;
import software.amazon.lambda.durable.operation.StepOperation;
import software.amazon.lambda.durable.operation.WaitOperation;
import software.amazon.lambda.durable.validation.ParameterValidator;
public class DurableContext extends BaseContext {
private static final String WAIT_FOR_CALLBACK_CALLBACK_SUFFIX = "-callback";
private static final String WAIT_FOR_CALLBACK_SUBMITTER_SUFFIX = "-submitter";
private static final int MAX_WAIT_FOR_CALLBACK_NAME_LENGTH = ParameterValidator.MAX_OPERATION_NAME_LENGTH
- Math.max(WAIT_FOR_CALLBACK_CALLBACK_SUFFIX.length(), WAIT_FOR_CALLBACK_SUBMITTER_SUFFIX.length());
private final AtomicInteger operationCounter;
private volatile DurableLogger logger;
/** Shared initialization — sets all fields. */
private DurableContext(
ExecutionManager executionManager,
DurableConfig durableConfig,
Context lambdaContext,
String contextId,
String contextName) {
super(executionManager, durableConfig, lambdaContext, contextId, contextName, ThreadType.CONTEXT);
this.operationCounter = new AtomicInteger(0);
}
/**
* Creates a root context (contextId = null)
*
* <p>The context itself always has a null contextId (making it a root context).
*
* @param executionManager the execution manager
* @param durableConfig the durable configuration
* @param lambdaContext the Lambda context
* @return a new root DurableContext
*/
public static DurableContext createRootContext(
ExecutionManager executionManager, DurableConfig durableConfig, Context lambdaContext) {
return new DurableContext(executionManager, durableConfig, lambdaContext, null, null);
}
/**
* Creates a child context.
*
* @param childContextId the child context's ID (the CONTEXT operation's operation ID)
* @return a new DurableContext for the child context
*/
public DurableContext createChildContext(String childContextId, String childContextName) {
return new DurableContext(
getExecutionManager(), getDurableConfig(), getLambdaContext(), childContextId, childContextName);
}
/**
* Creates a step context for executing step operations.
*
* @param stepOperationId the ID of the step operation (used for thread registration)
* @return a new StepContext instance
*/
public StepContext createStepContext(String stepOperationId, String stepOperationName, int attempt) {
return new StepContext(
getExecutionManager(),
getDurableConfig(),
getLambdaContext(),
stepOperationId,
stepOperationName,
attempt);
}
// ========== step methods ==========
public <T> T step(String name, Class<T> resultType, Function<StepContext, T> func) {
return step(name, TypeToken.get(resultType), func, StepConfig.builder().build());
}
public <T> T step(String name, Class<T> resultType, Function<StepContext, T> func, StepConfig config) {
// Simply delegate to stepAsync and block on the result
return stepAsync(name, resultType, func, config).get();
}
public <T> T step(String name, TypeToken<T> typeToken, Function<StepContext, T> func) {
return step(name, typeToken, func, StepConfig.builder().build());
}
public <T> T step(String name, TypeToken<T> typeToken, Function<StepContext, T> func, StepConfig config) {
// Simply delegate to stepAsync and block on the result
return stepAsync(name, typeToken, func, config).get();
}
public <T> DurableFuture<T> stepAsync(String name, Class<T> resultType, Function<StepContext, T> func) {
return stepAsync(
name, TypeToken.get(resultType), func, StepConfig.builder().build());
}
public <T> DurableFuture<T> stepAsync(
String name, Class<T> resultType, Function<StepContext, T> func, StepConfig config) {
return stepAsync(name, TypeToken.get(resultType), func, config);
}
public <T> DurableFuture<T> stepAsync(String name, TypeToken<T> typeToken, Function<StepContext, T> func) {
return stepAsync(name, typeToken, func, StepConfig.builder().build());
}
public <T> DurableFuture<T> stepAsync(
String name, TypeToken<T> typeToken, Function<StepContext, T> func, StepConfig config) {
Objects.requireNonNull(config, "config cannot be null");
Objects.requireNonNull(typeToken, "typeToken cannot be null");
ParameterValidator.validateOperationName(name);
if (config.serDes() == null) {
config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build();
}
var operationId = nextOperationId();
// Create and start step operation with TypeToken
var operation = new StepOperation<>(operationId, name, func, typeToken, config, this);
operation.execute(); // Start the step (returns immediately)
return operation;
}
public <T> T step(String name, Class<T> resultType, Supplier<T> func) {
return step(name, TypeToken.get(resultType), func, StepConfig.builder().build());
}
public <T> T step(String name, Class<T> resultType, Supplier<T> func, StepConfig config) {
// Simply delegate to stepAsync and block on the result
return stepAsync(name, resultType, func, config).get();
}
public <T> T step(String name, TypeToken<T> typeToken, Supplier<T> func) {
return step(name, typeToken, func, StepConfig.builder().build());
}
public <T> T step(String name, TypeToken<T> typeToken, Supplier<T> func, StepConfig config) {
// Simply delegate to stepAsync and block on the result
return stepAsync(name, typeToken, func, config).get();
}
public <T> DurableFuture<T> stepAsync(String name, Class<T> resultType, Supplier<T> func) {
return stepAsync(
name, TypeToken.get(resultType), func, StepConfig.builder().build());
}
public <T> DurableFuture<T> stepAsync(String name, Class<T> resultType, Supplier<T> func, StepConfig config) {
return stepAsync(name, TypeToken.get(resultType), func, config);
}
public <T> DurableFuture<T> stepAsync(String name, TypeToken<T> typeToken, Supplier<T> func) {
return stepAsync(name, typeToken, func, StepConfig.builder().build());
}
public <T> DurableFuture<T> stepAsync(String name, TypeToken<T> typeToken, Supplier<T> func, StepConfig config) {
return stepAsync(name, typeToken, stepContext -> func.get(), config);
}
// ========== wait methods ==========
public Void wait(String name, Duration duration) {
// Block (will throw SuspendExecutionException if there is no active thread)
return waitAsync(name, duration).get();
}
public DurableFuture<Void> waitAsync(String name, Duration duration) {
ParameterValidator.validateDuration(duration, "Wait duration");
ParameterValidator.validateOperationName(name);
var operationId = nextOperationId();
// Create and start wait operation
var operation = new WaitOperation(operationId, name, duration, this);
operation.execute(); // Checkpoint the wait
return operation;
}
// ========== chained invoke methods ==========
public <T, U> T invoke(String name, String functionName, U payload, Class<T> resultType) {
return invokeAsync(
name,
functionName,
payload,
resultType,
InvokeConfig.builder().build())
.get();
}
public <T, U> T invoke(String name, String functionName, U payload, Class<T> resultType, InvokeConfig config) {
return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config)
.get();
}
public <T, U> T invoke(String name, String functionName, U payload, TypeToken<T> typeToken) {
return invokeAsync(
name,
functionName,
payload,
typeToken,
InvokeConfig.builder().build())
.get();
}
public <T, U> T invoke(String name, String functionName, U payload, TypeToken<T> typeToken, InvokeConfig config) {
return invokeAsync(name, functionName, payload, typeToken, config).get();
}
public <T, U> DurableFuture<T> invokeAsync(
String name, String functionName, U payload, Class<T> resultType, InvokeConfig config) {
return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config);
}
public <T, U> DurableFuture<T> invokeAsync(String name, String functionName, U payload, Class<T> resultType) {
return invokeAsync(
name,
functionName,
payload,
TypeToken.get(resultType),
InvokeConfig.builder().build());
}
public <T, U> DurableFuture<T> invokeAsync(String name, String functionName, U payload, TypeToken<T> resultType) {
return invokeAsync(
name, functionName, payload, resultType, InvokeConfig.builder().build());
}
public <T, U> DurableFuture<T> invokeAsync(
String name, String functionName, U payload, TypeToken<T> typeToken, InvokeConfig config) {
Objects.requireNonNull(config, "config cannot be null");
Objects.requireNonNull(typeToken, "typeToken cannot be null");
ParameterValidator.validateOperationName(name);
if (config.serDes() == null) {
config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build();
}
if (config.payloadSerDes() == null) {
config = config.toBuilder()
.payloadSerDes(getDurableConfig().getSerDes())
.build();
}
var operationId = nextOperationId();
// Create and start invoke operation
var operation = new InvokeOperation<>(operationId, name, functionName, payload, typeToken, config, this);
operation.execute(); // checkpoint the invoke operation
return operation; // Block (will throw SuspendExecutionException if needed)
}
// ========== createCallback methods ==========
public <T> DurableCallbackFuture<T> createCallback(String name, Class<T> resultType, CallbackConfig config) {
return createCallback(name, TypeToken.get(resultType), config);
}
public <T> DurableCallbackFuture<T> createCallback(String name, TypeToken<T> typeToken) {
return createCallback(name, typeToken, CallbackConfig.builder().build());
}
public <T> DurableCallbackFuture<T> createCallback(String name, Class<T> resultType) {
return createCallback(name, resultType, CallbackConfig.builder().build());
}
public <T> DurableCallbackFuture<T> createCallback(String name, TypeToken<T> typeToken, CallbackConfig config) {
ParameterValidator.validateOperationName(name);
if (config.serDes() == null) {
config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build();
}
var operationId = nextOperationId();
var operation = new CallbackOperation<>(operationId, name, typeToken, config, this);
operation.execute();
return operation;
}
// ========== runInChildContext methods ==========
public <T> T runInChildContext(String name, Class<T> resultType, Function<DurableContext, T> func) {
return runInChildContextAsync(name, TypeToken.get(resultType), func).get();
}
public <T> T runInChildContext(String name, TypeToken<T> typeToken, Function<DurableContext, T> func) {
return runInChildContextAsync(name, typeToken, func).get();
}
public <T> DurableFuture<T> runInChildContextAsync(
String name, Class<T> resultType, Function<DurableContext, T> func) {
return runInChildContextAsync(name, TypeToken.get(resultType), func);
}
public <T> DurableFuture<T> runInChildContextAsync(
String name, TypeToken<T> typeToken, Function<DurableContext, T> func) {
return runInChildContextAsync(name, typeToken, func, OperationSubType.RUN_IN_CHILD_CONTEXT);
}
private <T> DurableFuture<T> runInChildContextAsync(
String name, TypeToken<T> typeToken, Function<DurableContext, T> func, OperationSubType subType) {
Objects.requireNonNull(typeToken, "typeToken cannot be null");
ParameterValidator.validateOperationName(name);
var operationId = nextOperationId();
var operation = new ChildContextOperation<>(
operationId, name, func, subType, typeToken, getDurableConfig().getSerDes(), this);
operation.execute();
return operation;
}
// ========= waitForCallback methods =============
public <T> T waitForCallback(String name, Class<T> resultType, BiConsumer<String, StepContext> func) {
return waitForCallbackAsync(
name,
TypeToken.get(resultType),
func,
WaitForCallbackConfig.builder().build())
.get();
}
public <T> T waitForCallback(String name, TypeToken<T> typeToken, BiConsumer<String, StepContext> func) {
return waitForCallbackAsync(
name, typeToken, func, WaitForCallbackConfig.builder().build())
.get();
}
public <T> T waitForCallback(
String name,
Class<T> resultType,
BiConsumer<String, StepContext> func,
WaitForCallbackConfig waitForCallbackConfig) {
return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig)
.get();
}
public <T> T waitForCallback(
String name,
TypeToken<T> typeToken,
BiConsumer<String, StepContext> func,
WaitForCallbackConfig waitForCallbackConfig) {
return waitForCallbackAsync(name, typeToken, func, waitForCallbackConfig)
.get();
}
public <T> DurableFuture<T> waitForCallbackAsync(
String name, Class<T> resultType, BiConsumer<String, StepContext> func) {
return waitForCallbackAsync(
name,
TypeToken.get(resultType),
func,
WaitForCallbackConfig.builder().build());
}
public <T> DurableFuture<T> waitForCallbackAsync(
String name, TypeToken<T> typeToken, BiConsumer<String, StepContext> func) {
return waitForCallbackAsync(
name, typeToken, func, WaitForCallbackConfig.builder().build());
}
public <T> DurableFuture<T> waitForCallbackAsync(
String name,
Class<T> resultType,
BiConsumer<String, StepContext> func,
WaitForCallbackConfig waitForCallbackConfig) {
return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig);
}
public <T> DurableFuture<T> waitForCallbackAsync(
String name,
TypeToken<T> typeToken,
BiConsumer<String, StepContext> func,
WaitForCallbackConfig waitForCallbackConfig) {
Objects.requireNonNull(typeToken, "typeToken cannot be null");
Objects.requireNonNull(waitForCallbackConfig, "waitForCallbackConfig cannot be null");
// waitForCallback adds a suffix for the callback operation name and the submitter operation name so
// the length restriction of waitForCallback name is different from the other operations.
ParameterValidator.validateOperationName(name, MAX_WAIT_FOR_CALLBACK_NAME_LENGTH);
var finalWaitForCallbackConfig = waitForCallbackConfig.stepConfig().serDes() == null
? waitForCallbackConfig.toBuilder()
.stepConfig(waitForCallbackConfig.stepConfig().toBuilder()
.serDes(getDurableConfig().getSerDes())
.build())
.build()
: waitForCallbackConfig;
return runInChildContextAsync(
name,
typeToken,
childCtx -> {
var callback = childCtx.createCallback(
name + WAIT_FOR_CALLBACK_CALLBACK_SUFFIX,
typeToken,
finalWaitForCallbackConfig.callbackConfig());
childCtx.step(
name + WAIT_FOR_CALLBACK_SUBMITTER_SUFFIX,
Void.class,
stepCtx -> {
func.accept(callback.callbackId(), stepCtx);
return null;
},
finalWaitForCallbackConfig.stepConfig());
return callback.get();
},
OperationSubType.WAIT_FOR_CALLBACK);
}
// =============== accessors ================
/**
* Returns a logger with execution context information for replay-aware logging.
*
* @return the durable logger
*/
public DurableLogger getLogger() {
// lazy initialize logger
if (logger == null) {
synchronized (this) {
if (logger == null) {
logger = new DurableLogger(LoggerFactory.getLogger(DurableContext.class), this);
}
}
}
return logger;
}
/**
* Clears the logger's thread properties. Called during context destruction to prevent memory leaks and ensure clean
* state for subsequent executions.
*/
@Override
public void close() {
if (logger != null) {
logger.close();
}
super.close();
}
/**
* Get the next operationId. Returns a globally unique operation ID by hashing a sequential operation counter. For
* root contexts, the counter value is hashed directly (e.g. "1", "2", "3"). For child contexts, the values are
* prefixed with the parent hashed contextId (e.g. "<hash>-1", "<hash>-2" inside parent context <hash>). This
* matches the Python SDK's stepPrefix convention and prevents ID collisions in checkpoint batches.
*/
private String nextOperationId() {
var counter = String.valueOf(operationCounter.incrementAndGet());
var rawId = getContextId() != null ? getContextId() + "-" + counter : counter;
try {
var messageDigest = MessageDigest.getInstance("SHA-256");
var hash = messageDigest.digest(rawId.getBytes(StandardCharsets.UTF_8));
return HexFormat.of().formatHex(hash);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("failed to get next operation id, SHA-256 not available", e);
}
}
}