LocalMemoryExecutionClient.java

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.testing.local;

import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import software.amazon.awssdk.services.lambda.model.CheckpointDurableExecutionResponse;
import software.amazon.awssdk.services.lambda.model.CheckpointUpdatedExecutionState;
import software.amazon.awssdk.services.lambda.model.GetDurableExecutionStateResponse;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationAction;
import software.amazon.awssdk.services.lambda.model.OperationStatus;
import software.amazon.awssdk.services.lambda.model.OperationType;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.TypeToken;
import software.amazon.lambda.durable.client.DurableExecutionClient;
import software.amazon.lambda.durable.model.DurableExecutionOutput;
import software.amazon.lambda.durable.serde.SerDes;
import software.amazon.lambda.durable.testing.TestOperation;
import software.amazon.lambda.durable.testing.TestResult;

/**
 * In-memory implementation of {@link DurableExecutionClient} for local testing. Stores operations and checkpoint state
 * in memory, simulating the durable execution backend without AWS infrastructure.
 */
public class LocalMemoryExecutionClient implements DurableExecutionClient {
    // use LinkedHashMap to keep insertion order
    private final Map<String, Operation> existingOperations = Collections.synchronizedMap(new LinkedHashMap<>());
    private final EventProcessor eventProcessor = new EventProcessor();
    private final List<OperationUpdate> operationUpdates = new CopyOnWriteArrayList<>();
    private final Map<String, Operation> updatedOperations = new HashMap<>();

    @Override
    public CheckpointDurableExecutionResponse checkpoint(String arn, String token, List<OperationUpdate> updates) {
        operationUpdates.addAll(updates);
        updates.forEach(this::applyUpdate);

        var newToken = UUID.randomUUID().toString();

        CheckpointDurableExecutionResponse response;
        synchronized (updatedOperations) {
            response = CheckpointDurableExecutionResponse.builder()
                    .checkpointToken(newToken)
                    .newExecutionState(CheckpointUpdatedExecutionState.builder()
                            .operations(updatedOperations.values())
                            .build())
                    .build();

            // updatedOperations was copied into response, so clearing it is safe here
            updatedOperations.clear();
        }
        return response;
    }

    @Override
    public GetDurableExecutionStateResponse getExecutionState(String arn, String checkpointToken, String marker) {
        // local runner doesn't use this API at all
        throw new UnsupportedOperationException("getExecutionState is not supported");
    }

    /** Get all operation updates that have been sent to this client. Useful for testing and verification. */
    public List<OperationUpdate> getOperationUpdates() {
        return List.copyOf(operationUpdates);
    }

    /**
     * Advance all operations (simulates time passing for retries/waits).
     *
     * @return true if any operations were advanced, false otherwise
     */
    public boolean advanceTime() {
        var hasOperationsAdvanced = new AtomicBoolean(false);
        // forEach is safe as we're not adding or removing keys here
        existingOperations.forEach((key, op) -> {
            if (op.type() == OperationType.STEP && op.status() == OperationStatus.PENDING) {
                applyResult(op, OperationResult.ready());
                hasOperationsAdvanced.set(true);
            }

            if (op.type() == OperationType.WAIT && op.status() == OperationStatus.STARTED) {
                applyResult(op, OperationResult.succeeded(null));
                hasOperationsAdvanced.set(true);
            }
        });
        return hasOperationsAdvanced.get();
    }

    /** Completes a chained invoke operation with the given result, simulating a child Lambda response. */
    public void completeChainedInvoke(String name, OperationResult result) {
        var op = getOperationByName(name);
        if (op == null) {
            throw new IllegalStateException("Operation not found: " + name);
        }
        if (op.type() != OperationType.CHAINED_INVOKE || op.status() != OperationStatus.STARTED) {
            throw new IllegalStateException("Operation is not a CHAINED_INVOKE or not in STARTED state");
        }
        applyResult(op, result);
    }

    /** Returns the operation with the given name, or null if not found. */
    public Operation getOperationByName(String name) {
        return existingOperations.values().stream()
                .filter(op -> name.equals(op.name()))
                .findFirst()
                .orElse(null);
    }

    /** Returns all operations currently stored. */
    public List<Operation> getAllOperations() {
        return existingOperations.values().stream().toList();
    }

    /** Build TestResult from current state. */
    public <O> TestResult<O> toTestResult(DurableExecutionOutput output, TypeToken<O> resultType, SerDes serDes) {
        var testOperations = existingOperations.values().stream()
                .filter(op -> op.type() != OperationType.EXECUTION)
                .map(op -> new TestOperation(op, eventProcessor.getEventsForOperation(op.id()), serDes))
                .toList();
        return new TestResult<>(
                output.status(),
                output.result(),
                output.error(),
                testOperations,
                eventProcessor.getAllEvents(),
                resultType,
                serDes);
    }

    /** Simulate checkpoint failure by forcing an operation into STARTED state */
    public void resetCheckpointToStarted(String stepName) {
        var op = getOperationByName(stepName);
        if (op == null) {
            throw new IllegalStateException("Operation not found: " + stepName);
        }
        var startedOp = op.toBuilder().status(OperationStatus.STARTED).build();
        updateOperation(null, startedOp);
    }

    /** Simulate fire-and-forget checkpoint loss by removing the operation entirely */
    public void simulateFireAndForgetCheckpointLoss(String stepName) {
        var op = getOperationByName(stepName);
        if (op == null) {
            throw new IllegalStateException("Operation not found: " + stepName);
        }
        existingOperations.remove(op.id());
        synchronized (updatedOperations) {
            updatedOperations.remove(op.id());
        }
    }

    private void applyUpdate(OperationUpdate update) {
        var existingOp = existingOperations.get(update.id());
        var updatedOp = OperationProcessor.applyUpdate(update, existingOp);
        updateOperation(update, updatedOp);
    }

    /** Get callback ID for a named callback operation. */
    public String getCallbackId(String operationName) {
        var op = getOperationByName(operationName);
        if (op == null || op.callbackDetails() == null) {
            return null;
        }
        return op.callbackDetails().callbackId();
    }

    /** Simulate external system completing callback. */
    public void completeCallback(String callbackId, OperationResult result) {
        var op = findOperationByCallbackId(callbackId);
        if (op == null) {
            throw new IllegalStateException("Callback not found: " + callbackId);
        }
        if (op.type() != OperationType.CALLBACK || op.status() != OperationStatus.STARTED) {
            throw new IllegalStateException("Operation is not a CALLBACK or not in STARTED state");
        }

        applyResult(op, result);
    }

    private void applyResult(Operation op, OperationResult result) {
        // derive a possible action from the target status
        OperationAction action = deriveAction(result.operationStatus());
        if (action != null) {
            var update = OperationUpdate.builder()
                    .id(op.id())
                    .name(op.name())
                    .type(op.type())
                    .action(action)
                    .parentId(op.parentId())
                    .payload(result.result())
                    .error(result.error())
                    .build();
            applyUpdate(update);
        } else if (result.operationStatus() == OperationStatus.TIMED_OUT
                || result.operationStatus() == OperationStatus.STOPPED
                || result.operationStatus() == OperationStatus.READY) {
            var newOp = OperationProcessor.applyResult(op, result);
            updateOperation(null, newOp);
        } else {
            throw new IllegalStateException("Unsupported OperationStatus in result: " + result.operationStatus());
        }
    }

    private static OperationAction deriveAction(OperationStatus status) {
        return switch (status) {
            case STARTED -> OperationAction.START;
            case SUCCEEDED -> OperationAction.SUCCEED;
            case FAILED -> OperationAction.FAIL;
            case PENDING -> OperationAction.RETRY;
            case CANCELLED -> OperationAction.CANCEL;
            case READY, TIMED_OUT, STOPPED -> null; // no action for these operation statuses
            case UNKNOWN_TO_SDK_VERSION -> OperationAction.UNKNOWN_TO_SDK_VERSION; // Todo: Check this
        };
    }

    private Operation findOperationByCallbackId(String callbackId) {
        return existingOperations.values().stream()
                .filter(op -> op.callbackDetails() != null
                        && callbackId.equals(op.callbackDetails().callbackId()))
                .findFirst()
                .orElse(null);
    }

    private void updateOperation(OperationUpdate update, Operation op) {
        // update can be null when an operation is updated without an OperationUpdate
        if (update == null) {
            eventProcessor.processUpdate(op);
        } else {
            eventProcessor.processUpdate(update, op);
        }
        existingOperations.put(op.id(), op);
        synchronized (updatedOperations) {
            updatedOperations.put(op.id(), op);
        }
    }
}