pymllm.quantization.kernels.int8_activation_triton

Per-token INT8 activation quantization using Triton.

Ported from sglang int8_kernel.py (per_token_quant_int8). Original: sglang/srt/layers/quantization/int8_kernel.py:28-89

Functions

per_token_quant_int8(x[, scale_dtype])

Per-token dynamic INT8 quantization.

Module Contents

pymllm.quantization.kernels.int8_activation_triton.per_token_quant_int8(x, scale_dtype=torch.float32)

Per-token dynamic INT8 quantization.

Parameters:
  • x (torch.Tensor) – Input tensor, any shape with last dim = hidden_dim. Must be contiguous.

  • scale_dtype (torch.dtype) – Dtype for scale output (default float32).

Returns:

INT8 quantized tensor, same shape as x. scales: Per-token scales, shape = x.shape[:-1] + (1,).

Return type:

x_q