pymllm.layers.attention.radix_attention ======================================= .. py:module:: pymllm.layers.attention.radix_attention .. autoapi-nested-parse:: RadixAttention -- the attention layer used by pymllm models. This module is kept small intentionally: all heavy computation is delegated to the pluggable ``AttentionBackend`` that is attached to the ``ForwardBatch``. Classes ------- .. autoapisummary:: pymllm.layers.attention.radix_attention.AttentionType pymllm.layers.attention.radix_attention.RadixAttention Module Contents --------------- .. py:class:: AttentionType(*args, **kwds) Bases: :py:obj:`enum.Enum` Attention variant used by a :class:`RadixAttention` layer. Uses string values so that ``torch.compile`` can treat them as constants. .. py:attribute:: DECODER :value: 'decoder' .. py:attribute:: DECODER_BIDIRECTIONAL :value: 'decoder_bidirectional' .. py:attribute:: ENCODER_ONLY :value: 'encoder_only' .. py:class:: RadixAttention(num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap = 0.0, v_head_dim = -1, sliding_window_size = -1, is_cross_attention = False, attn_type = AttentionType.DECODER) Bases: :py:obj:`torch.nn.Module` Attention layer that delegates computation to a pluggable backend. Each transformer attention layer in a pymllm model creates exactly one ``RadixAttention`` with a unique ``layer_id``. During the forward pass the layer looks up the correct KV buffer via ``layer_id`` and calls the backend attached to the current :class:`~pymllm.engine.forward_batch.ForwardBatch`. :param num_heads: Number of query attention heads (after any tensor-parallelism sharding; pass the full count if not using TP). :param head_dim: Per-head dimension for query and key projections. :param scaling: Softmax pre-scale, typically ``1 / sqrt(head_dim)``. :param num_kv_heads: Number of key / value heads (supports GQA / MQA). :param layer_id: Zero-based index of this layer within the model. Used to index into ``KVPool.k_buffer`` / ``v_buffer``. :param logit_cap: If > 0, attention logits are soft-capped to this value via a ``tanh`` gate (used by Gemma2 / Gemma3 style models). Set to ``0.0`` to disable. :param v_head_dim: Per-head dimension of the value projection. Defaults to ``head_dim`` (i.e. standard square QKV). :param sliding_window_size: Sliding-window attention span. ``-1`` means full context (no window). :param is_cross_attention: ``True`` for cross-attention layers in encoder-decoder models. :param attn_type: One of :class:`AttentionType`. .. py:attribute:: tp_q_head_num :type: int .. py:attribute:: tp_k_head_num :type: int .. py:attribute:: tp_v_head_num :type: int .. py:attribute:: head_dim :type: int .. py:attribute:: qk_head_dim :type: int .. py:attribute:: v_head_dim :type: int .. py:attribute:: scaling :type: float .. py:attribute:: layer_id :type: int .. py:attribute:: logit_cap :type: float :value: 0.0 .. py:attribute:: sliding_window_size :type: int :value: -1 .. py:attribute:: is_cross_attention :type: bool :value: False .. py:attribute:: attn_type :type: AttentionType .. py:method:: forward(q, k, v, forward_batch, save_kv_cache = True, **kwargs) Run attention for one batch. :param q: Query tensor, shape ``[num_tokens, tp_q_head_num * head_dim]`` (or already reshaped to ``[num_tokens, tp_q_head_num, head_dim]``). :param k: Key tensor, same leading dimension as ``q``, shape ``[num_tokens, tp_k_head_num * qk_head_dim]``. Pass ``None`` for cross-layer KV sharing (``v`` must also be ``None`` in this case). :param v: Value tensor, shape ``[num_tokens, tp_v_head_num * v_head_dim]``. :param forward_batch: Batch metadata and references to memory pools / backend. :param save_kv_cache: When ``False``, skip writing K/V into the pool (useful for draft models in speculative decoding). :param \*\*kwargs: Passed through to the backend (e.g. ``q_rope``, ``k_rope``). .. py:method:: extra_repr()