pymllm.layers.utils

Utility functions for layers.

Functions

set_weight_attrs(weight, weight_attrs)

Set attributes on a weight tensor.

Module Contents

pymllm.layers.utils.set_weight_attrs(weight, weight_attrs)

Set attributes on a weight tensor.

This method is used to set attributes on a weight tensor. This method will not overwrite existing attributes.

Parameters:
  • weight (torch.Tensor) – The weight tensor or parameter.

  • weight_attrs (Dict[str, Any] | None) – A dictionary of attributes to set on the weight tensor. Common attributes include: - output_dim: The dimension along which to shard the weight (typically 0 for output dim) - input_dim: The input dimension (typically 1 for input dim) - weight_loader: A callable to load weights into this parameter - packed_dim: The dimension along which the weight is packed (for quantization) - packed_factor: The packing factor (for quantization)

Return type:

None

Example

>>> weight = nn.Parameter(torch.empty(100, 64))
>>> set_weight_attrs(weight, {
...     "output_dim": 0,
...     "input_dim": 1,
...     "weight_loader": my_loader_func,
... })