pymllm.layers.attention.radix_linear_attention¶
RadixLinearAttention – GDN linear-attention layer for hybrid models.
Analogous to RadixAttention but for GDN (Gated Delta Net) layers.
Stores per-layer GDN parameters and delegates computation to the
AttentionBackend.forward_gdn() method on the current
ForwardBatch.
Classes¶
GDN linear-attention layer that delegates to the attention backend. |
Module Contents¶
- class pymllm.layers.attention.radix_linear_attention.RadixLinearAttention(layer_id, gdn_layer_idx, num_k_heads, num_v_heads, head_k_dim, head_v_dim, conv_weight, A_log, dt_bias)¶
Bases:
torch.nn.ModuleGDN linear-attention layer that delegates to the attention backend.
Each GDN layer in a pymllm model creates one
RadixLinearAttentionwith a uniquelayer_idandgdn_layer_idx. During forward, it callsforward_batch.attn_backend.forward_gdn(...)which routes to the appropriate GDN backend implementation.- Parameters:
layer_id (int) – Global zero-based layer index within the model.
gdn_layer_idx (int) – Sequential zero-based index among GDN layers only (not global). Used to index into
GDNPool.num_k_heads (int) – Number of key heads.
num_v_heads (int) – Number of value heads.
head_k_dim (int) – Per-head key dimension.
head_v_dim (int) – Per-head value dimension.
conv_weight (nn.Parameter) – Reference to the GDNConv1d weight parameter.
A_log (nn.Parameter) – Log-space decay parameter.
dt_bias (nn.Parameter) – Bias for the decay gate.
- layer_id¶
- gdn_layer_idx¶
- num_k_heads¶
- num_v_heads¶
- head_k_dim¶
- head_v_dim¶
- conv_weight¶
- A_log¶
- dt_bias¶
- forward(forward_batch, mixed_qkv, a, b)¶
Delegate GDN computation to the attention backend.
- Parameters:
forward_batch (pymllm.engine.forward_batch.ForwardBatch) – Batch metadata with
attn_backendattached.mixed_qkv (torch.Tensor) – Concatenated Q/K/V projection output before conv1d.
a (torch.Tensor) – Decay gate input, shape
[num_tokens, num_v_heads].b (torch.Tensor) – Update gate input, shape
[num_tokens, num_v_heads].
- Returns:
GDN attention output, shape
[num_tokens, num_v_heads * head_v_dim].- Return type:
torch.Tensor
- extra_repr()¶
- Return type:
str