pymllm.quantization.methods.awq_marlin¶
AWQ quantization with Marlin kernel acceleration.
This module implements the AWQ Marlin quantization plugin for pymllm, providing high-performance W4A16 inference via the Marlin GEMM kernel.
Classes¶
- AWQMarlinConfig
Quantization configuration parsed from
quantize_config.json.- AWQMarlinLinearMethod
Linear method that uses AWQ weight format with Marlin kernel dispatch.
Attributes¶
Classes¶
Linear method for AWQ with Marlin kernel acceleration. |
|
Configuration for AWQ quantization with Marlin kernel acceleration. |
Functions¶
|
Verify that the Marlin kernel supports this configuration. |
Verify that tensor dimensions are compatible with Marlin. |
|
|
Create Marlin workspace buffer for threadblock synchronization. |
|
Create empty g_idx tensor (AWQ doesn't use activation reordering). |
Get the scale permutation indices for Marlin format. |
|
|
Permute quantization scales from standard to Marlin layout. |
|
Pack quantized columns into int32 values. |
|
Unpack int32 packed columns into individual quantized values. |
|
Permute and pack zero points into Marlin format. |
|
Convert AWQ-format zero points to Marlin format. |
|
Replace a parameter on a layer with new data. |
Module Contents¶
- pymllm.quantization.methods.awq_marlin.logger¶
- pymllm.quantization.methods.awq_marlin.MARLIN_SUPPORTED_GROUP_SIZES¶
- pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_MIN_THREAD_N = 64¶
- pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_MIN_THREAD_K = 128¶
- pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_TILE = 16¶
- pymllm.quantization.methods.awq_marlin.SCALAR_TYPE_UINT4¶
- pymllm.quantization.methods.awq_marlin.SCALAR_TYPE_UINT8¶
- pymllm.quantization.methods.awq_marlin.verify_marlin_supported(quant_type, group_size, has_zp)¶
Verify that the Marlin kernel supports this configuration.
- Parameters:
quant_type (_ScalarTypeInfo)
group_size (int)
has_zp (bool)
- Return type:
None
- pymllm.quantization.methods.awq_marlin.verify_marlin_supports_shape(output_size_per_partition, input_size_per_partition, input_size, group_size)¶
Verify that tensor dimensions are compatible with Marlin.
- Parameters:
output_size_per_partition (int)
input_size_per_partition (int)
input_size (int)
group_size (int)
- Return type:
None
- pymllm.quantization.methods.awq_marlin.marlin_make_workspace(device)¶
Create Marlin workspace buffer for threadblock synchronization.
- Parameters:
device (torch.device)
- Return type:
torch.Tensor
- pymllm.quantization.methods.awq_marlin.marlin_make_empty_g_idx(device)¶
Create empty g_idx tensor (AWQ doesn’t use activation reordering).
- Parameters:
device (torch.device)
- Return type:
torch.Tensor
- pymllm.quantization.methods.awq_marlin.get_scale_perms()¶
Get the scale permutation indices for Marlin format.
- pymllm.quantization.methods.awq_marlin.marlin_permute_scales(s, size_k, size_n, group_size)¶
Permute quantization scales from standard to Marlin layout.
- Parameters:
s (torch.Tensor)
size_k (int)
size_n (int)
group_size (int)
- Return type:
torch.Tensor
- pymllm.quantization.methods.awq_marlin.pack_cols(q_w, num_bits, size_k, size_n)¶
Pack quantized columns into int32 values.
- Parameters:
q_w (torch.Tensor)
num_bits (int)
size_k (int)
size_n (int)
- Return type:
torch.Tensor
- pymllm.quantization.methods.awq_marlin.unpack_cols(packed, num_bits, size_k, size_n)¶
Unpack int32 packed columns into individual quantized values.
- Parameters:
packed (torch.Tensor)
num_bits (int)
size_k (int)
size_n (int)
- Return type:
torch.Tensor
- pymllm.quantization.methods.awq_marlin.marlin_zero_points(zp, size_k, size_n, num_bits)¶
Permute and pack zero points into Marlin format.
- Parameters:
zp (torch.Tensor)
size_k (int)
size_n (int)
num_bits (int)
- Return type:
torch.Tensor
- pymllm.quantization.methods.awq_marlin.awq_to_marlin_zero_points(q_zp_packed, size_k, size_n, num_bits)¶
Convert AWQ-format zero points to Marlin format.
AWQ zero-points are quantized and packed on the column dim with a specific interleaving. This function undoes the AWQ interleaving, then applies Marlin permutation and repacks.
- Parameters:
q_zp_packed (torch.Tensor)
size_k (int)
size_n (int)
num_bits (int)
- Return type:
torch.Tensor
- pymllm.quantization.methods.awq_marlin.replace_parameter(layer, name, new_data)¶
Replace a parameter on a layer with new data.
- Parameters:
layer (torch.nn.Module)
name (str)
new_data (torch.Tensor)
- Return type:
None
- class pymllm.quantization.methods.awq_marlin.AWQMarlinLinearMethod(quant_config)¶
Bases:
pymllm.layers.quantize_base.LinearMethodBaseLinear method for AWQ with Marlin kernel acceleration.
Uses the Marlin W4A16 GEMM kernel for high-performance inference. Weights are repacked from AWQ format to Marlin format after loading.
- Parameters:
quant_config (AWQMarlinConfig)
- quant_config¶
- create_weights(layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs)¶
Create quantized weight tensors on layer.
- Parameters:
layer (torch.nn.Module) – The linear module that will own the parameters.
input_size_per_partition (int) – Number of input features on this TP rank.
output_partition_sizes (List[int]) – Output sizes of each logical weight on this TP rank. For a standard linear layer this is
[out_features_per_partition]. For a merged QKV layer it might be[q_size, k_size, v_size].input_size (int) – Full (un-sharded) input dimension.
output_size (int) – Full (un-sharded) output dimension.
params_dtype (torch.dtype) – Data type for full-precision parameters (e.g.
torch.float16).**extra_weight_attrs (Any) – Additional metadata to attach to created parameters (e.g.
weight_loader,packed_dim,packed_factor).W4A16):: (Example (AWQ) –
# Register packed 4-bit weights, scales, and zero-points qweight = Parameter(torch.empty(…, dtype=torch.int32)) layer.register_parameter(“qweight”, qweight)
scales = Parameter(torch.empty(…, dtype=params_dtype)) layer.register_parameter(“scales”, scales)
qzeros = Parameter(torch.empty(…, dtype=torch.int32)) layer.register_parameter(“qzeros”, qzeros)
- Return type:
None
- process_weights_after_loading(layer)¶
Repack AWQ weights to Marlin format after checkpoint loading.
- Parameters:
layer (torch.nn.Module)
- Return type:
None
- apply(layer, x, bias=None)¶
Perform quantized matmul using the Marlin GEMM kernel.
- Parameters:
layer (torch.nn.Module)
x (torch.Tensor)
bias (Optional[torch.Tensor])
- Return type:
torch.Tensor
- class pymllm.quantization.methods.awq_marlin.AWQMarlinConfig(weight_bits, group_size, zero_point)¶
Bases:
pymllm.quantization.quant_config.QuantizationConfigConfiguration for AWQ quantization with Marlin kernel acceleration.
This config is used when loading models quantized with AutoAWQ and running inference with the high-performance Marlin W4A16 GEMM kernel.
Registered as
"awq_marlin"in the quantization registry.- Parameters:
weight_bits (int)
group_size (int)
zero_point (bool)
- weight_bits¶
- group_size¶
- zero_point¶
- pack_factor¶
- quant_type¶
- __repr__()¶
- Return type:
str
- get_name()¶
Return the canonical name of this quantization method.
Examples:
"awq","gptq","fp8","w8a8".- Return type:
str
- get_supported_act_dtypes()¶
Activation dtypes supported by this method.
Override to restrict (e.g. FP8 only supports
float16). Default: no restriction.- Return type:
List[torch.dtype]
- classmethod get_min_capability()¶
Minimum CUDA compute capability (e.g. 75 for Turing).
Default: 0 (no restriction).
- Return type:
int
- static get_config_filenames()¶
File names to look for in the checkpoint directory.
Default:
["quantize_config.json"].- Return type:
List[str]
- classmethod from_config(config)¶
Create an instance from a checkpoint’s quantization config dict.
- Parameters:
config (Dict[str, Any]) – Parsed JSON from the checkpoint’s
quantize_config.jsonor thequantization_configsection ofconfig.json.(AWQ):: (Example config dict) –
- {
“quant_method”: “awq”, “bits”: 4, “group_size”: 128, “zero_point”: true
}
- Return type:
- get_quant_method(layer, prefix='')¶
Return the quantization method for layer, or
Noneto skip.- Parameters:
layer (torch.nn.Module) – The
nn.Modulebeing constructed (e.g.ColumnParallelLinear).prefix (str) – The layer’s full dotted name in the model (e.g.
"model.layers.0.self_attn.q_proj"). Can be used to selectively skip quantization for certain layers.
- Returns:
The method instance.
Nonemeans this layer should fall back to the defaultUnquantizedLinearMethod.- Return type:
QuantizeMethodBase or None