pymllm.layers.embedding¶
Classes¶
Embedding layer with vocabulary parallelism. |
Module Contents¶
- class pymllm.layers.embedding.VocabParallelEmbedding(num_embeddings, embedding_dim, padding_idx=None)¶
Bases:
pymllm.layers.base.MllmBaseLayerEmbedding layer with vocabulary parallelism.
This layer shards the embedding table along the vocabulary dimension for tensor parallelism.
- Parameters:
num_embeddings (int) – Size of the vocabulary.
embedding_dim (int) – Size of the embedding vector.
padding_idx (int) – Index for padding token (optional).
- tp_rank = 0¶
- tp_size = 1¶
- num_embeddings¶
- embedding_dim¶
- padding_idx = None¶
- num_embeddings_per_partition¶
- weight¶
- vocab_start_index¶
- vocab_end_index¶
- weight_loader(param, loaded_weight)¶
Load sharded weights into the parameter.
- Parameters:
param (torch.nn.Parameter) – The parameter to load weights into.
loaded_weight (torch.Tensor) – The weight tensor loaded from checkpoint (full size).
- forward(x)¶
Forward pass of the embedding layer with TP support.
- Parameters:
x (torch.Tensor) – Input tensor of token ids.
- Returns:
Embedded representation (all-reduced across TP group if needed).
- Return type:
torch.Tensor