pymllm.executor.cuda_graph_runner¶
CUDA-graph accelerated forward pass for decode steps.
Captures CUDA graphs for a set of discrete batch sizes so that the decode forward pass can be replayed without CPU-side kernel-launch overhead.
CudaGraphRunner for pymllm’s single-GPU architecture. Handles:
Pre-allocated input buffers (avoids per-step allocations)
CUDA-graph capture for each batch size
Optional
torch.compileintegrationGraph replay with padding to the nearest captured batch size
Typical lifecycle:
runner = CudaGraphRunner(model_runner) # captures all batch sizes
# --- inside the inference loop ---
if runner.can_run(forward_batch):
logits_output = runner.replay(forward_batch)
else:
logits_output = model_runner.forward(forward_batch)
Integration with ModelRunner¶
The ModelRunner owns the CudaGraphRunner and delegates decode
batches to it when the batch size is within the captured range. The
CudaGraphRunner calls attn_backend.init_forward_metadata_*_cuda_graph
directly (bypassing the normal init_forward_metadata path) so that
FlashInfer’s per-batch planning is recorded inside the graph.
Attributes¶
Classes¶
Captures and replays CUDA graphs for decode-step forward passes. |
Functions¶
Return the shared CUDA graph memory pool handle. |
|
Set the shared CUDA graph memory pool handle. |
|
Return |
|
Context manager that sets the global capture-mode flag. |
|
Freeze the garbage collector during CUDA-graph capture. |
Module Contents¶
- pymllm.executor.cuda_graph_runner.logger¶
- pymllm.executor.cuda_graph_runner.get_global_graph_memory_pool()¶
Return the shared CUDA graph memory pool handle.
- Return type:
Optional[tuple]
- pymllm.executor.cuda_graph_runner.set_global_graph_memory_pool(pool)¶
Set the shared CUDA graph memory pool handle.
- Parameters:
pool (tuple)
- Return type:
None
- pymllm.executor.cuda_graph_runner.is_capture_mode()¶
Return
Trueif a CUDA-graph capture is in progress.- Return type:
bool
- pymllm.executor.cuda_graph_runner.model_capture_mode()¶
Context manager that sets the global capture-mode flag.
- pymllm.executor.cuda_graph_runner.freeze_gc()¶
Freeze the garbage collector during CUDA-graph capture.
GC activity during capture can interfere with the recorded stream ordering. This context manager collects garbage before capture, freezes all surviving objects, and unfreezes + re-collects afterwards.
- class pymllm.executor.cuda_graph_runner.CudaGraphRunner(model_runner)¶
Captures and replays CUDA graphs for decode-step forward passes.
This class is the pymllm equivalent of sglang’s
CudaGraphRunner, stripped of distributed, speculative-decoding, LoRA, mamba, TBO, and piecewise-graph complexities.- Parameters:
model_runner (pymllm.executor.model_runner.ModelRunner) – The owning
ModelRunner. Must have been fully initialised before theCudaGraphRunneris constructed.
- model_runner¶
- device¶
- graphs: Dict[int, torch.cuda.CUDAGraph]¶
- output_buffers: Dict[int, pymllm.executor.model_runner.LogitsProcessorOutput]¶
- enable_torch_compile: bool¶
- torch_compile_max_bs: int¶
- capture_bs: List[int]¶
- compile_bs: List[int]¶
- max_bs: int¶
- seq_len_fill_value: int¶
- buffers: _InputBuffers¶
- can_run(forward_batch)¶
Return
Trueif the batch can be run via CUDA graph replay.The batch must be a decode (or idle) batch whose size does not exceed the largest captured batch size.
- Parameters:
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
- Return type:
bool
- capture()¶
Capture CUDA graphs for every batch size in
capture_bs.Iterates in reverse order (largest first) so that the GPU memory pool allocated for the largest graph is reused by smaller ones.
- Return type:
None
- replay(forward_batch)¶
Replay a captured CUDA graph for the given decode batch.
The batch is padded to the nearest captured size, inputs are copied into the pre-allocated buffers, the graph is replayed, and the output is sliced back to the real batch size.
- Parameters:
forward_batch (pymllm.engine.forward_batch.ForwardBatch) – The decode batch from the scheduler.
- Returns:
The logits for the real (un-padded) sequences.
- Return type:
- shutdown()¶
Release all captured CUDA graphs and associated buffers.
- Return type:
None