LocalMemoryExecutionClient.java

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.testing;

import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import software.amazon.awssdk.services.lambda.model.*;
import software.amazon.lambda.durable.client.DurableExecutionClient;
import software.amazon.lambda.durable.model.DurableExecutionOutput;
import software.amazon.lambda.durable.serde.JacksonSerDes;
import software.amazon.lambda.durable.serde.SerDes;

/**
 * In-memory implementation of {@link DurableExecutionClient} for local testing. Stores operations and checkpoint state
 * in memory, simulating the durable execution backend without AWS infrastructure.
 */
public class LocalMemoryExecutionClient implements DurableExecutionClient {
    private final Map<String, Operation> operations = new ConcurrentHashMap<>();
    private final List<Event> allEvents = new CopyOnWriteArrayList<>();
    private final EventProcessor eventProcessor = new EventProcessor();
    private final SerDes serDes = new JacksonSerDes();
    private final AtomicReference<String> checkpointToken =
            new AtomicReference<>(UUID.randomUUID().toString());
    private final List<OperationUpdate> operationUpdates = new CopyOnWriteArrayList<>();

    @Override
    public CheckpointDurableExecutionResponse checkpoint(String arn, String token, List<OperationUpdate> updates) {
        operationUpdates.addAll(updates);
        updates.forEach(this::applyUpdate);

        var newToken = UUID.randomUUID().toString();
        checkpointToken.set(newToken);

        return CheckpointDurableExecutionResponse.builder()
                .checkpointToken(newToken)
                .newExecutionState(CheckpointUpdatedExecutionState.builder()
                        .operations(operations.values())
                        .build())
                .build();
    }

    @Override
    public GetDurableExecutionStateResponse getExecutionState(String arn, String checkpointToken, String marker) {
        return GetDurableExecutionStateResponse.builder()
                .operations(operations.values())
                .build();
    }

    /** Get all operation updates that have been sent to this client. Useful for testing and verification. */
    public List<OperationUpdate> getOperationUpdates() {
        return List.copyOf(operationUpdates);
    }

    /** Get all events in order. */
    public List<Event> getAllEvents() {
        return List.copyOf(allEvents);
    }

    /** Get events for a specific operation. */
    public List<Event> getEventsForOperation(String operationId) {
        return allEvents.stream().filter(e -> operationId.equals(e.id())).toList();
    }

    /**
     * Advance all operations (simulates time passing for retries/waits).
     *
     * @return true if any operations were advanced, false otherwise
     */
    public boolean advanceReadyOperations() {
        var replaced = new AtomicBoolean(false);
        operations.replaceAll((key, op) -> {
            if (op.status() == OperationStatus.PENDING) {
                replaced.set(true);
                return op.toBuilder().status(OperationStatus.READY).build();
            }
            if (op.status() == OperationStatus.STARTED && op.type() == OperationType.WAIT) {
                var succeededOp =
                        op.toBuilder().status(OperationStatus.SUCCEEDED).build();
                // Generate WaitSucceeded event
                var update = OperationUpdate.builder()
                        .id(op.id())
                        .name(op.name())
                        .type(OperationType.WAIT)
                        .action(OperationAction.SUCCEED)
                        .build();
                var event = eventProcessor.processUpdate(update, succeededOp);
                allEvents.add(event);
                replaced.set(true);
                return succeededOp;
            }
            return op;
        });
        return replaced.get();
    }

    /** Completes a chained invoke operation with the given result, simulating a child Lambda response. */
    public void completeChainedInvoke(String name, OperationResult result) {
        var op = getOperationByName(name);
        if (op == null) {
            throw new IllegalStateException("Operation not found: " + name);
        }
        if (op.type() == OperationType.CHAINED_INVOKE
                && op.status() == OperationStatus.STARTED
                && op.name().equals(name)) {
            var newOp = op.toBuilder()
                    .status(result.operationStatus())
                    .chainedInvokeDetails(ChainedInvokeDetails.builder()
                            .result(result.result())
                            .error(result.error())
                            .build())
                    .build();
            var update = OperationUpdate.builder()
                    .id(op.id())
                    .name(op.name())
                    .type(OperationType.CHAINED_INVOKE)
                    .action(
                            result.operationStatus() == OperationStatus.SUCCEEDED
                                    ? OperationAction.SUCCEED
                                    : OperationAction.FAIL)
                    .build();
            var event = eventProcessor.processUpdate(update, newOp);
            allEvents.add(event);
            operations.put(compositeKey(op.parentId(), op.id()), newOp);
        }
    }

    /** Returns the operation with the given name, or null if not found. */
    public Operation getOperationByName(String name) {
        return operations.values().stream()
                .filter(op -> name.equals(op.name()))
                .findFirst()
                .orElse(null);
    }

    /** Returns all operations currently stored. */
    public List<Operation> getAllOperations() {
        return operations.values().stream().toList();
    }

    /** Clears all operations and events, resetting the client to its initial state. */
    public void reset() {
        operations.clear();
        allEvents.clear();
    }

    /** Build TestResult from current state. */
    public <O> TestResult<O> toTestResult(DurableExecutionOutput output) {
        var testOperations = operations.values().stream()
                .filter(op -> op.type() != OperationType.EXECUTION)
                .map(op -> new TestOperation(op, getEventsForOperation(op.id()), serDes))
                .toList();
        return new TestResult<>(
                output.status(), output.result(), output.error(), testOperations, new ArrayList<>(allEvents), serDes);
    }

    /** Simulate checkpoint failure by forcing an operation into STARTED state */
    public void resetCheckpointToStarted(String stepName) {
        var op = getOperationByName(stepName);
        if (op == null) {
            throw new IllegalStateException("Operation not found: " + stepName);
        }
        var startedOp = op.toBuilder().status(OperationStatus.STARTED).build();
        operations.put(compositeKey(op.parentId(), op.id()), startedOp);
    }

    /** Simulate fire-and-forget checkpoint loss by removing the operation entirely */
    public void simulateFireAndForgetCheckpointLoss(String stepName) {
        var op = getOperationByName(stepName);
        if (op == null) {
            throw new IllegalStateException("Operation not found: " + stepName);
        }
        operations.remove(compositeKey(op.parentId(), op.id()));
    }

    private void applyUpdate(OperationUpdate update) {
        var operation = toOperation(update);
        var key = compositeKey(update.parentId(), update.id());
        operations.put(key, operation);

        var event = eventProcessor.processUpdate(update, operation);
        allEvents.add(event);
    }

    private static String compositeKey(String parentId, String operationId) {
        return (parentId != null ? parentId : "") + ":" + operationId;
    }

    private Operation toOperation(OperationUpdate update) {
        var builder = Operation.builder()
                .id(update.id())
                .name(update.name())
                .type(update.type())
                .subType(update.subType())
                .parentId(update.parentId())
                .status(deriveStatus(update.action()));

        switch (update.type()) {
            case WAIT -> builder.waitDetails(buildWaitDetails(update));
            case STEP -> builder.stepDetails(buildStepDetails(update));
            case CALLBACK -> builder.callbackDetails(buildCallbackDetails(update));
            case EXECUTION -> {} // No details needed for EXECUTION operations
            case CHAINED_INVOKE -> builder.chainedInvokeDetails(buildChainedInvokeDetails(update));
            case CONTEXT -> builder.contextDetails(buildContextDetails(update));
            case UNKNOWN_TO_SDK_VERSION ->
                throw new UnsupportedOperationException("UNKNOWN_TO_SDK_VERSION not supported");
        }

        return builder.build();
    }

    private ChainedInvokeDetails buildChainedInvokeDetails(OperationUpdate update) {
        if (update.chainedInvokeOptions() == null) {
            return null;
        }
        return ChainedInvokeDetails.builder()
                .result(update.payload())
                .error(update.error())
                .build();
    }

    private ContextDetails buildContextDetails(OperationUpdate update) {
        var detailsBuilder = ContextDetails.builder().result(update.payload()).error(update.error());

        if (update.contextOptions() != null
                && Boolean.TRUE.equals(update.contextOptions().replayChildren())) {
            detailsBuilder.replayChildren(true);
        }

        return detailsBuilder.build();
    }

    private WaitDetails buildWaitDetails(OperationUpdate update) {
        if (update.waitOptions() == null) return null;

        var scheduledEnd = Instant.now().plusSeconds(update.waitOptions().waitSeconds());
        return WaitDetails.builder().scheduledEndTimestamp(scheduledEnd).build();
    }

    private StepDetails buildStepDetails(OperationUpdate update) {
        var key = compositeKey(update.parentId(), update.id());
        var existingOp = operations.get(key);
        var existing = existingOp != null ? existingOp.stepDetails() : null;

        var detailsBuilder = existing != null ? existing.toBuilder() : StepDetails.builder();

        if (update.action() == OperationAction.RETRY || update.action() == OperationAction.FAIL) {
            var attempt = existing != null && existing.attempt() != null ? existing.attempt() + 1 : 1;
            detailsBuilder.attempt(attempt).error(update.error());
        }

        if (update.payload() != null) {
            detailsBuilder.result(update.payload());
        }

        return detailsBuilder.build();
    }

    private CallbackDetails buildCallbackDetails(OperationUpdate update) {
        var key = compositeKey(update.parentId(), update.id());
        var existingOp = operations.get(key);
        var existing = existingOp != null ? existingOp.callbackDetails() : null;

        // Preserve existing callbackId, or generate new one on START
        var callbackId =
                existing != null ? existing.callbackId() : UUID.randomUUID().toString();

        return CallbackDetails.builder()
                .callbackId(callbackId)
                .result(existing != null ? existing.result() : null)
                .build();
    }

    /** Get callback ID for a named callback operation. */
    public String getCallbackId(String operationName) {
        var op = getOperationByName(operationName);
        if (op == null || op.callbackDetails() == null) {
            return null;
        }
        return op.callbackDetails().callbackId();
    }

    /** Simulate external system completing callback successfully. */
    public void completeCallback(String callbackId, String result) {
        var op = findOperationByCallbackId(callbackId);
        if (op == null) {
            throw new IllegalStateException("Callback not found: " + callbackId);
        }
        var updated = op.toBuilder()
                .status(OperationStatus.SUCCEEDED)
                .callbackDetails(op.callbackDetails().toBuilder().result(result).build())
                .build();
        operations.put(compositeKey(op.parentId(), op.id()), updated);
    }

    /** Simulate external system failing callback. */
    public void failCallback(String callbackId, ErrorObject error) {
        var op = findOperationByCallbackId(callbackId);
        if (op == null) {
            throw new IllegalStateException("Callback not found: " + callbackId);
        }
        var updated = op.toBuilder()
                .status(OperationStatus.FAILED)
                .callbackDetails(op.callbackDetails().toBuilder().error(error).build())
                .build();
        operations.put(compositeKey(op.parentId(), op.id()), updated);
    }

    /** Simulate callback timeout. */
    public void timeoutCallback(String callbackId) {
        var op = findOperationByCallbackId(callbackId);
        if (op == null) {
            throw new IllegalStateException("Callback not found: " + callbackId);
        }
        var updated = op.toBuilder().status(OperationStatus.TIMED_OUT).build();
        operations.put(compositeKey(op.parentId(), op.id()), updated);
    }

    private Operation findOperationByCallbackId(String callbackId) {
        return operations.values().stream()
                .filter(op -> op.callbackDetails() != null
                        && callbackId.equals(op.callbackDetails().callbackId()))
                .findFirst()
                .orElse(null);
    }

    private OperationStatus deriveStatus(OperationAction action) {
        return switch (action) {
            case START -> OperationStatus.STARTED;
            case SUCCEED -> OperationStatus.SUCCEEDED;
            case FAIL -> OperationStatus.FAILED;
            case RETRY -> OperationStatus.PENDING;
            case CANCEL -> OperationStatus.CANCELLED;
            case UNKNOWN_TO_SDK_VERSION -> null; // Todo: Check this
        };
    }
}