pymllm.layers.gated_delta_net

Gated Delta Network (GDN) linear attention for Qwen3.5.

This implements the linear attention mechanism used in Qwen3.5’s hybrid architecture. GDN alternates with standard full-attention layers.

Core formulation (decode, per-head):

g_t = -exp(A_log) * softplus(a_t + dt_bias) beta_t = sigmoid(b_t) state_t = exp(g_t) * state_{t-1} + beta_t * (k_t outer v_t) output_t = (q_t @ state_t)

State is externalized into a GDNPool and computation is delegated to the attention backend via RadixLinearAttention.

Attributes

Classes

GDNConv1d

Causal 1D convolution weight holder for GDN sequence mixing.

GatedDeltaNet

Gated Delta Network linear attention layer for Qwen3.5.

Module Contents

pymllm.layers.gated_delta_net.logger
class pymllm.layers.gated_delta_net.GDNConv1d(channels, kernel_size)

Bases: torch.nn.Module

Causal 1D convolution weight holder for GDN sequence mixing.

The actual convolution computation is performed by the GDN backend using pooled conv states. This module only holds the learnable weight.

Parameters:
  • channels (int)

  • kernel_size (int)

channels
kernel_size
weight
class pymllm.layers.gated_delta_net.GatedDeltaNet(hidden_size, num_k_heads=16, num_v_heads=32, head_k_dim=128, head_v_dim=128, conv_kernel_size=4, layer_id=0, gdn_layer_idx=0, rms_norm_eps=1e-06, quant_config=None, prefix='')

Bases: pymllm.layers.base.MllmBaseLayer

Gated Delta Network linear attention layer for Qwen3.5.

State is externalized into a GDNPool and computation is delegated to the attention backend via RadixLinearAttention.

Parameters:
  • hidden_size (int) – Model hidden dimension.

  • num_k_heads (int) – Number of key heads.

  • num_v_heads (int) – Number of value heads.

  • head_k_dim (int) – Per-head key dimension.

  • head_v_dim (int) – Per-head value dimension.

  • conv_kernel_size (int) – Causal conv1d kernel width.

  • layer_id (int) – Global layer index.

  • gdn_layer_idx (int) – Sequential index among GDN layers (0-based).

  • rms_norm_eps (float) – Epsilon for gated RMS normalization.

  • prefix (str)

hidden_size
num_k_heads = 16
num_v_heads = 32
head_k_dim = 128
head_v_dim = 128
key_dim = 2048
value_dim = 4096
conv_kernel_size = 4
layer_id = 0
gdn_layer_idx = 0
in_proj_qkv
in_proj_z
in_proj_a
in_proj_b
conv1d
A_log
dt_bias
norm
out_proj
attn
forward(hidden_states, forward_batch=None)
Parameters:
  • hidden_states (torch.Tensor)

  • forward_batch (Any)

Return type:

torch.Tensor