test_vocab_parallel_embedding¶
Tests for VocabParallelEmbedding layer.
This module tests the VocabParallelEmbedding layer with and without tensor parallelism.
Classes¶
Real distributed tests with world_size=8 and TP=8 on CUDA. |
|
Tests for non-parallel TP=1 mode on CUDA. |
Functions¶
|
Load weight using the weight_loader attached to param attribute. |
|
Worker function for multi-process testing with TP=8 on CUDA. |
|
Test forward pass with real TP=8 on CUDA. |
|
Test weight loading with real TP=8 on CUDA. |
Module Contents¶
- test_vocab_parallel_embedding.load_weight(param, loaded_weight)¶
Load weight using the weight_loader attached to param attribute.
- Parameters:
param (torch.nn.Parameter)
loaded_weight (torch.Tensor)
- Return type:
None
- test_vocab_parallel_embedding.run_worker_tp8_cuda(rank, local_rank, world_size, local_world_size, test_func, return_dict)¶
Worker function for multi-process testing with TP=8 on CUDA.
- Parameters:
rank (int) – Global rank across all nodes
local_rank (int) – Local rank within this node (used for GPU binding)
world_size (int) – Total number of processes across all nodes
local_world_size (int) – Number of processes on this node
test_func (Callable) – Test function to run
return_dict (dict) – Shared dict for returning results
- test_vocab_parallel_embedding.embedding_forward_tp8_worker_cuda(rank, local_rank, world_size)¶
Test forward pass with real TP=8 on CUDA.
- Parameters:
rank (int) – Global rank
local_rank (int) – Local rank within this node (for logging/debugging)
world_size (int) – Total world size
- test_vocab_parallel_embedding.weight_loading_tp8_worker_cuda(rank, local_rank, world_size)¶
Test weight loading with real TP=8 on CUDA.
- Parameters:
rank (int) – Global rank
local_rank (int) – Local rank within this node (for GPU binding verification)
world_size (int) – Total world size
- class test_vocab_parallel_embedding.TestVocabParallelEmbeddingRealTP8¶
Real distributed tests with world_size=8 and TP=8 on CUDA.
- test_forward_pass_tp8_real()¶
Test forward pass with real TP=8 using 8 processes on CUDA.
- test_weight_loading_tp8_real()¶
Test weight loading with real TP=8 using 8 processes on CUDA.