pymllm.executor.cuda_graph_runner

CUDA-graph accelerated forward pass for decode steps.

Captures CUDA graphs for a set of discrete batch sizes so that the decode forward pass can be replayed without CPU-side kernel-launch overhead.

CudaGraphRunner for pymllm’s single-GPU architecture. Handles:

  • Pre-allocated input buffers (avoids per-step allocations)

  • CUDA-graph capture for each batch size

  • Optional torch.compile integration

  • Graph replay with padding to the nearest captured batch size

Typical lifecycle:

runner = CudaGraphRunner(model_runner)   # captures all batch sizes

# --- inside the inference loop ---
if runner.can_run(forward_batch):
    logits_output = runner.replay(forward_batch)
else:
    logits_output = model_runner.forward(forward_batch)

Integration with ModelRunner

The ModelRunner owns the CudaGraphRunner and delegates decode batches to it when the batch size is within the captured range. The CudaGraphRunner calls attn_backend.init_forward_metadata_*_cuda_graph directly (bypassing the normal init_forward_metadata path) so that FlashInfer’s per-batch planning is recorded inside the graph.

Attributes

Classes

CudaGraphRunner

Captures and replays CUDA graphs for decode-step forward passes.

Functions

get_global_graph_memory_pool()

Return the shared CUDA graph memory pool handle.

set_global_graph_memory_pool(pool)

Set the shared CUDA graph memory pool handle.

is_capture_mode()

Return True if a CUDA-graph capture is in progress.

model_capture_mode()

Context manager that sets the global capture-mode flag.

freeze_gc()

Freeze the garbage collector during CUDA-graph capture.

Module Contents

pymllm.executor.cuda_graph_runner.logger
pymllm.executor.cuda_graph_runner.get_global_graph_memory_pool()

Return the shared CUDA graph memory pool handle.

Return type:

Optional[tuple]

pymllm.executor.cuda_graph_runner.set_global_graph_memory_pool(pool)

Set the shared CUDA graph memory pool handle.

Parameters:

pool (tuple)

Return type:

None

pymllm.executor.cuda_graph_runner.is_capture_mode()

Return True if a CUDA-graph capture is in progress.

Return type:

bool

pymllm.executor.cuda_graph_runner.model_capture_mode()

Context manager that sets the global capture-mode flag.

pymllm.executor.cuda_graph_runner.freeze_gc()

Freeze the garbage collector during CUDA-graph capture.

GC activity during capture can interfere with the recorded stream ordering. This context manager collects garbage before capture, freezes all surviving objects, and unfreezes + re-collects afterwards.

class pymllm.executor.cuda_graph_runner.CudaGraphRunner(model_runner)

Captures and replays CUDA graphs for decode-step forward passes.

This class is the pymllm equivalent of sglang’s CudaGraphRunner, stripped of distributed, speculative-decoding, LoRA, mamba, TBO, and piecewise-graph complexities.

Parameters:

model_runner (pymllm.executor.model_runner.ModelRunner) – The owning ModelRunner. Must have been fully initialised before the CudaGraphRunner is constructed.

model_runner
device
graphs: Dict[int, torch.cuda.CUDAGraph]
output_buffers: Dict[int, pymllm.executor.model_runner.LogitsProcessorOutput]
enable_torch_compile: bool
torch_compile_max_bs: int
capture_bs: List[int]
compile_bs: List[int]
max_bs: int
seq_len_fill_value: int
buffers: _InputBuffers
can_run(forward_batch)

Return True if the batch can be run via CUDA graph replay.

The batch must be a decode (or idle) batch whose size does not exceed the largest captured batch size.

Parameters:

forward_batch (pymllm.engine.forward_batch.ForwardBatch)

Return type:

bool

capture()

Capture CUDA graphs for every batch size in capture_bs.

Iterates in reverse order (largest first) so that the GPU memory pool allocated for the largest graph is reused by smaller ones.

Return type:

None

replay(forward_batch)

Replay a captured CUDA graph for the given decode batch.

The batch is padded to the nearest captured size, inputs are copied into the pre-allocated buffers, the graph is replayed, and the output is sliced back to the real batch size.

Parameters:

forward_batch (pymllm.engine.forward_batch.ForwardBatch) – The decode batch from the scheduler.

Returns:

The logits for the real (un-padded) sequences.

Return type:

LogitsProcessorOutput

shutdown()

Release all captured CUDA graphs and associated buffers.

Return type:

None