test_vocab_parallel_embedding ============================= .. py:module:: test_vocab_parallel_embedding .. autoapi-nested-parse:: Tests for VocabParallelEmbedding layer. This module tests the VocabParallelEmbedding layer with and without tensor parallelism. Classes ------- .. autoapisummary:: test_vocab_parallel_embedding.TestVocabParallelEmbeddingRealTP8 test_vocab_parallel_embedding.TestVocabParallelEmbeddingCUDA Functions --------- .. autoapisummary:: test_vocab_parallel_embedding.load_weight test_vocab_parallel_embedding.run_worker_tp8_cuda test_vocab_parallel_embedding.embedding_forward_tp8_worker_cuda test_vocab_parallel_embedding.weight_loading_tp8_worker_cuda Module Contents --------------- .. py:function:: load_weight(param, loaded_weight) Load weight using the weight_loader attached to param attribute. .. py:function:: run_worker_tp8_cuda(rank, local_rank, world_size, local_world_size, test_func, return_dict) Worker function for multi-process testing with TP=8 on CUDA. :param rank: Global rank across all nodes :param local_rank: Local rank within this node (used for GPU binding) :param world_size: Total number of processes across all nodes :param local_world_size: Number of processes on this node :param test_func: Test function to run :param return_dict: Shared dict for returning results .. py:function:: embedding_forward_tp8_worker_cuda(rank, local_rank, world_size) Test forward pass with real TP=8 on CUDA. :param rank: Global rank :param local_rank: Local rank within this node (for logging/debugging) :param world_size: Total world size .. py:function:: weight_loading_tp8_worker_cuda(rank, local_rank, world_size) Test weight loading with real TP=8 on CUDA. :param rank: Global rank :param local_rank: Local rank within this node (for GPU binding verification) :param world_size: Total world size .. py:class:: TestVocabParallelEmbeddingRealTP8 Real distributed tests with world_size=8 and TP=8 on CUDA. .. py:method:: test_forward_pass_tp8_real() Test forward pass with real TP=8 using 8 processes on CUDA. .. py:method:: test_weight_loading_tp8_real() Test weight loading with real TP=8 using 8 processes on CUDA. .. py:class:: TestVocabParallelEmbeddingCUDA Tests for non-parallel TP=1 mode on CUDA. .. py:method:: setup_config() .. py:method:: test_cuda_forward() .. py:method:: test_cuda_weight_loader()