test_vocab_parallel_embedding

Tests for VocabParallelEmbedding layer.

This module tests the VocabParallelEmbedding layer with and without tensor parallelism.

Classes

TestVocabParallelEmbeddingRealTP8

Real distributed tests with world_size=8 and TP=8 on CUDA.

TestVocabParallelEmbeddingCUDA

Tests for non-parallel TP=1 mode on CUDA.

Functions

load_weight(param, loaded_weight)

Load weight using the weight_loader attached to param attribute.

run_worker_tp8_cuda(rank, local_rank, world_size, ...)

Worker function for multi-process testing with TP=8 on CUDA.

embedding_forward_tp8_worker_cuda(rank, local_rank, ...)

Test forward pass with real TP=8 on CUDA.

weight_loading_tp8_worker_cuda(rank, local_rank, ...)

Test weight loading with real TP=8 on CUDA.

Module Contents

test_vocab_parallel_embedding.load_weight(param, loaded_weight)

Load weight using the weight_loader attached to param attribute.

Parameters:
  • param (torch.nn.Parameter)

  • loaded_weight (torch.Tensor)

Return type:

None

test_vocab_parallel_embedding.run_worker_tp8_cuda(rank, local_rank, world_size, local_world_size, test_func, return_dict)

Worker function for multi-process testing with TP=8 on CUDA.

Parameters:
  • rank (int) – Global rank across all nodes

  • local_rank (int) – Local rank within this node (used for GPU binding)

  • world_size (int) – Total number of processes across all nodes

  • local_world_size (int) – Number of processes on this node

  • test_func (Callable) – Test function to run

  • return_dict (dict) – Shared dict for returning results

test_vocab_parallel_embedding.embedding_forward_tp8_worker_cuda(rank, local_rank, world_size)

Test forward pass with real TP=8 on CUDA.

Parameters:
  • rank (int) – Global rank

  • local_rank (int) – Local rank within this node (for logging/debugging)

  • world_size (int) – Total world size

test_vocab_parallel_embedding.weight_loading_tp8_worker_cuda(rank, local_rank, world_size)

Test weight loading with real TP=8 on CUDA.

Parameters:
  • rank (int) – Global rank

  • local_rank (int) – Local rank within this node (for GPU binding verification)

  • world_size (int) – Total world size

class test_vocab_parallel_embedding.TestVocabParallelEmbeddingRealTP8

Real distributed tests with world_size=8 and TP=8 on CUDA.

test_forward_pass_tp8_real()

Test forward pass with real TP=8 using 8 processes on CUDA.

test_weight_loading_tp8_real()

Test weight loading with real TP=8 using 8 processes on CUDA.

class test_vocab_parallel_embedding.TestVocabParallelEmbeddingCUDA

Tests for non-parallel TP=1 mode on CUDA.

setup_config()
test_cuda_forward()
test_cuda_weight_loader()