pymllm.layers.attention.radix_linear_attention

RadixLinearAttention – GDN linear-attention layer for hybrid models.

Analogous to RadixAttention but for GDN (Gated Delta Net) layers. Stores per-layer GDN parameters and delegates computation to the AttentionBackend.forward_gdn() method on the current ForwardBatch.

Classes

RadixLinearAttention

GDN linear-attention layer that delegates to the attention backend.

Module Contents

class pymllm.layers.attention.radix_linear_attention.RadixLinearAttention(layer_id, gdn_layer_idx, num_k_heads, num_v_heads, head_k_dim, head_v_dim, conv_weight, A_log, dt_bias)

Bases: torch.nn.Module

GDN linear-attention layer that delegates to the attention backend.

Each GDN layer in a pymllm model creates one RadixLinearAttention with a unique layer_id and gdn_layer_idx. During forward, it calls forward_batch.attn_backend.forward_gdn(...) which routes to the appropriate GDN backend implementation.

Parameters:
  • layer_id (int) – Global zero-based layer index within the model.

  • gdn_layer_idx (int) – Sequential zero-based index among GDN layers only (not global). Used to index into GDNPool.

  • num_k_heads (int) – Number of key heads.

  • num_v_heads (int) – Number of value heads.

  • head_k_dim (int) – Per-head key dimension.

  • head_v_dim (int) – Per-head value dimension.

  • conv_weight (nn.Parameter) – Reference to the GDNConv1d weight parameter.

  • A_log (nn.Parameter) – Log-space decay parameter.

  • dt_bias (nn.Parameter) – Bias for the decay gate.

layer_id
gdn_layer_idx
num_k_heads
num_v_heads
head_k_dim
head_v_dim
conv_weight
A_log
dt_bias
forward(forward_batch, mixed_qkv, a, b)

Delegate GDN computation to the attention backend.

Parameters:
  • forward_batch (pymllm.engine.forward_batch.ForwardBatch) – Batch metadata with attn_backend attached.

  • mixed_qkv (torch.Tensor) – Concatenated Q/K/V projection output before conv1d.

  • a (torch.Tensor) – Decay gate input, shape [num_tokens, num_v_heads].

  • b (torch.Tensor) – Update gate input, shape [num_tokens, num_v_heads].

Returns:

GDN attention output, shape [num_tokens, num_v_heads * head_v_dim].

Return type:

torch.Tensor

extra_repr()
Return type:

str