pymllm.executor.cuda_graph_runner ================================= .. py:module:: pymllm.executor.cuda_graph_runner .. autoapi-nested-parse:: CUDA-graph accelerated forward pass for decode steps. Captures CUDA graphs for a set of discrete batch sizes so that the decode forward pass can be replayed without CPU-side kernel-launch overhead. ``CudaGraphRunner`` for pymllm's single-GPU architecture. Handles: * Pre-allocated input buffers (avoids per-step allocations) * CUDA-graph capture for each batch size * Optional ``torch.compile`` integration * Graph replay with padding to the nearest captured batch size Typical lifecycle:: runner = CudaGraphRunner(model_runner) # captures all batch sizes # --- inside the inference loop --- if runner.can_run(forward_batch): logits_output = runner.replay(forward_batch) else: logits_output = model_runner.forward(forward_batch) Integration with :class:`~pymllm.executor.model_runner.ModelRunner` ------------------------------------------------------------------- The ``ModelRunner`` owns the ``CudaGraphRunner`` and delegates decode batches to it when the batch size is within the captured range. The ``CudaGraphRunner`` calls ``attn_backend.init_forward_metadata_*_cuda_graph`` directly (bypassing the normal ``init_forward_metadata`` path) so that FlashInfer's per-batch planning is recorded inside the graph. Attributes ---------- .. autoapisummary:: pymllm.executor.cuda_graph_runner.logger Classes ------- .. autoapisummary:: pymllm.executor.cuda_graph_runner.CudaGraphRunner Functions --------- .. autoapisummary:: pymllm.executor.cuda_graph_runner.get_global_graph_memory_pool pymllm.executor.cuda_graph_runner.set_global_graph_memory_pool pymllm.executor.cuda_graph_runner.is_capture_mode pymllm.executor.cuda_graph_runner.model_capture_mode pymllm.executor.cuda_graph_runner.freeze_gc Module Contents --------------- .. py:data:: logger .. py:function:: get_global_graph_memory_pool() Return the shared CUDA graph memory pool handle. .. py:function:: set_global_graph_memory_pool(pool) Set the shared CUDA graph memory pool handle. .. py:function:: is_capture_mode() Return ``True`` if a CUDA-graph capture is in progress. .. py:function:: model_capture_mode() Context manager that sets the global capture-mode flag. .. py:function:: freeze_gc() Freeze the garbage collector during CUDA-graph capture. GC activity during capture can interfere with the recorded stream ordering. This context manager collects garbage before capture, freezes all surviving objects, and unfreezes + re-collects afterwards. .. py:class:: CudaGraphRunner(model_runner) Captures and replays CUDA graphs for decode-step forward passes. This class is the pymllm equivalent of sglang's ``CudaGraphRunner``, stripped of distributed, speculative-decoding, LoRA, mamba, TBO, and piecewise-graph complexities. :param model_runner: The owning :class:`~pymllm.executor.model_runner.ModelRunner`. Must have been fully initialised before the ``CudaGraphRunner`` is constructed. .. py:attribute:: model_runner .. py:attribute:: device .. py:attribute:: graphs :type: Dict[int, torch.cuda.CUDAGraph] .. py:attribute:: output_buffers :type: Dict[int, pymllm.executor.model_runner.LogitsProcessorOutput] .. py:attribute:: enable_torch_compile :type: bool .. py:attribute:: torch_compile_max_bs :type: int .. py:attribute:: capture_bs :type: List[int] .. py:attribute:: compile_bs :type: List[int] .. py:attribute:: max_bs :type: int .. py:attribute:: seq_len_fill_value :type: int .. py:attribute:: buffers :type: _InputBuffers .. py:method:: can_run(forward_batch) Return ``True`` if the batch can be run via CUDA graph replay. The batch must be a decode (or idle) batch whose size does not exceed the largest captured batch size. .. py:method:: capture() Capture CUDA graphs for every batch size in ``capture_bs``. Iterates in reverse order (largest first) so that the GPU memory pool allocated for the largest graph is reused by smaller ones. .. py:method:: replay(forward_batch) Replay a captured CUDA graph for the given decode batch. The batch is padded to the nearest captured size, inputs are copied into the pre-allocated buffers, the graph is replayed, and the output is sliced back to the real batch size. :param forward_batch: The decode batch from the scheduler. :returns: The logits for the real (un-padded) sequences. :rtype: LogitsProcessorOutput .. py:method:: shutdown() Release all captured CUDA graphs and associated buffers.