pymllm.layers.gated_delta_net¶
Gated Delta Network (GDN) linear attention for Qwen3.5.
This implements the linear attention mechanism used in Qwen3.5’s hybrid architecture. GDN alternates with standard full-attention layers.
- Core formulation (decode, per-head):
g_t = -exp(A_log) * softplus(a_t + dt_bias) beta_t = sigmoid(b_t) state_t = exp(g_t) * state_{t-1} + beta_t * (k_t outer v_t) output_t = (q_t @ state_t)
State is externalized into a GDNPool
and computation is delegated to the attention backend via
RadixLinearAttention.
Attributes¶
Classes¶
Causal 1D convolution weight holder for GDN sequence mixing. |
|
Gated Delta Network linear attention layer for Qwen3.5. |
Module Contents¶
- pymllm.layers.gated_delta_net.logger¶
- class pymllm.layers.gated_delta_net.GDNConv1d(channels, kernel_size)¶
Bases:
torch.nn.ModuleCausal 1D convolution weight holder for GDN sequence mixing.
The actual convolution computation is performed by the GDN backend using pooled conv states. This module only holds the learnable weight.
- Parameters:
channels (int)
kernel_size (int)
- channels¶
- kernel_size¶
- weight¶
- class pymllm.layers.gated_delta_net.GatedDeltaNet(hidden_size, num_k_heads=16, num_v_heads=32, head_k_dim=128, head_v_dim=128, conv_kernel_size=4, layer_id=0, gdn_layer_idx=0, rms_norm_eps=1e-06, quant_config=None, prefix='')¶
Bases:
pymllm.layers.base.MllmBaseLayerGated Delta Network linear attention layer for Qwen3.5.
State is externalized into a GDNPool and computation is delegated to the attention backend via RadixLinearAttention.
- Parameters:
hidden_size (int) – Model hidden dimension.
num_k_heads (int) – Number of key heads.
num_v_heads (int) – Number of value heads.
head_k_dim (int) – Per-head key dimension.
head_v_dim (int) – Per-head value dimension.
conv_kernel_size (int) – Causal conv1d kernel width.
layer_id (int) – Global layer index.
gdn_layer_idx (int) – Sequential index among GDN layers (0-based).
rms_norm_eps (float) – Epsilon for gated RMS normalization.
prefix (str)
- num_k_heads = 16¶
- num_v_heads = 32¶
- head_k_dim = 128¶
- head_v_dim = 128¶
- key_dim = 2048¶
- value_dim = 4096¶
- conv_kernel_size = 4¶
- layer_id = 0¶
- gdn_layer_idx = 0¶
- in_proj_qkv¶
- in_proj_z¶
- in_proj_a¶
- in_proj_b¶
- conv1d¶
- A_log¶
- dt_bias¶
- norm¶
- out_proj¶
- attn¶
- forward(hidden_states, forward_batch=None)¶
- Parameters:
hidden_states (torch.Tensor)
forward_batch (Any)
- Return type:
torch.Tensor