pymllm.quantization.methods.awq_marlin

AWQ quantization with Marlin kernel acceleration.

This module implements the AWQ Marlin quantization plugin for pymllm, providing high-performance W4A16 inference via the Marlin GEMM kernel.

Classes

AWQMarlinConfig

Quantization configuration parsed from quantize_config.json.

AWQMarlinLinearMethod

Linear method that uses AWQ weight format with Marlin kernel dispatch.

Attributes

Classes

AWQMarlinLinearMethod

Linear method for AWQ with Marlin kernel acceleration.

AWQMarlinConfig

Configuration for AWQ quantization with Marlin kernel acceleration.

Functions

verify_marlin_supported(quant_type, group_size, has_zp)

Verify that the Marlin kernel supports this configuration.

verify_marlin_supports_shape(...)

Verify that tensor dimensions are compatible with Marlin.

marlin_make_workspace(device)

Create Marlin workspace buffer for threadblock synchronization.

marlin_make_empty_g_idx(device)

Create empty g_idx tensor (AWQ doesn't use activation reordering).

get_scale_perms()

Get the scale permutation indices for Marlin format.

marlin_permute_scales(s, size_k, size_n, group_size)

Permute quantization scales from standard to Marlin layout.

pack_cols(q_w, num_bits, size_k, size_n)

Pack quantized columns into int32 values.

unpack_cols(packed, num_bits, size_k, size_n)

Unpack int32 packed columns into individual quantized values.

marlin_zero_points(zp, size_k, size_n, num_bits)

Permute and pack zero points into Marlin format.

awq_to_marlin_zero_points(q_zp_packed, size_k, size_n, ...)

Convert AWQ-format zero points to Marlin format.

replace_parameter(layer, name, new_data)

Replace a parameter on a layer with new data.

Module Contents

pymllm.quantization.methods.awq_marlin.logger
pymllm.quantization.methods.awq_marlin.MARLIN_SUPPORTED_GROUP_SIZES
pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_MIN_THREAD_N = 64
pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_MIN_THREAD_K = 128
pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_TILE = 16
pymllm.quantization.methods.awq_marlin.SCALAR_TYPE_UINT4
pymllm.quantization.methods.awq_marlin.SCALAR_TYPE_UINT8
pymllm.quantization.methods.awq_marlin.verify_marlin_supported(quant_type, group_size, has_zp)

Verify that the Marlin kernel supports this configuration.

Parameters:
  • quant_type (_ScalarTypeInfo)

  • group_size (int)

  • has_zp (bool)

Return type:

None

pymllm.quantization.methods.awq_marlin.verify_marlin_supports_shape(output_size_per_partition, input_size_per_partition, input_size, group_size)

Verify that tensor dimensions are compatible with Marlin.

Parameters:
  • output_size_per_partition (int)

  • input_size_per_partition (int)

  • input_size (int)

  • group_size (int)

Return type:

None

pymllm.quantization.methods.awq_marlin.marlin_make_workspace(device)

Create Marlin workspace buffer for threadblock synchronization.

Parameters:

device (torch.device)

Return type:

torch.Tensor

pymllm.quantization.methods.awq_marlin.marlin_make_empty_g_idx(device)

Create empty g_idx tensor (AWQ doesn’t use activation reordering).

Parameters:

device (torch.device)

Return type:

torch.Tensor

pymllm.quantization.methods.awq_marlin.get_scale_perms()

Get the scale permutation indices for Marlin format.

pymllm.quantization.methods.awq_marlin.marlin_permute_scales(s, size_k, size_n, group_size)

Permute quantization scales from standard to Marlin layout.

Parameters:
  • s (torch.Tensor)

  • size_k (int)

  • size_n (int)

  • group_size (int)

Return type:

torch.Tensor

pymllm.quantization.methods.awq_marlin.pack_cols(q_w, num_bits, size_k, size_n)

Pack quantized columns into int32 values.

Parameters:
  • q_w (torch.Tensor)

  • num_bits (int)

  • size_k (int)

  • size_n (int)

Return type:

torch.Tensor

pymllm.quantization.methods.awq_marlin.unpack_cols(packed, num_bits, size_k, size_n)

Unpack int32 packed columns into individual quantized values.

Parameters:
  • packed (torch.Tensor)

  • num_bits (int)

  • size_k (int)

  • size_n (int)

Return type:

torch.Tensor

pymllm.quantization.methods.awq_marlin.marlin_zero_points(zp, size_k, size_n, num_bits)

Permute and pack zero points into Marlin format.

Parameters:
  • zp (torch.Tensor)

  • size_k (int)

  • size_n (int)

  • num_bits (int)

Return type:

torch.Tensor

pymllm.quantization.methods.awq_marlin.awq_to_marlin_zero_points(q_zp_packed, size_k, size_n, num_bits)

Convert AWQ-format zero points to Marlin format.

AWQ zero-points are quantized and packed on the column dim with a specific interleaving. This function undoes the AWQ interleaving, then applies Marlin permutation and repacks.

Parameters:
  • q_zp_packed (torch.Tensor)

  • size_k (int)

  • size_n (int)

  • num_bits (int)

Return type:

torch.Tensor

pymllm.quantization.methods.awq_marlin.replace_parameter(layer, name, new_data)

Replace a parameter on a layer with new data.

Parameters:
  • layer (torch.nn.Module)

  • name (str)

  • new_data (torch.Tensor)

Return type:

None

class pymllm.quantization.methods.awq_marlin.AWQMarlinLinearMethod(quant_config)

Bases: pymllm.layers.quantize_base.LinearMethodBase

Linear method for AWQ with Marlin kernel acceleration.

Uses the Marlin W4A16 GEMM kernel for high-performance inference. Weights are repacked from AWQ format to Marlin format after loading.

Parameters:

quant_config (AWQMarlinConfig)

quant_config
create_weights(layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs)

Create quantized weight tensors on layer.

Parameters:
  • layer (torch.nn.Module) – The linear module that will own the parameters.

  • input_size_per_partition (int) – Number of input features on this TP rank.

  • output_partition_sizes (List[int]) – Output sizes of each logical weight on this TP rank. For a standard linear layer this is [out_features_per_partition]. For a merged QKV layer it might be [q_size, k_size, v_size].

  • input_size (int) – Full (un-sharded) input dimension.

  • output_size (int) – Full (un-sharded) output dimension.

  • params_dtype (torch.dtype) – Data type for full-precision parameters (e.g. torch.float16).

  • **extra_weight_attrs (Any) – Additional metadata to attach to created parameters (e.g. weight_loader, packed_dim, packed_factor).

  • W4A16):: (Example (AWQ) –

    # Register packed 4-bit weights, scales, and zero-points qweight = Parameter(torch.empty(…, dtype=torch.int32)) layer.register_parameter(“qweight”, qweight)

    scales = Parameter(torch.empty(…, dtype=params_dtype)) layer.register_parameter(“scales”, scales)

    qzeros = Parameter(torch.empty(…, dtype=torch.int32)) layer.register_parameter(“qzeros”, qzeros)

Return type:

None

process_weights_after_loading(layer)

Repack AWQ weights to Marlin format after checkpoint loading.

Parameters:

layer (torch.nn.Module)

Return type:

None

apply(layer, x, bias=None)

Perform quantized matmul using the Marlin GEMM kernel.

Parameters:
  • layer (torch.nn.Module)

  • x (torch.Tensor)

  • bias (Optional[torch.Tensor])

Return type:

torch.Tensor

class pymllm.quantization.methods.awq_marlin.AWQMarlinConfig(weight_bits, group_size, zero_point)

Bases: pymllm.quantization.quant_config.QuantizationConfig

Configuration for AWQ quantization with Marlin kernel acceleration.

This config is used when loading models quantized with AutoAWQ and running inference with the high-performance Marlin W4A16 GEMM kernel.

Registered as "awq_marlin" in the quantization registry.

Parameters:
  • weight_bits (int)

  • group_size (int)

  • zero_point (bool)

weight_bits
group_size
zero_point
pack_factor
quant_type
__repr__()
Return type:

str

get_name()

Return the canonical name of this quantization method.

Examples: "awq", "gptq", "fp8", "w8a8".

Return type:

str

get_supported_act_dtypes()

Activation dtypes supported by this method.

Override to restrict (e.g. FP8 only supports float16). Default: no restriction.

Return type:

List[torch.dtype]

classmethod get_min_capability()

Minimum CUDA compute capability (e.g. 75 for Turing).

Default: 0 (no restriction).

Return type:

int

static get_config_filenames()

File names to look for in the checkpoint directory.

Default: ["quantize_config.json"].

Return type:

List[str]

classmethod from_config(config)

Create an instance from a checkpoint’s quantization config dict.

Parameters:
  • config (Dict[str, Any]) – Parsed JSON from the checkpoint’s quantize_config.json or the quantization_config section of config.json.

  • (AWQ):: (Example config dict) –

    {

    “quant_method”: “awq”, “bits”: 4, “group_size”: 128, “zero_point”: true

    }

Return type:

AWQMarlinConfig

get_quant_method(layer, prefix='')

Return the quantization method for layer, or None to skip.

Parameters:
  • layer (torch.nn.Module) – The nn.Module being constructed (e.g. ColumnParallelLinear).

  • prefix (str) – The layer’s full dotted name in the model (e.g. "model.layers.0.self_attn.q_proj"). Can be used to selectively skip quantization for certain layers.

Returns:

The method instance. None means this layer should fall back to the default UnquantizedLinearMethod.

Return type:

QuantizeMethodBase or None