ChildContextOperation.java

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.operation;

import static software.amazon.lambda.durable.execution.ExecutionManager.isTerminalStatus;

import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import software.amazon.awssdk.services.lambda.model.ContextOptions;
import software.amazon.awssdk.services.lambda.model.ErrorObject;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationAction;
import software.amazon.awssdk.services.lambda.model.OperationStatus;
import software.amazon.awssdk.services.lambda.model.OperationType;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.DurableContext;
import software.amazon.lambda.durable.TypeToken;
import software.amazon.lambda.durable.exception.CallbackFailedException;
import software.amazon.lambda.durable.exception.CallbackSubmitterException;
import software.amazon.lambda.durable.exception.CallbackTimeoutException;
import software.amazon.lambda.durable.exception.ChildContextFailedException;
import software.amazon.lambda.durable.exception.DurableOperationException;
import software.amazon.lambda.durable.exception.StepFailedException;
import software.amazon.lambda.durable.exception.StepInterruptedException;
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
import software.amazon.lambda.durable.execution.SuspendExecutionException;
import software.amazon.lambda.durable.model.OperationSubType;
import software.amazon.lambda.durable.serde.SerDes;
import software.amazon.lambda.durable.util.ExceptionHelper;

/**
 * Manages the lifecycle of a child execution context.
 *
 * <p>A child context runs a user function in a separate thread with its own operation counter and checkpoint log.
 * Operations within the child context use the child's context ID as their parentId.
 */
public class ChildContextOperation<T> extends BaseDurableOperation<T> {

    private static final int LARGE_RESULT_THRESHOLD = 256 * 1024;

    private final Function<DurableContext, T> function;
    private final ExecutorService userExecutor;
    private boolean replayChildContext;
    private T reconstructedResult;
    private final OperationSubType subType;

    public ChildContextOperation(
            String operationId,
            String name,
            Function<DurableContext, T> function,
            OperationSubType subType,
            TypeToken<T> resultTypeToken,
            SerDes resultSerDes,
            DurableContext durableContext) {
        super(operationId, name, OperationType.CONTEXT, resultTypeToken, resultSerDes, durableContext);
        this.function = function;
        this.userExecutor = getContext().getDurableConfig().getExecutorService();
        this.subType = subType;
    }

    /** Starts the operation. */
    @Override
    protected void start() {
        // First execution: fire-and-forget START checkpoint, then run
        sendOperationUpdateAsync(
                OperationUpdate.builder().action(OperationAction.START).subType(subType.getValue()));
        executeChildContext();
    }

    /** Replays the operation. */
    @Override
    protected void replay(Operation existing) {
        switch (existing.status()) {
            case SUCCEEDED -> {
                if (existing.contextDetails() != null
                        && Boolean.TRUE.equals(existing.contextDetails().replayChildren())) {
                    // Large result: re-execute child context to reconstruct result
                    replayChildContext = true;
                    executeChildContext();
                } else {
                    markAlreadyCompleted();
                }
            }
            case FAILED -> markAlreadyCompleted();
            case STARTED -> executeChildContext();
            default ->
                terminateExecutionWithIllegalDurableOperationException(
                        "Unexpected child context status: " + existing.status());
        }
    }

    private void executeChildContext() {
        // The operationId is already globally unique (prefixed by parent context path via
        // DurableContext.nextOperationId), so we use it directly as the contextId.
        // E.g., first level child context "hash(1)",
        //       second level child context "hash(hash(1)-2)",
        //       third level child context "hash(hash(hash(1)-2)-1)".
        var contextId = getOperationId();

        // Thread registration is intentionally split across two threads:
        // 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the
        //    parent can deregister and trigger suspension (race prevention).
        // 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside
        //    the child context know which context they belong to.
        // registerActiveThread is idempotent (no-op if already registered).
        registerActiveThread(contextId);

        Runnable userHandler = () -> {
            // use a try-with-resources to
            // - add thread id/type to thread local when the step starts
            // - clear logger properties when the step finishes
            try (var childContext = getContext().createChildContext(contextId, getName())) {
                try {
                    T result = function.apply(childContext);

                    handleChildContextSuccess(result);
                } catch (Throwable e) {
                    handleChildContextFailure(e);
                }
            }
        };

        // Execute user provided child context code in user-configured executor
        CompletableFuture.runAsync(userHandler, userExecutor);
    }

    private void handleChildContextSuccess(T result) {
        if (replayChildContext) {
            // Replaying a SUCCEEDED child with replayChildren=true — skip checkpointing.
            // Mark the completableFuture completed so get() doesn't block waiting for a checkpoint response.
            this.reconstructedResult = result;
            markAlreadyCompleted();
        } else {
            checkpointSuccess(result);
        }
    }

    private void checkpointSuccess(T result) {
        var serialized = serializeResult(result);
        var serializedBytes = serialized.getBytes(StandardCharsets.UTF_8);

        if (serializedBytes.length < LARGE_RESULT_THRESHOLD) {
            sendOperationUpdate(OperationUpdate.builder()
                    .action(OperationAction.SUCCEED)
                    .subType(subType.getValue())
                    .payload(serialized));
        } else {
            // Large result: checkpoint with empty payload + ReplayChildren flag.
            // Store the result so get() can return it directly without deserializing the empty payload.
            this.reconstructedResult = result;
            sendOperationUpdate(OperationUpdate.builder()
                    .action(OperationAction.SUCCEED)
                    .subType(subType.getValue())
                    .payload("")
                    .contextOptions(
                            ContextOptions.builder().replayChildren(true).build()));
        }
    }

    private void handleChildContextFailure(Throwable exception) {
        exception = ExceptionHelper.unwrapCompletableFuture(exception);
        if (exception instanceof SuspendExecutionException) {
            // Rethrow Error immediately — do not checkpoint
            ExceptionHelper.sneakyThrow(exception);
        }
        if (exception instanceof UnrecoverableDurableExecutionException) {
            terminateExecution((UnrecoverableDurableExecutionException) exception);
        }

        final ErrorObject errorObject;
        if (exception instanceof DurableOperationException opEx) {
            errorObject = opEx.getErrorObject();
        } else {
            errorObject = serializeException(exception);
        }

        sendOperationUpdate(OperationUpdate.builder()
                .action(OperationAction.FAIL)
                .subType(subType.getValue())
                .error(errorObject));
    }

    @Override
    public T get() {
        var op = waitForOperationCompletion();

        if (op.status() == OperationStatus.SUCCEEDED) {
            if (reconstructedResult != null) {
                return reconstructedResult;
            }
            var contextDetails = op.contextDetails();
            var result = (contextDetails != null) ? contextDetails.result() : null;
            return deserializeResult(result);
        } else {
            var contextDetails = op.contextDetails();
            var errorObject = (contextDetails != null) ? contextDetails.error() : null;

            // Attempt to reconstruct and throw the original exception
            Throwable original = deserializeException(errorObject);
            if (original != null) {
                ExceptionHelper.sneakyThrow(original);
            }

            // throw a general failed exception if a user exception is not reconstructed
            return switch (subType) {
                case WAIT_FOR_CALLBACK -> handleWaitForCallbackFailure(op);
                // todo: handle MAP/PARALLEL
                case MAP -> throw new ChildContextFailedException(op);
                case PARALLEL -> throw new ChildContextFailedException(op);
                case RUN_IN_CHILD_CONTEXT -> throw new ChildContextFailedException(op);
            };
        }
    }

    private T handleWaitForCallbackFailure(Operation op) {
        var childrenOps = getChildOperations(op.id());
        var callbackOp = childrenOps.stream()
                .filter(o -> o.type() == OperationType.CALLBACK)
                .findFirst()
                .orElse(null);
        var submitterOp = childrenOps.stream()
                .filter(o -> o.type() == OperationType.STEP)
                .findFirst()
                .orElse(null);
        if (callbackOp != null) {
            // if callback failed
            if (isTerminalStatus(callbackOp.status())) {
                switch (callbackOp.status()) {
                    case FAILED -> throw new CallbackFailedException(callbackOp);
                    case TIMED_OUT -> throw new CallbackTimeoutException(callbackOp);
                }
            }

            // if submitter failed
            if (submitterOp != null
                    && isTerminalStatus(submitterOp.status())
                    && submitterOp.status() != OperationStatus.SUCCEEDED) {
                var stepError = submitterOp.stepDetails().error();
                if (StepInterruptedException.isStepInterruptedException(stepError)) {
                    throw new CallbackSubmitterException(callbackOp, new StepInterruptedException(submitterOp));
                } else {
                    throw new CallbackSubmitterException(callbackOp, new StepFailedException(submitterOp));
                }
            }
        }

        throw new IllegalStateException("Unknown waitForCallback status");
    }
}