pymllm.quantization.methods.awq_marlin ====================================== .. py:module:: pymllm.quantization.methods.awq_marlin .. autoapi-nested-parse:: AWQ quantization with Marlin kernel acceleration. This module implements the AWQ Marlin quantization plugin for pymllm, providing high-performance W4A16 inference via the Marlin GEMM kernel. Classes ------- AWQMarlinConfig Quantization configuration parsed from ``quantize_config.json``. AWQMarlinLinearMethod Linear method that uses AWQ weight format with Marlin kernel dispatch. Attributes ---------- .. autoapisummary:: pymllm.quantization.methods.awq_marlin.logger pymllm.quantization.methods.awq_marlin.MARLIN_SUPPORTED_GROUP_SIZES pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_MIN_THREAD_N pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_MIN_THREAD_K pymllm.quantization.methods.awq_marlin.GPTQ_MARLIN_TILE pymllm.quantization.methods.awq_marlin.SCALAR_TYPE_UINT4 pymllm.quantization.methods.awq_marlin.SCALAR_TYPE_UINT8 Classes ------- .. autoapisummary:: pymllm.quantization.methods.awq_marlin.AWQMarlinLinearMethod pymllm.quantization.methods.awq_marlin.AWQMarlinConfig Functions --------- .. autoapisummary:: pymllm.quantization.methods.awq_marlin.verify_marlin_supported pymllm.quantization.methods.awq_marlin.verify_marlin_supports_shape pymllm.quantization.methods.awq_marlin.marlin_make_workspace pymllm.quantization.methods.awq_marlin.marlin_make_empty_g_idx pymllm.quantization.methods.awq_marlin.get_scale_perms pymllm.quantization.methods.awq_marlin.marlin_permute_scales pymllm.quantization.methods.awq_marlin.pack_cols pymllm.quantization.methods.awq_marlin.unpack_cols pymllm.quantization.methods.awq_marlin.marlin_zero_points pymllm.quantization.methods.awq_marlin.awq_to_marlin_zero_points pymllm.quantization.methods.awq_marlin.replace_parameter Module Contents --------------- .. py:data:: logger .. py:data:: MARLIN_SUPPORTED_GROUP_SIZES .. py:data:: GPTQ_MARLIN_MIN_THREAD_N :value: 64 .. py:data:: GPTQ_MARLIN_MIN_THREAD_K :value: 128 .. py:data:: GPTQ_MARLIN_TILE :value: 16 .. py:data:: SCALAR_TYPE_UINT4 .. py:data:: SCALAR_TYPE_UINT8 .. py:function:: verify_marlin_supported(quant_type, group_size, has_zp) Verify that the Marlin kernel supports this configuration. .. py:function:: verify_marlin_supports_shape(output_size_per_partition, input_size_per_partition, input_size, group_size) Verify that tensor dimensions are compatible with Marlin. .. py:function:: marlin_make_workspace(device) Create Marlin workspace buffer for threadblock synchronization. .. py:function:: marlin_make_empty_g_idx(device) Create empty g_idx tensor (AWQ doesn't use activation reordering). .. py:function:: get_scale_perms() Get the scale permutation indices for Marlin format. .. py:function:: marlin_permute_scales(s, size_k, size_n, group_size) Permute quantization scales from standard to Marlin layout. .. py:function:: pack_cols(q_w, num_bits, size_k, size_n) Pack quantized columns into int32 values. .. py:function:: unpack_cols(packed, num_bits, size_k, size_n) Unpack int32 packed columns into individual quantized values. .. py:function:: marlin_zero_points(zp, size_k, size_n, num_bits) Permute and pack zero points into Marlin format. .. py:function:: awq_to_marlin_zero_points(q_zp_packed, size_k, size_n, num_bits) Convert AWQ-format zero points to Marlin format. AWQ zero-points are quantized and packed on the column dim with a specific interleaving. This function undoes the AWQ interleaving, then applies Marlin permutation and repacks. .. py:function:: replace_parameter(layer, name, new_data) Replace a parameter on a layer with new data. .. py:class:: AWQMarlinLinearMethod(quant_config) Bases: :py:obj:`pymllm.layers.quantize_base.LinearMethodBase` Linear method for AWQ with Marlin kernel acceleration. Uses the Marlin W4A16 GEMM kernel for high-performance inference. Weights are repacked from AWQ format to Marlin format after loading. .. py:attribute:: quant_config .. py:method:: create_weights(layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs) Create quantized weight tensors on *layer*. :param layer: The linear module that will own the parameters. :param input_size_per_partition: Number of input features on this TP rank. :param output_partition_sizes: Output sizes of each logical weight on this TP rank. For a standard linear layer this is ``[out_features_per_partition]``. For a merged QKV layer it might be ``[q_size, k_size, v_size]``. :param input_size: Full (un-sharded) input dimension. :param output_size: Full (un-sharded) output dimension. :param params_dtype: Data type for full-precision parameters (e.g. ``torch.float16``). :param \*\*extra_weight_attrs: Additional metadata to attach to created parameters (e.g. ``weight_loader``, ``packed_dim``, ``packed_factor``). :param Example (AWQ W4A16)::: # Register packed 4-bit weights, scales, and zero-points qweight = Parameter(torch.empty(..., dtype=torch.int32)) layer.register_parameter("qweight", qweight) scales = Parameter(torch.empty(..., dtype=params_dtype)) layer.register_parameter("scales", scales) qzeros = Parameter(torch.empty(..., dtype=torch.int32)) layer.register_parameter("qzeros", qzeros) .. py:method:: process_weights_after_loading(layer) Repack AWQ weights to Marlin format after checkpoint loading. .. py:method:: apply(layer, x, bias = None) Perform quantized matmul using the Marlin GEMM kernel. .. py:class:: AWQMarlinConfig(weight_bits, group_size, zero_point) Bases: :py:obj:`pymllm.quantization.quant_config.QuantizationConfig` Configuration for AWQ quantization with Marlin kernel acceleration. This config is used when loading models quantized with AutoAWQ and running inference with the high-performance Marlin W4A16 GEMM kernel. Registered as ``"awq_marlin"`` in the quantization registry. .. py:attribute:: weight_bits .. py:attribute:: group_size .. py:attribute:: zero_point .. py:attribute:: pack_factor .. py:attribute:: quant_type .. py:method:: __repr__() .. py:method:: get_name() Return the canonical name of this quantization method. Examples: ``"awq"``, ``"gptq"``, ``"fp8"``, ``"w8a8"``. .. py:method:: get_supported_act_dtypes() Activation dtypes supported by this method. Override to restrict (e.g. FP8 only supports ``float16``). Default: no restriction. .. py:method:: get_min_capability() :classmethod: Minimum CUDA compute capability (e.g. 75 for Turing). Default: 0 (no restriction). .. py:method:: get_config_filenames() :staticmethod: File names to look for in the checkpoint directory. Default: ``["quantize_config.json"]``. .. py:method:: from_config(config) :classmethod: Create an instance from a checkpoint's quantization config dict. :param config: Parsed JSON from the checkpoint's ``quantize_config.json`` or the ``quantization_config`` section of ``config.json``. :param Example config dict (AWQ)::: { "quant_method": "awq", "bits": 4, "group_size": 128, "zero_point": true } .. py:method:: get_quant_method(layer, prefix = '') Return the quantization method for *layer*, or ``None`` to skip. :param layer: The ``nn.Module`` being constructed (e.g. ``ColumnParallelLinear``). :param prefix: The layer's full dotted name in the model (e.g. ``"model.layers.0.self_attn.q_proj"``). Can be used to selectively skip quantization for certain layers. :returns: The method instance. ``None`` means this layer should fall back to the default :class:`~pymllm.layers.quantize_base.UnquantizedLinearMethod`. :rtype: QuantizeMethodBase or None