pymllm.layers.attention.radix_attention

RadixAttention – the attention layer used by pymllm models.

This module is kept small intentionally: all heavy computation is delegated to the pluggable AttentionBackend that is attached to the ForwardBatch.

Classes

AttentionType

Attention variant used by a RadixAttention layer.

RadixAttention

Attention layer that delegates computation to a pluggable backend.

Module Contents

class pymllm.layers.attention.radix_attention.AttentionType(*args, **kwds)

Bases: enum.Enum

Attention variant used by a RadixAttention layer.

Uses string values so that torch.compile can treat them as constants.

DECODER = 'decoder'
DECODER_BIDIRECTIONAL = 'decoder_bidirectional'
ENCODER_ONLY = 'encoder_only'
class pymllm.layers.attention.radix_attention.RadixAttention(num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=0.0, v_head_dim=-1, sliding_window_size=-1, is_cross_attention=False, attn_type=AttentionType.DECODER)

Bases: torch.nn.Module

Attention layer that delegates computation to a pluggable backend.

Each transformer attention layer in a pymllm model creates exactly one RadixAttention with a unique layer_id. During the forward pass the layer looks up the correct KV buffer via layer_id and calls the backend attached to the current ForwardBatch.

Parameters:
  • num_heads (int) – Number of query attention heads (after any tensor-parallelism sharding; pass the full count if not using TP).

  • head_dim (int) – Per-head dimension for query and key projections.

  • scaling (float) – Softmax pre-scale, typically 1 / sqrt(head_dim).

  • num_kv_heads (int) – Number of key / value heads (supports GQA / MQA).

  • layer_id (int) – Zero-based index of this layer within the model. Used to index into KVPool.k_buffer / v_buffer.

  • logit_cap (float) – If > 0, attention logits are soft-capped to this value via a tanh gate (used by Gemma2 / Gemma3 style models). Set to 0.0 to disable.

  • v_head_dim (int) – Per-head dimension of the value projection. Defaults to head_dim (i.e. standard square QKV).

  • sliding_window_size (int) – Sliding-window attention span. -1 means full context (no window).

  • is_cross_attention (bool) – True for cross-attention layers in encoder-decoder models.

  • attn_type (AttentionType) – One of AttentionType.

tp_q_head_num: int
tp_k_head_num: int
tp_v_head_num: int
head_dim: int
qk_head_dim: int
v_head_dim: int
scaling: float
layer_id: int
logit_cap: float = 0.0
sliding_window_size: int = -1
is_cross_attention: bool = False
attn_type: AttentionType
forward(q, k, v, forward_batch, save_kv_cache=True, **kwargs)

Run attention for one batch.

Parameters:
  • q (torch.Tensor) – Query tensor, shape [num_tokens, tp_q_head_num * head_dim] (or already reshaped to [num_tokens, tp_q_head_num, head_dim]).

  • k (Optional[torch.Tensor]) – Key tensor, same leading dimension as q, shape [num_tokens, tp_k_head_num * qk_head_dim]. Pass None for cross-layer KV sharing (v must also be None in this case).

  • v (Optional[torch.Tensor]) – Value tensor, shape [num_tokens, tp_v_head_num * v_head_dim].

  • forward_batch (pymllm.engine.forward_batch.ForwardBatch) – Batch metadata and references to memory pools / backend.

  • save_kv_cache (bool) – When False, skip writing K/V into the pool (useful for draft models in speculative decoding).

  • **kwargs – Passed through to the backend (e.g. q_rope, k_rope).

Return type:

torch.Tensor

extra_repr()
Return type:

str