pymllm.layers.attention.flashinfer_backend¶
FlashInfer attention backend for pymllm.
No model-runner object – constructor takes explicit scalar / tensor params.
No tensor-parallelism head splitting (handled at the model layer level).
No speculative decoding support.
KVPoolAPI:
get_kv_buffer(layer_id)returns(k_buf, v_buf)each shaped[buf_len, num_heads, head_dim].
set_kv_buffer(layer_id, indices, k, v)– no scale arguments.
- Supports:
Single-wrapper mode (full context, no sliding window)
Sliding-window mode (two wrappers: window + full)
CUDA-graph capture / replay for decode and target-verify passes.
Attributes¶
Classes¶
Indicates which wrapper to use for a given attention layer. |
|
Per-batch metadata for a decode step. |
|
Per-batch metadata for a prefill / extend step. |
|
FlashInfer-based attention backend for pymllm. |
Functions¶
|
Return whether FlashInfer decode should use tensor cores. |
Module Contents¶
- pymllm.layers.attention.flashinfer_backend.logger¶
- class pymllm.layers.attention.flashinfer_backend.WrapperDispatch(*args, **kwds)¶
Bases:
enum.EnumIndicates which wrapper to use for a given attention layer.
- SLIDING_WINDOW¶
- CROSS_ATTENTION¶
- class pymllm.layers.attention.flashinfer_backend.DecodeMetadata¶
Per-batch metadata for a decode step.
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]¶
- class pymllm.layers.attention.flashinfer_backend.PrefillMetadata¶
Per-batch metadata for a prefill / extend step.
- prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]¶
- use_ragged: bool¶
- extend_no_prefix: bool¶
- pymllm.layers.attention.flashinfer_backend.should_use_tensor_core(kv_cache_dtype, num_attention_heads, num_kv_heads)¶
Return whether FlashInfer decode should use tensor cores.
For FP8 we always use tensor cores. For fp16 / bf16 we use them when the GQA group size (num_attention_heads / num_kv_heads) is ≥ 4, which fuses the head group with the token dimension in the MMA instruction.
- Parameters:
kv_cache_dtype (torch.dtype)
num_attention_heads (int)
num_kv_heads (int)
- Return type:
bool
- class pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend(num_heads, num_kv_heads, head_dim, kv_cache_dtype, q_dtype, max_context_len, req_to_token, device, max_req_pool_size, sliding_window_size=None, skip_prefill=False, kv_indptr_buf=None, kv_last_page_len_buf=None, init_new_workspace=False)¶
Bases:
pymllm.layers.attention.attention_backend.AttentionBackendFlashInfer-based attention backend for pymllm.
This class does not depend on a
ModelRunnerobject. Instead it takes all required configuration explicitly so that it can be constructed independently of any particular model runner.- Parameters:
num_heads (int) – Number of query heads per device (after any TP sharding).
num_kv_heads (int) – Number of KV heads per device.
head_dim (int) – Per-head dimension for Q and K.
kv_cache_dtype (torch.dtype) –
torch.dtypeof the KV cache (e.g.torch.float16).q_dtype (torch.dtype) –
torch.dtypeof the query tensor.max_context_len (int) – Maximum sequence length the model supports.
req_to_token (torch.Tensor) – The
[max_reqs, max_context_len]int32 tensor fromReqToTokenPool.req_to_token.device (torch.device) – Target device (e.g.
torch.device("cuda"))max_req_pool_size (int) – Maximum number of concurrent requests (=
ReqToTokenPool.size). Used to pre-allocatekv_indptr/kv_last_page_lenbuffers.sliding_window_size (Optional[int]) – When not
None, enables sliding-window attention mode which allocates two wrapper sets (window + full context).skip_prefill (bool) – When
True, skip creating prefill wrappers (for backends that only perform decode, e.g. multi-step draft backends).kv_indptr_buf (Optional[torch.Tensor]) – Optional pre-allocated
kv_indptrbuffer. Used when sharing buffers across multiple backend instances (e.g. multi-step draft).kv_last_page_len_buf (Optional[torch.Tensor]) – Optional pre-allocated
kv_last_page_lenbuffer.init_new_workspace (bool) – When
Trueallocate a fresh workspace buffer instead of reusing the global one.
- num_heads¶
- num_kv_heads¶
- head_dim¶
- kv_cache_dtype¶
- q_dtype¶
- max_context_len¶
- req_to_token¶
- device¶
- skip_prefill = False¶
- decode_use_tensor_cores¶
- prefill_wrapper_ragged: flashinfer.BatchPrefillWithRaggedKVCacheWrapper | None = None¶
- prefill_wrappers_paged: List[flashinfer.BatchPrefillWithPagedKVCacheWrapper] = []¶
- decode_wrappers: List[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = []¶
- indices_updater_decode¶
- forward_metadata: DecodeMetadata | PrefillMetadata | None = None¶
- decode_cuda_graph_metadata: dict¶
- prefill_cuda_graph_metadata: dict¶
- init_forward_metadata(forward_batch)¶
Prepare FlashInfer wrappers for the current batch.
Must be called once per batch before the model’s
forwardmethod.- Parameters:
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
- Return type:
None
- forward_extend(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)¶
Run attention for a prefill / extend step.
- Parameters:
q (torch.Tensor)
k (Optional[torch.Tensor])
v (Optional[torch.Tensor])
layer (pymllm.layers.attention.radix_attention.RadixAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
save_kv_cache (bool)
- Return type:
torch.Tensor
- forward_decode(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)¶
Run attention for a decode step (one new token per sequence).
- Parameters:
q (torch.Tensor)
k (Optional[torch.Tensor])
v (Optional[torch.Tensor])
layer (pymllm.layers.attention.radix_attention.RadixAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
save_kv_cache (bool)
- Return type:
torch.Tensor
- get_cuda_graph_seq_len_fill_value()¶
Fill value used to pad
seq_lenstensors for CUDA-graph capture.Most backends use
1(not0) to avoid division-by-zero in attention kernels.- Return type:
int
- init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf=None)¶
Allocate CUDA-graph shared state buffers.
- Parameters:
max_bs (int)
max_num_tokens (int)
kv_indices_buf (Optional[torch.Tensor])
- Return type:
None
- init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode)¶
Set up metadata for CUDA-graph capture of a decode step.
- Parameters:
bs (int)
num_tokens (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
forward_mode (pymllm.engine.forward_batch.ForwardMode)
- Return type:
None
- init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu)¶
Update metadata when replaying a CUDA graph for decode.
- Parameters:
bs (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
seq_lens_sum (int)
forward_mode (pymllm.engine.forward_batch.ForwardMode)
seq_lens_cpu (Optional[torch.Tensor])
- Return type:
None