pymllm.layers.attention.radix_attention¶
RadixAttention – the attention layer used by pymllm models.
This module is kept small intentionally: all heavy computation is delegated
to the pluggable AttentionBackend that is attached to the ForwardBatch.
Classes¶
Attention variant used by a |
|
Attention layer that delegates computation to a pluggable backend. |
Module Contents¶
- class pymllm.layers.attention.radix_attention.AttentionType(*args, **kwds)¶
Bases:
enum.EnumAttention variant used by a
RadixAttentionlayer.Uses string values so that
torch.compilecan treat them as constants.- DECODER = 'decoder'¶
- DECODER_BIDIRECTIONAL = 'decoder_bidirectional'¶
- ENCODER_ONLY = 'encoder_only'¶
- class pymllm.layers.attention.radix_attention.RadixAttention(num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=0.0, v_head_dim=-1, sliding_window_size=-1, is_cross_attention=False, attn_type=AttentionType.DECODER)¶
Bases:
torch.nn.ModuleAttention layer that delegates computation to a pluggable backend.
Each transformer attention layer in a pymllm model creates exactly one
RadixAttentionwith a uniquelayer_id. During the forward pass the layer looks up the correct KV buffer vialayer_idand calls the backend attached to the currentForwardBatch.- Parameters:
num_heads (int) – Number of query attention heads (after any tensor-parallelism sharding; pass the full count if not using TP).
head_dim (int) – Per-head dimension for query and key projections.
scaling (float) – Softmax pre-scale, typically
1 / sqrt(head_dim).num_kv_heads (int) – Number of key / value heads (supports GQA / MQA).
layer_id (int) – Zero-based index of this layer within the model. Used to index into
KVPool.k_buffer/v_buffer.logit_cap (float) – If > 0, attention logits are soft-capped to this value via a
tanhgate (used by Gemma2 / Gemma3 style models). Set to0.0to disable.v_head_dim (int) – Per-head dimension of the value projection. Defaults to
head_dim(i.e. standard square QKV).sliding_window_size (int) – Sliding-window attention span.
-1means full context (no window).is_cross_attention (bool) –
Truefor cross-attention layers in encoder-decoder models.attn_type (AttentionType) – One of
AttentionType.
- tp_q_head_num: int¶
- tp_k_head_num: int¶
- tp_v_head_num: int¶
- head_dim: int¶
- qk_head_dim: int¶
- v_head_dim: int¶
- scaling: float¶
- layer_id: int¶
- logit_cap: float = 0.0¶
- sliding_window_size: int = -1¶
- is_cross_attention: bool = False¶
- attn_type: AttentionType¶
- forward(q, k, v, forward_batch, save_kv_cache=True, **kwargs)¶
Run attention for one batch.
- Parameters:
q (torch.Tensor) – Query tensor, shape
[num_tokens, tp_q_head_num * head_dim](or already reshaped to[num_tokens, tp_q_head_num, head_dim]).k (Optional[torch.Tensor]) – Key tensor, same leading dimension as
q, shape[num_tokens, tp_k_head_num * qk_head_dim]. PassNonefor cross-layer KV sharing (vmust also beNonein this case).v (Optional[torch.Tensor]) – Value tensor, shape
[num_tokens, tp_v_head_num * v_head_dim].forward_batch (pymllm.engine.forward_batch.ForwardBatch) – Batch metadata and references to memory pools / backend.
save_kv_cache (bool) – When
False, skip writing K/V into the pool (useful for draft models in speculative decoding).**kwargs – Passed through to the backend (e.g.
q_rope,k_rope).
- Return type:
torch.Tensor
- extra_repr()¶
- Return type:
str