ExecutionManager.java
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.execution;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.lambda.model.CheckpointUpdatedExecutionState;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationStatus;
import software.amazon.awssdk.services.lambda.model.OperationType;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.DurableConfig;
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
import software.amazon.lambda.durable.model.DurableExecutionInput;
import software.amazon.lambda.durable.operation.BaseDurableOperation;
/**
* Central manager for durable execution coordination.
*
* <p>Consolidates:
*
* <ul>
* <li>Execution state (operations, checkpoint token)
* <li>Thread lifecycle (registration/deregistration)
* <li>Checkpoint batching (via CheckpointBatcher)
* <li>Checkpoint result handling (CheckpointBatcher callback)
* <li>Polling (for waits and retries)
* </ul>
*
* <p>This is the single entry point for all execution coordination. Internal coordination (polling, checkpointing) uses
* a dedicated SDK thread pool, while user-defined operations run on a customer-configured executor.
*
* <p>Operations are keyed by their globally unique operation ID. Child context operations use prefixed IDs (e.g.,
* "1-1", "1-2") to avoid collisions with root-level operations.
*
* @see InternalExecutor
*/
public class ExecutionManager implements AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(ExecutionManager.class);
// ===== Execution State =====
private final Map<String, Operation> operationStorage;
private final Operation executionOp;
private final String durableExecutionArn;
private final AtomicReference<ExecutionMode> executionMode;
// ===== Thread Coordination =====
private final Map<String, BaseDurableOperation<?>> registeredOperations =
Collections.synchronizedMap(new HashMap<>());
private final Set<String> activeThreads = Collections.synchronizedSet(new HashSet<>());
private static final ThreadLocal<ThreadContext> currentThreadContext = new ThreadLocal<>();
private final CompletableFuture<Void> executionExceptionFuture = new CompletableFuture<>();
// ===== Checkpoint Batching =====
private final CheckpointBatcher checkpointBatcher;
public ExecutionManager(DurableExecutionInput input, DurableConfig config) {
this.durableExecutionArn = input.durableExecutionArn();
// Create checkpoint batcher for internal coordination
this.checkpointBatcher =
new CheckpointBatcher(config, durableExecutionArn, input.checkpointToken(), this::onCheckpointComplete);
this.operationStorage = checkpointBatcher.fetchAllPages(input.initialExecutionState()).stream()
.collect(Collectors.toConcurrentMap(Operation::id, op -> op));
// Start in REPLAY mode if we have more than just the initial EXECUTION operation
this.executionMode =
new AtomicReference<>(operationStorage.size() > 1 ? ExecutionMode.REPLAY : ExecutionMode.EXECUTION);
executionOp = findExecutionOp(input.initialExecutionState());
// Validate initial operation is an EXECUTION operation
if (executionOp == null) {
throw new IllegalStateException("First operation must be EXECUTION");
}
logger.debug("DurableExecution.execute() called");
logger.debug("DurableExecutionArn: {}", durableExecutionArn);
logger.debug("Initial operations count: {}", operationStorage.size());
logger.debug("EXECUTION operation found: {}", executionOp.id());
}
// ===== State Management =====
public String getDurableExecutionArn() {
return durableExecutionArn;
}
public boolean isReplaying() {
return executionMode.get() == ExecutionMode.REPLAY;
}
public void registerOperation(BaseDurableOperation<?> operation) {
registeredOperations.put(operation.getOperationId(), operation);
}
// ===== Checkpoint Completion Handler =====
/** Called by CheckpointManager when a checkpoint completes. Updates operationStorage and notify operations . */
private void onCheckpointComplete(List<Operation> newOperations) {
newOperations.forEach(op -> {
// Update operation storage
operationStorage.put(op.id(), op);
// call registered operation's onCheckpointComplete method for completed operations
registeredOperations.computeIfPresent(op.id(), (id, operation) -> {
operation.onCheckpointComplete(op);
return operation;
});
});
}
/**
* Gets all child operations for a given operationId.
*
* @param operationId the operationId to get children for
* @return List of child operations for the given operationId
*/
public List<Operation> getChildOperations(String operationId) {
// todo: this is O(n) - consider an improvement if performance becomes an issue
var children = new ArrayList<Operation>();
for (Operation op : operationStorage.values()) {
if (Objects.equals(op.parentId(), operationId)) {
children.add(op);
}
}
return children;
}
/**
* Gets an operation by its globally unique operationId, and updates replay state. Transitions from REPLAY to
* EXECUTION mode if the operation is not found or is not in a terminal state (still in progress).
*
* @param operationId the globally unique operation ID (e.g., "1" for root, "1-1" for child context)
* @return the existing operation, or null if not found (first execution)
*/
public Operation getOperationAndUpdateReplayState(String operationId) {
var existing = operationStorage.get(operationId);
if (executionMode.get() == ExecutionMode.REPLAY && (existing == null || !isTerminalStatus(existing.status()))) {
if (executionMode.compareAndSet(ExecutionMode.REPLAY, ExecutionMode.EXECUTION)) {
logger.debug("Transitioned to EXECUTION mode at operation '{}'", operationId);
}
}
return existing;
}
public Operation getExecutionOperation() {
return executionOp;
}
private Operation findExecutionOp(CheckpointUpdatedExecutionState initialExecutionState) {
// find execution OP in the input
if (initialExecutionState != null
&& initialExecutionState.operations() != null
&& !initialExecutionState.operations().isEmpty()) {
var op = initialExecutionState.operations().get(0);
if (op.type() != OperationType.EXECUTION) {
throw new IllegalStateException("First operation must be EXECUTION");
}
return op;
}
// find execution OP in the checkpoint result
for (Operation op : operationStorage.values()) {
if (op.type() == OperationType.EXECUTION) {
return op;
}
}
return null;
}
/**
* Checks whether there are any cached operations for the given parent context ID. Used to initialize per-context
* replay state — a context starts in replay mode if the ExecutionManager has cached operations belonging to it.
*
* @param parentId the context ID to check (null for root context)
* @return true if at least one operation exists with the given parentId
*/
public boolean hasOperationsForContext(String parentId) {
return operationStorage.values().stream().anyMatch(op -> Objects.equals(op.parentId(), parentId));
}
// ===== Thread Coordination =====
/** Sets the current thread's ThreadContext (threadId and threadType). Called when a user thread is started. */
public void setCurrentThreadContext(ThreadContext threadContext) {
currentThreadContext.set(threadContext);
}
/** Returns the current thread's ThreadContext (threadId and threadType), or null if not set. */
public ThreadContext getCurrentThreadContext() {
return currentThreadContext.get();
}
/**
* Registers a thread as active.
*
* @see ThreadContext
*/
public void registerActiveThread(String threadId) {
if (activeThreads.contains(threadId)) {
logger.trace("Thread '{}' already registered as active", threadId);
return;
}
activeThreads.add(threadId);
logger.trace("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
}
/**
* Mark a thread as inactive. If no threads remain, suspends the execution.
*
* @param threadId the thread ID to deregister
*/
public void deregisterActiveThread(String threadId) {
// Skip if already suspended
if (executionExceptionFuture.isDone()) {
return;
}
boolean removed = activeThreads.remove(threadId);
if (removed) {
logger.trace("Deregistered thread '{}' Active threads: {}", threadId, activeThreads.size());
} else {
logger.warn("Thread '{}' not active, cannot deregister", threadId);
}
if (activeThreads.isEmpty()) {
logger.info("No active threads remaining - suspending execution");
suspendExecution();
}
}
// ===== Checkpointing =====
// This method will checkpoint the operation updates to the durable backend and return a future which completes
// when the checkpoint completes.
public CompletableFuture<Void> sendOperationUpdate(OperationUpdate update) {
return checkpointBatcher.checkpoint(update);
}
// ===== Polling =====
// This method will poll the operation updates from the durable backend and return a future which completes
// when an update of the operation is received.
// This is useful for in-process waits. For example, we want to
// wait while another thread is still running, and we therefore are not
// re-invoked because we never suspended.
public CompletableFuture<Operation> pollForOperationUpdates(String operationId) {
return checkpointBatcher.pollForUpdate(operationId);
}
public CompletableFuture<Operation> pollForOperationUpdates(String operationId, Duration delay) {
return checkpointBatcher.pollForUpdate(operationId, delay);
}
// ===== Utilities =====
/** Shutdown the checkpoint batcher. */
@Override
public void close() {
checkpointBatcher.shutdown();
}
public static boolean isTerminalStatus(OperationStatus status) {
return status == OperationStatus.SUCCEEDED
|| status == OperationStatus.FAILED
|| status == OperationStatus.CANCELLED
|| status == OperationStatus.TIMED_OUT
|| status == OperationStatus.STOPPED;
}
public void terminateExecution(UnrecoverableDurableExecutionException exception) {
executionExceptionFuture.completeExceptionally(exception);
throw exception;
}
public void suspendExecution() {
var ex = new SuspendExecutionException();
executionExceptionFuture.completeExceptionally(ex);
throw ex;
}
/**
* return a future that completes when userFuture completes successfully or the execution is terminated or
* suspended.
*
* @param userFuture user provided function
* @return a future of userFuture result if userFuture completes successfully, a user exception if userFuture
* completes with an exception, a SuspendExecutionException if the execution is suspended, or an
* UnrecoverableDurableExecutionException if the execution is terminated.
*/
public <T> CompletableFuture<T> runUntilCompleteOrSuspend(CompletableFuture<T> userFuture) {
return CompletableFuture.anyOf(userFuture, executionExceptionFuture).thenApply(v -> {
// reaches here only if userFuture complete successfully
if (userFuture.isDone()) {
return userFuture.join();
}
return null;
});
}
}