pymllm.layers.attention.flashinfer_backend

FlashInfer attention backend for pymllm.

  • No model-runner object – constructor takes explicit scalar / tensor params.

  • No tensor-parallelism head splitting (handled at the model layer level).

  • No speculative decoding support.

  • KVPool API:
    • get_kv_buffer(layer_id) returns (k_buf, v_buf) each shaped [buf_len, num_heads, head_dim].

    • set_kv_buffer(layer_id, indices, k, v) – no scale arguments.

Supports:
  • Single-wrapper mode (full context, no sliding window)

  • Sliding-window mode (two wrappers: window + full)

  • CUDA-graph capture / replay for decode and target-verify passes.

Attributes

Classes

WrapperDispatch

Indicates which wrapper to use for a given attention layer.

DecodeMetadata

Per-batch metadata for a decode step.

PrefillMetadata

Per-batch metadata for a prefill / extend step.

FlashInferAttnBackend

FlashInfer-based attention backend for pymllm.

Functions

should_use_tensor_core(kv_cache_dtype, ...)

Return whether FlashInfer decode should use tensor cores.

Module Contents

pymllm.layers.attention.flashinfer_backend.logger
class pymllm.layers.attention.flashinfer_backend.WrapperDispatch(*args, **kwds)

Bases: enum.Enum

Indicates which wrapper to use for a given attention layer.

SLIDING_WINDOW
CROSS_ATTENTION
class pymllm.layers.attention.flashinfer_backend.DecodeMetadata

Per-batch metadata for a decode step.

decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
class pymllm.layers.attention.flashinfer_backend.PrefillMetadata

Per-batch metadata for a prefill / extend step.

prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
use_ragged: bool
extend_no_prefix: bool
pymllm.layers.attention.flashinfer_backend.should_use_tensor_core(kv_cache_dtype, num_attention_heads, num_kv_heads)

Return whether FlashInfer decode should use tensor cores.

For FP8 we always use tensor cores. For fp16 / bf16 we use them when the GQA group size (num_attention_heads / num_kv_heads) is ≥ 4, which fuses the head group with the token dimension in the MMA instruction.

Parameters:
  • kv_cache_dtype (torch.dtype)

  • num_attention_heads (int)

  • num_kv_heads (int)

Return type:

bool

class pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend(num_heads, num_kv_heads, head_dim, kv_cache_dtype, q_dtype, max_context_len, req_to_token, device, max_req_pool_size, sliding_window_size=None, skip_prefill=False, kv_indptr_buf=None, kv_last_page_len_buf=None, init_new_workspace=False)

Bases: pymllm.layers.attention.attention_backend.AttentionBackend

FlashInfer-based attention backend for pymllm.

This class does not depend on a ModelRunner object. Instead it takes all required configuration explicitly so that it can be constructed independently of any particular model runner.

Parameters:
  • num_heads (int) – Number of query heads per device (after any TP sharding).

  • num_kv_heads (int) – Number of KV heads per device.

  • head_dim (int) – Per-head dimension for Q and K.

  • kv_cache_dtype (torch.dtype) – torch.dtype of the KV cache (e.g. torch.float16).

  • q_dtype (torch.dtype) – torch.dtype of the query tensor.

  • max_context_len (int) – Maximum sequence length the model supports.

  • req_to_token (torch.Tensor) – The [max_reqs, max_context_len] int32 tensor from ReqToTokenPool.req_to_token.

  • device (torch.device) – Target device (e.g. torch.device("cuda"))

  • max_req_pool_size (int) – Maximum number of concurrent requests (= ReqToTokenPool.size). Used to pre-allocate kv_indptr / kv_last_page_len buffers.

  • sliding_window_size (Optional[int]) – When not None, enables sliding-window attention mode which allocates two wrapper sets (window + full context).

  • skip_prefill (bool) – When True, skip creating prefill wrappers (for backends that only perform decode, e.g. multi-step draft backends).

  • kv_indptr_buf (Optional[torch.Tensor]) – Optional pre-allocated kv_indptr buffer. Used when sharing buffers across multiple backend instances (e.g. multi-step draft).

  • kv_last_page_len_buf (Optional[torch.Tensor]) – Optional pre-allocated kv_last_page_len buffer.

  • init_new_workspace (bool) – When True allocate a fresh workspace buffer instead of reusing the global one.

num_heads
num_kv_heads
head_dim
kv_cache_dtype
q_dtype
max_context_len
req_to_token
device
skip_prefill = False
decode_use_tensor_cores
prefill_wrapper_ragged: flashinfer.BatchPrefillWithRaggedKVCacheWrapper | None = None
prefill_wrappers_paged: List[flashinfer.BatchPrefillWithPagedKVCacheWrapper] = []
decode_wrappers: List[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = []
indices_updater_decode
forward_metadata: DecodeMetadata | PrefillMetadata | None = None
decode_cuda_graph_metadata: dict
prefill_cuda_graph_metadata: dict
init_forward_metadata(forward_batch)

Prepare FlashInfer wrappers for the current batch.

Must be called once per batch before the model’s forward method.

Parameters:

forward_batch (pymllm.engine.forward_batch.ForwardBatch)

Return type:

None

forward_extend(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)

Run attention for a prefill / extend step.

Parameters:
Return type:

torch.Tensor

forward_decode(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)

Run attention for a decode step (one new token per sequence).

Parameters:
Return type:

torch.Tensor

get_cuda_graph_seq_len_fill_value()

Fill value used to pad seq_lens tensors for CUDA-graph capture.

Most backends use 1 (not 0) to avoid division-by-zero in attention kernels.

Return type:

int

init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf=None)

Allocate CUDA-graph shared state buffers.

Parameters:
  • max_bs (int)

  • max_num_tokens (int)

  • kv_indices_buf (Optional[torch.Tensor])

Return type:

None

init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode)

Set up metadata for CUDA-graph capture of a decode step.

Parameters:
Return type:

None

init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu)

Update metadata when replaying a CUDA graph for decode.

Parameters:
Return type:

None