pymllm.layers.rms_norm¶
Classes¶
RMSNorm layer implemented with FlashInfer kernel. |
|
Gemma-style RMSNorm layer implemented with FlashInfer kernel. |
Module Contents¶
- class pymllm.layers.rms_norm.RMSNorm(hidden_size, eps=1e-06)¶
Bases:
pymllm.layers.base.MllmBaseLayerRMSNorm layer implemented with FlashInfer kernel.
- Parameters:
hidden_size (int)
eps (float)
- eps = 1e-06¶
- weight¶
- forward(x, residual=None)¶
- Parameters:
x (torch.Tensor)
residual (Optional[torch.Tensor])
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class pymllm.layers.rms_norm.GemmaRMSNorm(hidden_size, eps=1e-06)¶
Bases:
pymllm.layers.base.MllmBaseLayerGemma-style RMSNorm layer implemented with FlashInfer kernel.
- Parameters:
hidden_size (int)
eps (float)
- eps = 1e-06¶
- weight¶
- forward(x, residual=None)¶
- Parameters:
x (torch.Tensor)
residual (Optional[torch.Tensor])
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]