BaseDurableOperation.java

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.operation;

import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.lambda.model.ErrorObject;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationType;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.DurableContext;
import software.amazon.lambda.durable.DurableFuture;
import software.amazon.lambda.durable.TypeToken;
import software.amazon.lambda.durable.exception.IllegalDurableOperationException;
import software.amazon.lambda.durable.exception.NonDeterministicExecutionException;
import software.amazon.lambda.durable.exception.SerDesException;
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
import software.amazon.lambda.durable.execution.ExecutionManager;
import software.amazon.lambda.durable.execution.ThreadContext;
import software.amazon.lambda.durable.execution.ThreadType;
import software.amazon.lambda.durable.serde.SerDes;
import software.amazon.lambda.durable.util.ExceptionHelper;

/**
 * Base class for all durable operations (STEP, WAIT, etc.).
 *
 * <p>Key methods:
 *
 * <ul>
 *   <li>{@code execute()} starts the operation (returns immediately)
 *   <li>{@code get()} blocks until complete and returns the result
 * </ul>
 *
 * <p>The separation allows:
 *
 * <ul>
 *   <li>Starting multiple async operations quickly
 *   <li>Blocking on results later when needed
 *   <li>Proper thread coordination via future
 * </ul>
 */
public abstract class BaseDurableOperation<T> implements DurableFuture<T> {
    private static final Logger logger = LoggerFactory.getLogger(BaseDurableOperation.class);

    private final String operationId;
    private final String name;
    private final OperationType operationType;
    private final ExecutionManager executionManager;
    private final TypeToken<T> resultTypeToken;
    private final SerDes resultSerDes;
    protected final CompletableFuture<Void> completionFuture;
    private final DurableContext durableContext;

    protected BaseDurableOperation(
            String operationId,
            String name,
            OperationType operationType,
            TypeToken<T> resultTypeToken,
            SerDes resultSerDes,
            DurableContext durableContext) {
        this.operationId = operationId;
        this.name = name;
        this.durableContext = durableContext;
        this.operationType = operationType;
        this.executionManager = durableContext.getExecutionManager();
        this.resultTypeToken = resultTypeToken;
        this.resultSerDes = resultSerDes;

        this.completionFuture = new CompletableFuture<>();

        // register this operation in ExecutionManager so that the operation can receive updates from ExecutionManager
        executionManager.registerOperation(this);
    }

    /** Gets the unique identifier for this operation. */
    public String getOperationId() {
        return operationId;
    }

    /** Gets the operation name (maybe null). */
    public String getName() {
        return name;
    }

    /** Gets the parent context. */
    protected DurableContext getContext() {
        return durableContext;
    }

    /** Gets the operation type */
    public OperationType getType() {
        return operationType;
    }

    /** Starts the operation, processes the operation updates from backend. Does not block. */
    public void execute() {
        var existing = getOperation();

        if (existing != null) {
            validateReplay(existing);
            replay(existing);
        } else {
            start();
        }
    }

    /** Starts the operation. */
    protected abstract void start();

    /** Replays the operation. */
    protected abstract void replay(Operation existing);

    /**
     * Gets the Operation from ExecutionManager and update the replay state from REPLAY to EXECUTE if operation is not
     * found. Operation IDs are globally unique (prefixed for child contexts), so no parentId is needed for lookups.
     *
     * @return the operation if found, otherwise null
     */
    protected Operation getOperation() {
        return executionManager.getOperationAndUpdateReplayState(getOperationId());
    }

    /**
     * Gets the direct child Operations of a give context operation.
     *
     * @param operationId the operation id of the context
     * @return list of the child Operations
     */
    protected List<Operation> getChildOperations(String operationId) {
        return executionManager.getChildOperations(operationId);
    }

    /**
     * Checks if it's called from a Step.
     *
     * @throws IllegalDurableOperationException if it's in a step
     */
    private void validateCurrentThreadType() {
        ThreadType current = getCurrentThreadContext().threadType();
        if (current == ThreadType.STEP) {
            var message = String.format(
                    "Nested %s operation is not supported on %s from within a %s execution.",
                    getType(), getName(), current);
            // terminate execution and throw the exception
            terminateExecutionWithIllegalDurableOperationException(message);
        }
    }

    /** Checks if this operation is completed */
    protected boolean isOperationCompleted() {
        return completionFuture.isDone();
    }

    /** Waits for the operation to complete and suspends the execution if no active thread is running */
    protected Operation waitForOperationCompletion() {

        validateCurrentThreadType();

        var threadContext = getCurrentThreadContext();

        // It's important that we synchronize access to the future. Otherwise, a race condition could happen if the
        // completionFuture is completed by a user thread (a step or child context thread) when the execution here
        // is between `isOperationCompleted` and `thenRun`.
        synchronized (completionFuture) {
            if (!isOperationCompleted()) {
                // Operation not done yet
                logger.trace(
                        "deregistering thread {} when waiting for operation {} ({}) to complete ({})",
                        threadContext.threadId(),
                        getOperation(),
                        getType(),
                        completionFuture);

                // Add a completion stage to completionFuture so that when the completionFuture is completed,
                // it will register the current Context thread synchronously to make sure it is always registered
                // strictly before the execution thread (Step or child context) is deregistered.
                completionFuture.thenRun(() -> registerActiveThread(threadContext.threadId()));

                // Deregister the current thread to allow suspension
                executionManager.deregisterActiveThread(threadContext.threadId());
            }
        }

        // Block until operation completes. No-op if the future is already completed.
        completionFuture.join();

        // Get result based on status
        var op = getOperation();
        if (op == null) {
            terminateExecutionWithIllegalDurableOperationException(
                    String.format("%s operation not found: %s", getType(), getOperationId()));
        }
        return op;
    }

    /** Receives operation updates from ExecutionManager and updates the internal state of the operation */
    public void onCheckpointComplete(Operation operation) {
        if (ExecutionManager.isTerminalStatus(operation.status())) {
            // This method handles only terminal status updates. Override this method if a DurableOperation needs to
            // handle other updates.
            logger.trace("In onCheckpointComplete, completing operation {} ({})", operationId, completionFuture);
            // It's important that we synchronize access to the future, otherwise the processing could happen
            // on someone else's thread and cause a race condition.
            synchronized (completionFuture) {
                // Completing the future here will also run any other completion stages that have been attached
                // to the future. In our case, other contexts may have attached a function to reactivate themselves,
                // so they will definitely have a chance to reactivate before we finish completing and deactivating
                // whatever operations were just checkpointed.
                completionFuture.complete(null);
            }
        }
    }

    /** Marks the operation as already completed (in replay). */
    protected void markAlreadyCompleted() {
        // When the operation is already completed in a replay, we complete completionFuture immediately
        // so that the `get` method will be unblocked and the context thread will be registered
        logger.trace("In markAlreadyCompleted, completing operation: {} ({}).", operationId, completionFuture);

        // It's important that we synchronize access to the future, otherwise the processing could happen
        // on someone else's thread and cause a race condition.
        synchronized (completionFuture) {
            completionFuture.complete(null);
        }
    }

    // terminate the execution
    protected T terminateExecution(UnrecoverableDurableExecutionException exception) {
        executionManager.terminateExecution(exception);
        // Exception is already thrown from above. Keep the throw statement below to make tests happy
        throw exception;
    }

    protected T terminateExecutionWithIllegalDurableOperationException(String message) {
        return terminateExecution(new IllegalDurableOperationException(message));
    }

    // advanced thread and context control
    protected void registerActiveThread(String threadId) {
        executionManager.registerActiveThread(threadId);
    }

    protected ThreadContext getCurrentThreadContext() {
        return executionManager.getCurrentThreadContext();
    }

    // polling and checkpointing
    protected CompletableFuture<Operation> pollForOperationUpdates() {
        return executionManager.pollForOperationUpdates(operationId);
    }

    protected CompletableFuture<Operation> pollForOperationUpdates(Duration delay) {
        return executionManager.pollForOperationUpdates(operationId, delay);
    }

    protected void sendOperationUpdate(OperationUpdate.Builder builder) {
        sendOperationUpdateAsync(builder).join();
    }

    protected CompletableFuture<Void> sendOperationUpdateAsync(OperationUpdate.Builder builder) {
        return executionManager.sendOperationUpdate(builder.id(operationId)
                .name(name)
                .type(operationType)
                .parentId(durableContext.getContextId())
                .build());
    }

    // serialization/deserialization utilities
    protected T deserializeResult(String result) {
        try {
            return resultSerDes.deserialize(result, resultTypeToken);
        } catch (SerDesException e) {
            logger.warn(
                    "Failed to deserialize {} result for operation name '{}'. Ensure the result is properly encoded.",
                    getType(),
                    getName());
            throw e;
        }
    }

    protected String serializeResult(T result) {
        return resultSerDes.serialize(result);
    }

    protected ErrorObject serializeException(Throwable throwable) {
        return ExceptionHelper.buildErrorObject(throwable, resultSerDes);
    }

    protected Throwable deserializeException(ErrorObject errorObject) {
        Throwable original = null;
        if (errorObject == null) {
            return original;
        }
        var errorType = errorObject.errorType();
        var errorData = errorObject.errorData();

        if (errorType == null) {
            return original;
        }
        try {

            Class<?> exceptionClass = Class.forName(errorType);
            if (Throwable.class.isAssignableFrom(exceptionClass)) {
                original =
                        resultSerDes.deserialize(errorData, TypeToken.get(exceptionClass.asSubclass(Throwable.class)));

                if (original != null) {
                    original.setStackTrace(ExceptionHelper.deserializeStackTrace(errorObject.stackTrace()));
                }
            }
        } catch (ClassNotFoundException e) {
            logger.warn("Cannot re-construct original exception type. Falling back to generic StepFailedException.");
        } catch (SerDesException e) {
            logger.warn("Cannot deserialize original exception data. Falling back to generic StepFailedException.", e);
        }
        return original;
    }

    /** Validates that current operation matches checkpointed operation during replay. */
    protected void validateReplay(Operation checkpointed) {
        if (checkpointed == null || checkpointed.type() == null) {
            return; // First execution, no validation needed
        }

        if (!checkpointed.type().equals(getType())) {
            terminateExecution(new NonDeterministicExecutionException(String.format(
                    "Operation type mismatch for \"%s\". Expected %s, got %s",
                    operationId, checkpointed.type(), getType())));
        }

        if (!Objects.equals(checkpointed.name(), getName())) {
            terminateExecution(new NonDeterministicExecutionException(String.format(
                    "Operation name mismatch for \"%s\". Expected \"%s\", got \"%s\"",
                    operationId, checkpointed.name(), getName())));
        }
    }
}