pymllm.layers.embedding ======================= .. py:module:: pymllm.layers.embedding Classes ------- .. autoapisummary:: pymllm.layers.embedding.VocabParallelEmbedding Module Contents --------------- .. py:class:: VocabParallelEmbedding(num_embeddings, embedding_dim, padding_idx = None) Bases: :py:obj:`pymllm.layers.base.MllmBaseLayer` Embedding layer with vocabulary parallelism. This layer shards the embedding table along the vocabulary dimension for tensor parallelism. :param num_embeddings: Size of the vocabulary. :param embedding_dim: Size of the embedding vector. :param padding_idx: Index for padding token (optional). .. py:attribute:: tp_rank :value: 0 .. py:attribute:: tp_size :value: 1 .. py:attribute:: num_embeddings .. py:attribute:: embedding_dim .. py:attribute:: padding_idx :value: None .. py:attribute:: num_embeddings_per_partition .. py:attribute:: weight .. py:attribute:: vocab_start_index .. py:attribute:: vocab_end_index .. py:method:: weight_loader(param, loaded_weight) Load sharded weights into the parameter. :param param: The parameter to load weights into. :param loaded_weight: The weight tensor loaded from checkpoint (full size). .. py:method:: forward(x) Forward pass of the embedding layer with TP support. :param x: Input tensor of token ids. :returns: Embedded representation (all-reduced across TP group if needed).