pymllm.layers.rms_norm

Classes

RMSNorm

RMSNorm layer implemented with FlashInfer kernel.

GemmaRMSNorm

Gemma-style RMSNorm layer implemented with FlashInfer kernel.

Module Contents

class pymllm.layers.rms_norm.RMSNorm(hidden_size, eps=1e-06)

Bases: pymllm.layers.base.MllmBaseLayer

RMSNorm layer implemented with FlashInfer kernel.

Parameters:
  • hidden_size (int)

  • eps (float)

hidden_size
eps = 1e-06
weight
forward(x, residual=None)
Parameters:
  • x (torch.Tensor)

  • residual (Optional[torch.Tensor])

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

class pymllm.layers.rms_norm.GemmaRMSNorm(hidden_size, eps=1e-06)

Bases: pymllm.layers.base.MllmBaseLayer

Gemma-style RMSNorm layer implemented with FlashInfer kernel.

Parameters:
  • hidden_size (int)

  • eps (float)

hidden_size
eps = 1e-06
weight
forward(x, residual=None)
Parameters:
  • x (torch.Tensor)

  • residual (Optional[torch.Tensor])

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]