pymllm.quantization.methods.compressed_tensors ============================================== .. py:module:: pymllm.quantization.methods.compressed_tensors Attributes ---------- .. autoapisummary:: pymllm.quantization.methods.compressed_tensors.MARLIN_SUPPORTED_GROUP_SIZES pymllm.quantization.methods.compressed_tensors.GPTQ_MARLIN_MIN_THREAD_N pymllm.quantization.methods.compressed_tensors.GPTQ_MARLIN_MIN_THREAD_K pymllm.quantization.methods.compressed_tensors.GPTQ_MARLIN_TILE pymllm.quantization.methods.compressed_tensors.SCALAR_TYPE_UINT4 pymllm.quantization.methods.compressed_tensors.SCALAR_TYPE_UINT4B8 Classes ------- .. autoapisummary:: pymllm.quantization.methods.compressed_tensors.CompressedTensorsWNA16Scheme pymllm.quantization.methods.compressed_tensors.CompressedTensorsW8A8Int8Scheme pymllm.quantization.methods.compressed_tensors.CompressedTensorsLinearMethod pymllm.quantization.methods.compressed_tensors.CompressedTensorsConfig Functions --------- .. autoapisummary:: pymllm.quantization.methods.compressed_tensors.verify_marlin_supported pymllm.quantization.methods.compressed_tensors.verify_marlin_supports_shape pymllm.quantization.methods.compressed_tensors.marlin_make_workspace pymllm.quantization.methods.compressed_tensors.marlin_make_empty_g_idx pymllm.quantization.methods.compressed_tensors.get_scale_perms pymllm.quantization.methods.compressed_tensors.marlin_permute_scales pymllm.quantization.methods.compressed_tensors.replace_parameter Module Contents --------------- .. py:data:: MARLIN_SUPPORTED_GROUP_SIZES .. py:data:: GPTQ_MARLIN_MIN_THREAD_N :value: 64 .. py:data:: GPTQ_MARLIN_MIN_THREAD_K :value: 128 .. py:data:: GPTQ_MARLIN_TILE :value: 16 .. py:data:: SCALAR_TYPE_UINT4 .. py:data:: SCALAR_TYPE_UINT4B8 .. py:function:: verify_marlin_supported(group_size) .. py:function:: verify_marlin_supports_shape(output_size_per_partition, input_size_per_partition, input_size, group_size) .. py:function:: marlin_make_workspace(device) .. py:function:: marlin_make_empty_g_idx(device) .. py:function:: get_scale_perms() .. py:function:: marlin_permute_scales(s, size_k, size_n, group_size) .. py:function:: replace_parameter(layer, name, new_data) .. py:class:: CompressedTensorsWNA16Scheme(*, weight_bits, group_size, symmetric, actorder) .. py:attribute:: weight_bits .. py:attribute:: group_size .. py:attribute:: symmetric .. py:attribute:: actorder .. py:attribute:: pack_factor .. py:attribute:: quant_type .. py:method:: create_weights(layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs) .. py:method:: process_weights_after_loading(layer) .. py:method:: apply(layer, x, bias = None) .. py:class:: CompressedTensorsW8A8Int8Scheme(*, weight_bits) .. py:attribute:: weight_bits .. py:method:: create_weights(layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs) .. py:method:: process_weights_after_loading(layer) .. py:method:: apply(layer, x, bias = None) .. py:class:: CompressedTensorsLinearMethod(quant_config, signature) Bases: :py:obj:`pymllm.layers.quantize_base.LinearMethodBase` Base class for quantization methods applied to linear layers. Narrows the :class:`QuantizeMethodBase` interface with concrete signatures tailored to linear (matmul) operations. Subclasses must implement :meth:`create_weights` and :meth:`apply`. .. py:attribute:: quant_config .. py:attribute:: scheme .. py:method:: create_weights(*args, **kwargs) Create quantized weight tensors on *layer*. :param layer: The linear module that will own the parameters. :param input_size_per_partition: Number of input features on this TP rank. :param output_partition_sizes: Output sizes of each logical weight on this TP rank. For a standard linear layer this is ``[out_features_per_partition]``. For a merged QKV layer it might be ``[q_size, k_size, v_size]``. :param input_size: Full (un-sharded) input dimension. :param output_size: Full (un-sharded) output dimension. :param params_dtype: Data type for full-precision parameters (e.g. ``torch.float16``). :param \*\*extra_weight_attrs: Additional metadata to attach to created parameters (e.g. ``weight_loader``, ``packed_dim``, ``packed_factor``). :param Example (AWQ W4A16)::: # Register packed 4-bit weights, scales, and zero-points qweight = Parameter(torch.empty(..., dtype=torch.int32)) layer.register_parameter("qweight", qweight) scales = Parameter(torch.empty(..., dtype=params_dtype)) layer.register_parameter("scales", scales) qzeros = Parameter(torch.empty(..., dtype=torch.int32)) layer.register_parameter("qzeros", qzeros) .. py:method:: process_weights_after_loading(layer) Post-process parameters after checkpoint loading. Called once by ``ModelRunner`` after all checkpoint tensors have been loaded into the layer's parameters. Use this for: * **Repacking**: converting checkpoint layout to kernel-native layout (e.g. AutoAWQ int4 → Marlin packed format). * **Transposing**: rearranging dimensions for optimised GEMM kernels. * **Calibration**: computing per-tensor or per-channel scales from the loaded FP weights (e.g. dynamic FP8 quantisation). * **Cleanup**: replacing custom parameter wrappers with plain ``torch.nn.Parameter`` to avoid overhead during inference. The default implementation is a no-op. .. py:method:: apply(layer, x, bias = None) Compute the quantized linear forward. :param layer: The module that owns quantized parameters (set by :meth:`create_weights`). :param x: Input activation tensor, shape ``(*, input_size_per_partition)``. :param bias: Optional bias vector. :returns: * *torch.Tensor* -- Output tensor, shape ``(*, sum(output_partition_sizes))``. * *Example (AWQ W4A16)::* -- qweight = layer.qweight # packed int32 scales = layer.scales # fp16 per-group scales qzeros = layer.qzeros # packed int32 zero-points # → invoke dequant + matmul kernel .. py:class:: CompressedTensorsConfig(*, quant_format, ignore, weight_bits, group_size, weight_strategy, weight_type, weight_dynamic, symmetric, actorder, input_bits, input_strategy, input_type, input_dynamic, input_symmetric) Bases: :py:obj:`pymllm.quantization.quant_config.QuantizationConfig` Base class for quantization configurations. A ``QuantizationConfig`` is instantiated once per model load. It reads quantization metadata from the checkpoint (bit-width, group size, etc.) and provides :class:`~pymllm.layers.quantize_base.QuantizeMethodBase` instances to each layer. Subclass contract ----------------- * :meth:`get_name` — return the method name (e.g. ``"awq"``). * :meth:`from_config` — class method that parses a dict from the checkpoint's ``quantize_config.json``. * :meth:`get_quant_method` — return the appropriate ``LinearMethodBase`` (or ``None`` to skip quantization for a layer). Optional overrides ------------------ * :meth:`get_supported_act_dtypes` — restrict activation dtypes. * :meth:`get_min_capability` — minimum GPU compute capability. * :meth:`get_config_filenames` — files to probe in the checkpoint dir. .. py:attribute:: quant_format .. py:attribute:: ignore .. py:attribute:: weight_bits .. py:attribute:: group_size .. py:attribute:: weight_strategy .. py:attribute:: weight_type .. py:attribute:: weight_dynamic .. py:attribute:: symmetric .. py:attribute:: actorder .. py:attribute:: input_bits .. py:attribute:: input_strategy .. py:attribute:: input_type .. py:attribute:: input_dynamic .. py:attribute:: input_symmetric .. py:method:: get_name() Return the canonical name of this quantization method. Examples: ``"awq"``, ``"gptq"``, ``"fp8"``, ``"w8a8"``. .. py:method:: get_supported_act_dtypes() Activation dtypes supported by this method. Override to restrict (e.g. FP8 only supports ``float16``). Default: no restriction. .. py:method:: get_min_capability() :classmethod: Minimum CUDA compute capability (e.g. 75 for Turing). Default: 0 (no restriction). .. py:method:: get_config_filenames() :staticmethod: File names to look for in the checkpoint directory. Default: ``["quantize_config.json"]``. .. py:method:: from_config(config) :classmethod: Create an instance from a checkpoint's quantization config dict. :param config: Parsed JSON from the checkpoint's ``quantize_config.json`` or the ``quantization_config`` section of ``config.json``. :param Example config dict (AWQ)::: { "quant_method": "awq", "bits": 4, "group_size": 128, "zero_point": true } .. py:method:: get_quant_method(layer, prefix = '') Return the quantization method for *layer*, or ``None`` to skip. :param layer: The ``nn.Module`` being constructed (e.g. ``ColumnParallelLinear``). :param prefix: The layer's full dotted name in the model (e.g. ``"model.layers.0.self_attn.q_proj"``). Can be used to selectively skip quantization for certain layers. :returns: The method instance. ``None`` means this layer should fall back to the default :class:`~pymllm.layers.quantize_base.UnquantizedLinearMethod`. :rtype: QuantizeMethodBase or None