pymllm.layers.linear¶
Linear layers with quantization method dispatch.
Every linear layer holds a quant_method attribute (an instance of
LinearMethodBase). When no
quantization is configured, UnquantizedLinearMethod is used as the
default — it creates a standard FP weight and forwards via F.linear.
Quantized checkpoints plug in a different LinearMethodBase (e.g.
AWQLinearMethod) which creates packed int4 weights, scales, and
zero-points, and overrides apply() with a fused dequant+matmul kernel.
Usage in model definitions:
# Non-quantized (default)
layer = ColumnParallelLinear(4096, 4096)
# Quantized — pass a quant_method from QuantizationConfig
qm = awq_config.get_quant_method(layer, prefix="model.layers.0.q_proj")
layer = ColumnParallelLinear(4096, 4096, quant_method=qm)
Classes¶
Linear layer with column parallelism (output-dimension sharding). |
|
Linear layer with row parallelism (input-dimension sharding). |
|
Non-parallel linear layer with quantization dispatch. |
Module Contents¶
- class pymllm.layers.linear.ColumnParallelLinear(in_features, out_features, bias=True, gather_output=True, quant_method=None)¶
Bases:
pymllm.layers.base.MllmBaseLayerLinear layer with column parallelism (output-dimension sharding).
The weight matrix is split along the output dimension across TP ranks. Each rank holds
out_features / tp_sizerows of the weight.- Parameters:
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample (before sharding).
bias (bool) – If
True, adds a learnable bias.gather_output (bool) – If
True, all-gather the output across TP ranks so every rank gets the fullout_features. Set toFalsewhen the next layer is aRowParallelLinearthat expects a split input.quant_method (Optional[pymllm.layers.quantize_base.LinearMethodBase]) – Quantization method instance.
None→UnquantizedLinearMethod.
- tp_rank = 0¶
- tp_size = 1¶
- in_features¶
- out_features¶
- gather_output = True¶
- out_features_per_partition¶
- output_start_index¶
- output_end_index¶
- quant_method¶
- weight_loader(param, loaded_weight)¶
Load sharded weights into the parameter.
- Parameters:
param (torch.nn.Parameter) – The parameter to load weights into.
loaded_weight (torch.Tensor) – The weight tensor loaded from checkpoint (full size).
- forward(x)¶
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor
- class pymllm.layers.linear.RowParallelLinear(in_features, out_features, bias=True, reduce_output=True, quant_method=None)¶
Bases:
pymllm.layers.base.MllmBaseLayerLinear layer with row parallelism (input-dimension sharding).
The weight matrix is split along the input dimension across TP ranks. Each rank holds all
out_featuresrows but onlyin_features / tp_sizecolumns.Typically placed after a
ColumnParallelLinearwhosegather_output=False, so the input is already split.- Parameters:
in_features (int) – Size of each input sample (before sharding).
out_features (int) – Size of each output sample.
bias (bool) – If
True, adds a learnable bias (applied after all-reduce).reduce_output (bool) – If
True, all-reduce the output across TP ranks.quant_method (Optional[pymllm.layers.quantize_base.LinearMethodBase]) – Quantization method instance.
None→UnquantizedLinearMethod.
- tp_rank = 0¶
- tp_size = 1¶
- in_features¶
- out_features¶
- reduce_output = True¶
- in_features_per_partition¶
- input_start_index¶
- input_end_index¶
- quant_method¶
- weight_loader(param, loaded_weight)¶
Load sharded weights into the parameter.
- Parameters:
param (torch.nn.Parameter) – The parameter to load weights into.
loaded_weight (torch.Tensor) – The weight tensor loaded from checkpoint (full size).
- forward(x)¶
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor
- class pymllm.layers.linear.Linear(in_features, out_features, bias=True, quant_method=None)¶
Bases:
pymllm.layers.base.MllmBaseLayerNon-parallel linear layer with quantization dispatch.
- Parameters:
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
bias (bool) – If
True, adds a learnable bias.quant_method (Optional[pymllm.layers.quantize_base.LinearMethodBase]) – Quantization method instance.
None→UnquantizedLinearMethod.
- in_features¶
- out_features¶
- quant_method¶
- forward(x)¶
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor