pymllm.layers.quantize_base¶
Quantization method base classes for pymllm layers.
This module defines the plugin interface that all quantization methods must
implement. The pattern follows sglang / vLLM’s LinearMethodBase design:
Each quantization algorithm (AWQ, GPTQ, FP8, …) provides a concrete subclass of
LinearMethodBase.Linear layers hold a
quant_methodattribute (an instance ofLinearMethodBase).During
__init__, the linear layer callsquant_method.create_weights(layer, ...)to register the appropriate parameters (packed int weights, scales, zero-points, etc.) on itself.During
forward, the linear layer callsquant_method.apply(layer, x, bias)instead ofF.linear.After checkpoint loading,
ModelRunneriterates all modules and callsquant_method.process_weights_after_loading(layer)for format conversion, repacking (e.g. AWQ → Marlin), or calibration.
Typical lifecycle:
# ---- model construction ----
quant_method = SomeLinearMethod(bits=4, group_size=128)
layer = ColumnParallelLinear(4096, 4096, quant_method=quant_method)
# → calls quant_method.create_weights(layer, ...)
# → layer now has .qweight, .scales, .qzeros, etc.
# ---- weight loading ----
model.load_weights(iter_weights(...))
# → checkpoint tensors are loaded into the parameters created above,
# using each parameter's ``weight_loader`` attribute.
# ---- post-load processing ----
for module in model.modules():
qm = getattr(module, "quant_method", None)
if qm is not None:
qm.process_weights_after_loading(module)
# → AWQ repacks int4 → Marlin layout, GPTQ shuffles by g_idx, etc.
# ---- inference ----
output = layer(x)
# → calls quant_method.apply(layer, x, bias)
# → dequant + matmul (or fused kernel)
Classes¶
Base class for all quantization methods (linear, embedding, MoE, ...). |
|
Base class for quantization methods applied to linear layers. |
|
Default pass-through for non-quantized linear layers. |
Module Contents¶
- class pymllm.layers.quantize_base.QuantizeMethodBase¶
Bases:
abc.ABCBase class for all quantization methods (linear, embedding, MoE, …).
Every concrete quantization algorithm must implement at least
create_weights()andapply().How to implement a new quantization method¶
Subclass
LinearMethodBase(for linear layers).Override
create_weights()to register quantized parameters (qweight,scales,qzeros, etc.) on the layer vialayer.register_parameter().Override
apply()to perform the quantized forward computation.Optionally override
process_weights_after_loading()if the checkpoint format differs from the runtime format (e.g. repacking, transposing, or calibrating scales).
- abstractmethod create_weights(layer, *args, **kwargs)¶
Create and register quantized weight parameters on layer.
Called once during layer construction (
__init__). Implementations should calllayer.register_parameter(name, param)and attach metadata viaset_weight_attrs()so that the weight-loading infrastructure knows how to shard and load them.- Parameters:
layer (torch.nn.Module) – The
nn.Module(e.g.ColumnParallelLinear) that will own the parameters.args (Any)
kwargs (Any)
- Return type:
None
- abstractmethod apply(layer, *args, **kwargs)¶
Execute the quantized forward pass.
Called by
layer.forward()every inference step. The method should read the parameters previously created bycreate_weights()from layer (e.g.layer.qweight,layer.scales), dequantize or invoke a fused kernel, and return the output tensor.- Parameters:
layer (torch.nn.Module) – The module that owns the quantized parameters.
args (Any)
kwargs (Any)
- Return type:
torch.Tensor
- process_weights_after_loading(layer)¶
Post-process parameters after checkpoint loading.
Called once by
ModelRunnerafter all checkpoint tensors have been loaded into the layer’s parameters. Use this for:Repacking: converting checkpoint layout to kernel-native layout (e.g. AutoAWQ int4 → Marlin packed format).
Transposing: rearranging dimensions for optimised GEMM kernels.
Calibration: computing per-tensor or per-channel scales from the loaded FP weights (e.g. dynamic FP8 quantisation).
Cleanup: replacing custom parameter wrappers with plain
torch.nn.Parameterto avoid overhead during inference.
The default implementation is a no-op.
- Parameters:
layer (torch.nn.Module)
- Return type:
None
- class pymllm.layers.quantize_base.LinearMethodBase¶
Bases:
QuantizeMethodBaseBase class for quantization methods applied to linear layers.
Narrows the
QuantizeMethodBaseinterface with concrete signatures tailored to linear (matmul) operations.Subclasses must implement
create_weights()andapply().- abstractmethod create_weights(layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs)¶
Create quantized weight tensors on layer.
- Parameters:
layer (torch.nn.Module) – The linear module that will own the parameters.
input_size_per_partition (int) – Number of input features on this TP rank.
output_partition_sizes (List[int]) – Output sizes of each logical weight on this TP rank. For a standard linear layer this is
[out_features_per_partition]. For a merged QKV layer it might be[q_size, k_size, v_size].input_size (int) – Full (un-sharded) input dimension.
output_size (int) – Full (un-sharded) output dimension.
params_dtype (torch.dtype) – Data type for full-precision parameters (e.g.
torch.float16).**extra_weight_attrs (Any) – Additional metadata to attach to created parameters (e.g.
weight_loader,packed_dim,packed_factor).W4A16):: (Example (AWQ) –
# Register packed 4-bit weights, scales, and zero-points qweight = Parameter(torch.empty(…, dtype=torch.int32)) layer.register_parameter(“qweight”, qweight)
scales = Parameter(torch.empty(…, dtype=params_dtype)) layer.register_parameter(“scales”, scales)
qzeros = Parameter(torch.empty(…, dtype=torch.int32)) layer.register_parameter(“qzeros”, qzeros)
- Return type:
None
- abstractmethod apply(layer, x, bias=None)¶
Compute the quantized linear forward.
- Parameters:
layer (torch.nn.Module) – The module that owns quantized parameters (set by
create_weights()).x (torch.Tensor) – Input activation tensor, shape
(*, input_size_per_partition).bias (Optional[torch.Tensor]) – Optional bias vector.
- Returns:
torch.Tensor – Output tensor, shape
(*, sum(output_partition_sizes)).Example (AWQ W4A16):: – qweight = layer.qweight # packed int32 scales = layer.scales # fp16 per-group scales qzeros = layer.qzeros # packed int32 zero-points # → invoke dequant + matmul kernel
- Return type:
torch.Tensor
- class pymllm.layers.quantize_base.UnquantizedLinearMethod¶
Bases:
LinearMethodBaseDefault pass-through for non-quantized linear layers.
Creates a standard FP weight
(out_features, in_features)and forwards viaF.linear. This is used when no quantization config is specified so that every linear layer always has aquant_methodattribute with a uniform interface.- create_weights(layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs)¶
Create a standard full-precision weight parameter.
- Parameters:
layer (torch.nn.Module)
input_size_per_partition (int)
output_partition_sizes (List[int])
input_size (int)
output_size (int)
params_dtype (torch.dtype)
extra_weight_attrs (Any)
- Return type:
None
- apply(layer, x, bias=None)¶
Standard
F.linearforward.- Parameters:
layer (torch.nn.Module)
x (torch.Tensor)
bias (Optional[torch.Tensor])
- Return type:
torch.Tensor