pymllm Kernels and Acceleration¶
总览¶
pymllm 的性能路径由多类加速组件共同组成:
FlashInfer:paged KV cache attention。
CUDA Graph:decode 阶段减少 CPU launch overhead。
Triton:W8A8 per-token activation quantization。
CUTLASS:W8A8 INT8 Tensor Core GEMM。
mllm-kernel:基于 TVM-FFI / torch extension 的 JIT kernel 工具包。
这些组件不是彼此替代关系,而是在不同层次承担职责。attention backend 解决 KV cache
attention;CUDA Graph 解决重复 decode step 的 launch overhead;Triton 和 CUTLASS 解决量化
linear 的核心计算;mllm-kernel 为项目内自定义 CUDA/C++ kernel 提供封装、缓存和工具。
mllm-kernel¶
mllm-kernel 是 mllm 项目中的高性能 kernel 包。当前 Python 侧主要包含:
mllm_kernel.cuda.jit:CUDA JIT kernel wrapper。mllm_kernel.cpu.jit:CPU JIT kernel wrapper。mllm_kernel.jit_utils:JIT 编译、缓存、注册表和工具函数。
CUDA JIT kernel 的典型结构是:
Python wrapper
-> @jit(...)
-> include CUDA/C++ source
-> export TVM-FFI typed function
-> compile on first use
-> reuse cached shared library
默认 JIT 缓存目录为:
~/.cache/mllm_kernel/
mllm-kernel 的 JIT 路径与 SGLang 的 jit_kernel 设计关系更直接:二者都强调轻量
JIT、运行时选择模板实例、避免大型 AOT torch extension 带来的长编译周期。与此同时,SGLang
的 sgl-kernel AOT kernel 仍然是重要参考,尤其适合对照量化 GEMM 的语义和性能。
TVM-FFI JIT 路径¶
mllm_kernel.jit_utils.jit decorator 会将 Python 函数包装成一个按需编译的 kernel 调用。
它负责:
根据 tensor device 推断 CPU/CUDA 目标。
将 Python 参数转换为 C++ template 参数。
拼接 C++/CUDA source 和 export wrapper。
调用 TVM-FFI 编译并加载 shared library。
将编译结果缓存到
~/.cache/mllm_kernel。
这种方式适合小而明确的自定义 kernel,例如:
create_kv_indices:构造 FlashInfer KV index metadata。store_cache:将 K/V 写入 KVPool。gptq_marlin_repack:Marlin weight layout 转换。gptq_marlin_gemm:W4A16 Marlin GEMM。
W8A8 CUTLASS kernel 当前使用 torch.utils.cpp_extension.load 编译。这是因为 CUTLASS
模板和 include 体系较重,当前以稳定通过 Jetson SM87 编译为优先。
FlashInfer Attention¶
pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend 封装 FlashInfer 的 paged
KV cache attention。它负责:
为 prefill 和 decode 准备
kv_indptr、kv_indices、kv_last_page_len等 metadata。管理全局 workspace buffer。
根据是否存在 sliding window 选择 wrapper dispatch。
在 decode 中根据 GQA group size 和 KV dtype 决定是否使用 tensor core 路径。
为 CUDA Graph capture / replay 提供专用 metadata 初始化接口。
prefill 和 decode 使用不同 wrapper:
prefill / extend
BatchPrefillWithPagedKVCacheWrapper
BatchPrefillWithRaggedKVCacheWrapper
decode
BatchDecodeWithPagedKVCacheWrapper
attention backend 只负责 attention 计算和 metadata,不负责请求调度和 KV slot 生命周期。KV slot 的分配、释放和 prefix cache 命中由 scheduler / model runner 侧完成。
CUDA Graph¶
pymllm.executor.cuda_graph_runner.CudaGraphRunner 用于 decode step 的 CUDA Graph capture
和 replay。它的目标是减少小 batch decode 中 CPU launch overhead。
初始化阶段会按一组离散 batch size 捕获 graph:
[1, 2, 4, 8, 12, 16, 24, 32, ...]
每个 captured graph 复用预分配输入 buffer:
input_idsreq_pool_indicesseq_lensout_cache_locpositionsmrope_position_deltas
replay 时,真实 batch 会被 padding 到最近的 captured batch size。attention backend 会走专用
init_forward_metadata_replay_cuda_graph 路径,避免使用普通动态 metadata 初始化。
CUDA Graph 只覆盖 decode 主路径。调试模型、调试 attention metadata 或定位 shape 问题时,可以
使用 --server.disable_cuda_graph 暂时关闭。
W4A16 Marlin¶
W4A16 路径复用 Marlin kernel。checkpoint 权重先以 weight_packed 和 weight_scale
加载,然后在 post-load 阶段转换为 Marlin runtime layout。
关键 kernel:
mllm_kernel.cuda.jit.gptq_marlin_repackmllm_kernel.cuda.jit.gptq_marlin
执行约束包括:
SM80+
output partition 可被 64 整除
input partition 可被 128 整除
group size 当前主路径为 32
这种路径适合 AWQ / W4A16 类权重量化模型,activation 保持 FP16/BF16。
W8A8 Triton + CUTLASS¶
W8A8 路径包含两个核心 kernel:
pymllm.quantization.kernels.int8_activation_triton.per_token_quant_int8mllm_kernel.cuda.jit.int8_scaled_mm_cutlass.int8_scaled_mm
运行时链路:
[M, K] fp16/bf16 activation
-> Triton per-token absmax + round + int8 cast
-> [M, K] int8 + [M, 1] fp32 scale
-> CUTLASS int8 GEMM with per-row/per-col scales
-> [M, N] fp16/bf16 output
CUTLASS kernel 要求 mat_b 为 [K, N] column-major,因此 W8A8 scheme 会在
process_weights_after_loading 中把 checkpoint 的 [N, K] INT8 weight 转成对应布局。
当前 CUTLASS include 查找顺序为:
CUTLASS_HOME/includeflashinferbundled CUTLASS系统 include 目录
如果找不到 CUTLASS 头文件,W8A8 初始化会失败。生产环境建议在镜像中固定 CUTLASS 来源,避免 不同节点使用不同版本头文件。
GDN decode kernel¶
Qwen3.5 等 hybrid 模型可能包含 GDN / linear attention 层。pymllm 为这类模型保留了:
pymllm.layers.attention.gdn_backendpymllm.layers.attention.hybrid_backendmllm_kernel.cuda.jit.gdn_decodeMambaRadixCache/ GDN state cache 相关结构
当前文档重点覆盖 Qwen3 / Qwen3-VL 主路径。GDN 相关路径仍应以具体模型和测试结果为准。
调试与观测¶
常用检查命令:
python3 -m mllm_kernel show-env
python3 -m mllm_kernel show-config
python3 -m pymllm show-config
当首次运行时间异常长时,应区分:
模型权重加载时间。
FlashInfer / CUDA context 初始化时间。
CUTLASS JIT 编译时间。
CUDA Graph capture 时间。
实际 prefill/decode 时间。