pymllm.layers.attention.radix_linear_attention ============================================== .. py:module:: pymllm.layers.attention.radix_linear_attention .. autoapi-nested-parse:: RadixLinearAttention -- GDN linear-attention layer for hybrid models. Analogous to :class:`RadixAttention` but for GDN (Gated Delta Net) layers. Stores per-layer GDN parameters and delegates computation to the :meth:`AttentionBackend.forward_gdn` method on the current :class:`~pymllm.engine.forward_batch.ForwardBatch`. Classes ------- .. autoapisummary:: pymllm.layers.attention.radix_linear_attention.RadixLinearAttention Module Contents --------------- .. py:class:: RadixLinearAttention(layer_id, gdn_layer_idx, num_k_heads, num_v_heads, head_k_dim, head_v_dim, conv_weight, A_log, dt_bias) Bases: :py:obj:`torch.nn.Module` GDN linear-attention layer that delegates to the attention backend. Each GDN layer in a pymllm model creates one ``RadixLinearAttention`` with a unique ``layer_id`` and ``gdn_layer_idx``. During forward, it calls ``forward_batch.attn_backend.forward_gdn(...)`` which routes to the appropriate GDN backend implementation. :param layer_id: Global zero-based layer index within the model. :type layer_id: int :param gdn_layer_idx: Sequential zero-based index among GDN layers only (not global). Used to index into :class:`~pymllm.mem_cache.memory_pool.GDNPool`. :type gdn_layer_idx: int :param num_k_heads: Number of key heads. :type num_k_heads: int :param num_v_heads: Number of value heads. :type num_v_heads: int :param head_k_dim: Per-head key dimension. :type head_k_dim: int :param head_v_dim: Per-head value dimension. :type head_v_dim: int :param conv_weight: Reference to the GDNConv1d weight parameter. :type conv_weight: nn.Parameter :param A_log: Log-space decay parameter. :type A_log: nn.Parameter :param dt_bias: Bias for the decay gate. :type dt_bias: nn.Parameter .. py:attribute:: layer_id .. py:attribute:: gdn_layer_idx .. py:attribute:: num_k_heads .. py:attribute:: num_v_heads .. py:attribute:: head_k_dim .. py:attribute:: head_v_dim .. py:attribute:: conv_weight .. py:attribute:: A_log .. py:attribute:: dt_bias .. py:method:: forward(forward_batch, mixed_qkv, a, b) Delegate GDN computation to the attention backend. :param forward_batch: Batch metadata with ``attn_backend`` attached. :param mixed_qkv: Concatenated Q/K/V projection output before conv1d. :param a: Decay gate input, shape ``[num_tokens, num_v_heads]``. :param b: Update gate input, shape ``[num_tokens, num_v_heads]``. :returns: GDN attention output, shape ``[num_tokens, num_v_heads * head_v_dim]``. :rtype: torch.Tensor .. py:method:: extra_repr()