pymllm.layers.rms_norm_gated¶
Gated RMSNorm layer for Qwen3.5 GDN attention.
Computes rmsnorm(x, weight, eps) * silu(z) using a fused CUDA kernel
from mllm-kernel. Falls back to PyTorch when the kernel is unavailable.
Attributes¶
Classes¶
Gated RMS Normalization layer for Qwen3.5 GDN attention. |
Functions¶
|
Compute (optionally gated) RMS normalization. |
Module Contents¶
- pymllm.layers.rms_norm_gated.logger¶
- pymllm.layers.rms_norm_gated.rms_norm_gated(x, weight, z=None, eps=1e-06, norm_before_gate=True)¶
Compute (optionally gated) RMS normalization.
Uses the fused mllm-kernel CUDA implementation when available, otherwise falls back to a pure-PyTorch implementation.
- Parameters:
x (torch.Tensor)
weight (torch.Tensor)
z (Optional[torch.Tensor])
eps (float)
norm_before_gate (bool)
- Return type:
torch.Tensor
- class pymllm.layers.rms_norm_gated.RMSNormGated(hidden_size, eps=1e-06, group_size=None, norm_before_gate=True, device=None, dtype=None)¶
Bases:
pymllm.layers.base.MllmBaseLayerGated RMS Normalization layer for Qwen3.5 GDN attention.
Computes:
output = rmsnorm(x, weight) * silu(z) # z is not None output = rmsnorm(x, weight) # z is None
Uses a fused CUDA kernel from mllm-kernel for maximum throughput.
- Parameters:
hidden_size (int) – Dimensionality of the input (and weight vector).
eps (float) – Small constant for numerical stability.
norm_before_gate (bool) – If
True(default):rmsnorm(x) * silu(z). IfFalse:rmsnorm(x * silu(z)).group_size (Optional[int])
device (Optional[torch.device])
dtype (Optional[torch.dtype])
- eps = 1e-06¶
- norm_before_gate = True¶
- weight¶
- forward(x, z=None)¶
- Parameters:
x (torch.Tensor)
z (Optional[torch.Tensor])
- Return type:
torch.Tensor
- extra_repr()¶
- Return type:
str