pymllm.layers.attention.attention_backend¶
Abstract base class for pymllm attention backends.
Every concrete backend (FlashInfer, Triton, torch-native, …) must implement at minimum:
init_forward_metadata– called once per batch before the model forward.
forward_extend– prefill / extend attention.
forward_decode– single-token decode attention.
The public forward method dispatches to the correct variant based on
forward_batch.forward_mode.
Classes¶
Abstract base class for attention backends. |
Module Contents¶
- class pymllm.layers.attention.attention_backend.AttentionBackend¶
Bases:
abc.ABCAbstract base class for attention backends.
All concrete backends inherit from this class and implement the abstract methods below.
- abstractmethod init_forward_metadata(forward_batch)¶
Prepare per-batch metadata before the model’s attention layers run.
For FlashInfer this plans the KV-index arrays and calls
wrapper.begin_forward; for Triton / torch-native this is a no-op. Must be called once per batch beforemodel.forward.- Parameters:
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
- Return type:
None
- abstractmethod forward_decode(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)¶
Run attention for a decode step (one new token per sequence).
- Parameters:
q (torch.Tensor)
k (Optional[torch.Tensor])
v (Optional[torch.Tensor])
layer (pymllm.layers.attention.radix_attention.RadixAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
save_kv_cache (bool)
- Return type:
torch.Tensor
- abstractmethod forward_extend(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)¶
Run attention for a prefill / extend step.
- Parameters:
q (torch.Tensor)
k (Optional[torch.Tensor])
v (Optional[torch.Tensor])
layer (pymllm.layers.attention.radix_attention.RadixAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
save_kv_cache (bool)
- Return type:
torch.Tensor
- forward(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)¶
Dispatch to
forward_decodeorforward_extendbased on mode.For IDLE batches a zero-filled output tensor is returned without any compute.
- Parameters:
q (torch.Tensor)
k (Optional[torch.Tensor])
v (Optional[torch.Tensor])
layer (pymllm.layers.attention.radix_attention.RadixAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
save_kv_cache (bool)
- Return type:
torch.Tensor
- abstractmethod forward_gdn(layer, forward_batch, mixed_qkv, a, b)¶
Run GDN linear-attention for one layer.
Only implemented by backends that support hybrid (full + GDN) architectures. The default raises
NotImplementedError.- Parameters:
layer (RadixLinearAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
mixed_qkv (torch.Tensor)
a (torch.Tensor)
b (torch.Tensor)
- Return type:
torch.Tensor
- abstractmethod get_cuda_graph_seq_len_fill_value()¶
Fill value used to pad
seq_lenstensors for CUDA-graph capture.Most backends use
1(not0) to avoid division-by-zero in attention kernels.- Return type:
int
- abstractmethod init_cuda_graph_state(max_bs, max_num_tokens)¶
Allocate shared CUDA-graph state (buffers reused across captures).
- Parameters:
max_bs (int)
max_num_tokens (int)
- Return type:
None
- abstractmethod init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode)¶
Set up per-batch metadata for capturing a CUDA graph.
- Parameters:
bs (int)
num_tokens (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
forward_mode (pymllm.engine.forward_batch.ForwardMode)
- Return type:
None
- abstractmethod init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu)¶
Update metadata when replaying a captured CUDA graph.
- Parameters:
bs (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
seq_lens_sum (int)
forward_mode (pymllm.engine.forward_batch.ForwardMode)
seq_lens_cpu (Optional[torch.Tensor])
- Return type:
None