CheckpointBatcher.java

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.execution;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.lambda.model.CheckpointUpdatedExecutionState;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.DurableConfig;
import software.amazon.lambda.durable.retry.PollingStrategies;
import software.amazon.lambda.durable.retry.PollingStrategy;

/**
 * Package-private checkpoint manager for batching and queueing checkpoint API calls.
 *
 * <p>Single responsibility: Queue and batch checkpoint requests efficiently. Uses a Consumer to notify when checkpoints
 * complete, avoiding cyclic dependency.
 */
class CheckpointBatcher {
    private static final int MAX_BATCH_SIZE_BYTES = 750 * 1024; // 750KB
    private static final int MAX_ITEM_COUNT = 100; // max updates in one batch
    private static final Logger logger = LoggerFactory.getLogger(CheckpointBatcher.class);

    private final Consumer<List<Operation>> callback;
    private final String durableExecutionArn;
    private final Map<String, List<CompletableFuture<Operation>>> pollingFutures = new ConcurrentHashMap<>();
    private final ApiRequestBatcher<OperationUpdate> checkpointApiRequestBatcher;
    private final DurableConfig config;
    private String checkpointToken;

    CheckpointBatcher(
            DurableConfig config,
            String durableExecutionArn,
            String checkpointToken,
            Consumer<List<Operation>> callback) {
        this.config = config;
        this.durableExecutionArn = durableExecutionArn;
        this.callback = callback;
        this.checkpointToken = checkpointToken;
        this.checkpointApiRequestBatcher = new ApiRequestBatcher<>(
                MAX_ITEM_COUNT, MAX_BATCH_SIZE_BYTES, CheckpointBatcher::estimateSize, this::checkpointBatch);
    }

    /** Queues a checkpoint request for batched execution */
    CompletableFuture<Void> checkpoint(OperationUpdate update) {
        logger.debug("Checkpoint request received: Action {}", update.action());
        return checkpointApiRequestBatcher.submit(update, config.getCheckpointDelay());
    }

    /** Polls for updates of the specified operation with preconfigured intervals */
    CompletableFuture<Operation> pollForUpdate(String operationId) {
        return pollForUpdate(operationId, config.getPollingStrategy());
    }

    /** Polls for updates of the specified operation with specified delay */
    CompletableFuture<Operation> pollForUpdate(String operationId, Duration delay) {
        return pollForUpdate(operationId, PollingStrategies.fixedDelay(delay));
    }

    /** Polls for updates of the specified operation with specified polling strategy */
    CompletableFuture<Operation> pollForUpdate(String operationId, PollingStrategy pollingStrategy) {
        logger.debug("Polling request received: operation id {}", operationId);
        var future = new CompletableFuture<Operation>();
        synchronized (pollingFutures) {
            // register the future in pollingFutures, which will be completed by the polling thread
            pollingFutures
                    .computeIfAbsent(operationId, k -> Collections.synchronizedList(new ArrayList<>()))
                    .add(future);
        }
        pollForUpdateInternal(future, 0, pollingStrategy);
        return future;
    }

    private CompletableFuture<Void> pollForUpdateInternal(
            CompletableFuture<Operation> future, int attempt, PollingStrategy pollingStrategy) {
        return checkpointApiRequestBatcher
                .submit(null, pollingStrategy.computeDelay(attempt))
                .thenCompose(v -> {
                    if (future.isDone()) {
                        return CompletableFuture.completedFuture(null);
                    }
                    return pollForUpdateInternal(future, attempt + 1, pollingStrategy);
                });
    }

    /** Cancels all polling futures and waits for all pending checkpoint requests to complete */
    void shutdown() {
        // complete all polling futures with an exception
        List<List<CompletableFuture<Operation>>> allFutures;
        synchronized (pollingFutures) {
            allFutures = new ArrayList<>(pollingFutures.values());
            pollingFutures.clear();
        }

        for (var futures : allFutures) {
            futures.forEach(f -> f.completeExceptionally(new IllegalStateException("CheckpointManager shutdown")));
        }

        // wait for all non-polling checkpoint requests to complete
        checkpointApiRequestBatcher.shutdown();
    }

    /**
     * Calling GetExecutionState API to get all pages of operations given CheckpointUpdatedExecutionState(operations,
     * nextMarker)
     */
    List<Operation> fetchAllPages(CheckpointUpdatedExecutionState checkpointUpdatedExecutionState) {
        List<Operation> operations = new ArrayList<>();
        if (checkpointUpdatedExecutionState == null) {
            return operations;
        }
        if (checkpointUpdatedExecutionState.operations() != null) {
            operations.addAll(checkpointUpdatedExecutionState.operations());
        }
        var nextMarker = checkpointUpdatedExecutionState.nextMarker();
        while (nextMarker != null && !nextMarker.isEmpty()) {
            var startTime = System.nanoTime();
            var response = config.getDurableExecutionClient()
                    .getExecutionState(durableExecutionArn, checkpointToken, nextMarker);
            logger.debug(
                    "Durable getExecutionState API called (latency={}ns): {}.",
                    System.nanoTime() - startTime,
                    response);
            operations.addAll(response.operations());
            nextMarker = response.nextMarker();
        }
        return operations;
    }

    private void checkpointBatch(List<OperationUpdate> updates) {
        synchronized (pollingFutures) {
            // filter the null values from pollers
            var request = updates.stream().filter(Objects::nonNull).toList();

            if (pollingFutures.isEmpty() && request.isEmpty()) {
                // ignore the batch if no pollers and no data to checkpoint
                return;
            }

            var startTime = System.nanoTime();
            logger.debug("Calling durable checkpoint API with {} updates: {}", updates.size(), request);
            var response = config.getDurableExecutionClient().checkpoint(durableExecutionArn, checkpointToken, request);
            logger.debug("Durable checkpoint API called (latency={}ns): {}.", System.nanoTime() - startTime, response);

            // Notify callback of completion
            checkpointToken = response.checkpointToken();
            if (response.newExecutionState() != null) {
                // fetch all pages of operations
                var operations = fetchAllPages(response.newExecutionState());

                var processStartTime = System.nanoTime();
                int completedFutures = 0;
                logger.debug(
                        "Processing {} operations. ({} pending pollers)", operations.size(), pollingFutures.size());
                // call the callback
                callback.accept(operations);

                // complete the registered pollingFutures
                for (var operation : operations) {
                    var pollers = pollingFutures.remove(operation.id());
                    if (pollers != null) {
                        completedFutures += pollers.size();
                        pollers.forEach(poller -> poller.complete(operation));
                    }
                }
                logger.debug(
                        "{} operations processed and {} pollers completed (latency={}ns). ",
                        operations.size(),
                        completedFutures,
                        System.nanoTime() - processStartTime);
            }
        }
    }

    private static int estimateSize(OperationUpdate update) {
        if (update == null) {
            return 0;
        }
        return update.id().length()
                + update.type().toString().length()
                + update.action().toString().length()
                + (update.payload() != null ? update.payload().length() : 0)
                + 100;
    }
}