pymllm.layers.embedding

Classes

VocabParallelEmbedding

Embedding layer with vocabulary parallelism.

Module Contents

class pymllm.layers.embedding.VocabParallelEmbedding(num_embeddings, embedding_dim, padding_idx=None)

Bases: pymllm.layers.base.MllmBaseLayer

Embedding layer with vocabulary parallelism.

This layer shards the embedding table along the vocabulary dimension for tensor parallelism.

Parameters:
  • num_embeddings (int) – Size of the vocabulary.

  • embedding_dim (int) – Size of the embedding vector.

  • padding_idx (int) – Index for padding token (optional).

tp_rank = 0
tp_size = 1
num_embeddings
embedding_dim
padding_idx = None
num_embeddings_per_partition
weight
vocab_start_index
vocab_end_index
weight_loader(param, loaded_weight)

Load sharded weights into the parameter.

Parameters:
  • param (torch.nn.Parameter) – The parameter to load weights into.

  • loaded_weight (torch.Tensor) – The weight tensor loaded from checkpoint (full size).

forward(x)

Forward pass of the embedding layer with TP support.

Parameters:

x (torch.Tensor) – Input tensor of token ids.

Returns:

Embedded representation (all-reduced across TP group if needed).

Return type:

torch.Tensor