How to Support a New LLM: Step-by-Step¶
This guide walks you through adding a brand-new Large Language Model (LLM) to the mllm inference framework. We use Qwen3 as a running example, but the same workflow applies to any transformer-style model.
Step 1: Acquire the Model Assets¶
Download the original model from Hugging-Face (or any other reputable source).
Typical files you need:
config.jsontokenizer.json/tokenizer.modelPyTorch / Safetensors checkpoints (
.bin,.safetensors)
Place everything under a single directory, e.g.
~/models/Qwen3-0.6B.
Note
Models obtained from hosting platforms such as Hugging Face or ModelScope (via git clone or their official CLI) are already organized in a single directory that contains config.json, tokenizer.json, tokenizer.model, checkpoint shards, etc.
You can download Qwen3-0.6B from ModelScope with the following command:
git clone https://www.modelscope.cn/Qwen/Qwen3-0.6B.git
Note
About Model Version: Most mllm models on Hugging-Face are in v1 version file format, which has no tensor shape and supports less data types. When loading a model, you can specify the model version. We maintain the compatibility of both v1 and v2 formats in mllm. It is recommended to use v2 format for new models whenever possible.
Step 2: Convert to mllm Format¶
mllm ships a dedicated converter called mllm-convertor.
It translates Hugging-Face / PyTorch checkpoints into the internal *.mllm format.
Install pymllm (Python bindings)¶
bash ./scripts/install_pymllm.sh
Note
Once the mllm organisation is approved on PyPI you will be able to run:
pip install pymllm
Run the converter¶
mllm-convertor \
--input_path ./Qwen3-0.6B/model.safetensors \
--output_path ./Qwen3-0.6B/w4a32.mllm \
--cfg_path ./Qwen3-0.6B/quant_config.json \
--pipeline w4a32_kai_pipeline
For sharded checkpoints the converter automatically follows the *.index.json file—no manual merging required.
Custom quantization recipe (optional)¶
Supply a JSON map if you need quantization for specific layers.
Example snippet:
{
// KAI Config
"^model\\.layers\\.\\d+\\.self_attn\\.q_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [2048, 1024],
"replace": true
}
}
// GGUF Config
"^model\\.layers\\.\\d+\\.self_attn\\.q_proj.(bias|weight)": {
"hints": {
"quant_method": "gguf",
"gguf_type": "Q4_0",
"shape": [2048, 1024],
"replace": true
}
}
(...)
}
Save the file (e.g. quant_config.json) and pass it with --cfg_path quant_config.json.
See full quant_config.json in Appendix.
Step 3: (Optional) On-Device Quantization with mllm-quantizer¶
mllm-convertor already embeds the quantization engine, but mllm-quantizer is still useful when:
you target an Android handset and want to quantize on the phone, or
the quantization kernel is only compiled for ARM.
mllm-quantizer \
-i ./Qwen3-0.6B/model.mllm \
-c ./Qwen3-0.6B/quant_config.json \
-iv v2 \
-o ./Qwen3-0.6B/w4a32.mllm \
-ov v2
Note
Basically, if you have no ARM DEVICE(Mac with apple silicon or Arm PC) to quantize your model through pymllm in kai settings. You should use mllm-quantizer to quantize your model on your arm devices (maybe android phone).
Supported Quantization Types in mllm-quantizer¶
The mllm-quantizer tool supports several quantization types, allowing you to optimize model size and inference speed for different hardware targets. Below is a summary of the main quantization types available:
GGUF Quantization Types (CPU, cross-platform):
Q4_0, Q8_0, Q2_K, Q3_K, Q4_K, Q6_K, Q8_K: GGUF per-group quantization.
KAI Quantization Types (ARM, Apple Silicon):
KAI_fp16_fp16_fp16p_mxk_kxn: FP16 result, FP16 activation, FP16 weight, packed in kai format.
KAI_f32_qai8dxp_qsi4c32p_mxk_nxk: FP16 result, Int8 activation(asymmetric, per-token), Int4 weight(symmetric, per-group, 32 pack, transposed layout).
KAI_f16_qsi8d32p_qai4c32p_mxk_nxk: FP16 result, Int8 activation(symmetric, per-group, 32 pack), Int4 weight(asymmetric, per-group, 32 pack, transposed layout).
How to Select Quantization Type: - For general CPU deployment, use GGUF types (Q4_0, Q8_0, etc.). - For ARM devices (Android, Apple Silicon), use KAI types for best performance and compatibility. - Specify the quantization type in your config or pipeline when running mllm-quantizer.
For more details on each quantization type and their configuration, refer to the quantization implementation in the source code or the kleidiai documentation.
Step 4: Implement Core C++ Files¶
mllm mirrors the Hugging-Face config / tokenizer / model split.
Create three files under mllm/models/qwen3/:
configuration_qwen3.hpptokenization_qwen3.hppmodeling_qwen3.hpp
See full configuration_qwen3.hpp, tokenization_qwen3.hpp, modeling_qwen3.hpp in Appendix.
Step 5: Create an Example Application¶
Directory layout
examples/qwen3/
├── main.cpp
└── config_0.6B_w4a32_kai.json
main.cpp¶
#include "mllm/mllm.hpp"
#include "mllm/models/qwen3/modeling_qwen3.hpp"
#include "mllm/models/qwen3/tokenization_qwen3.hpp"
#include "mllm/models/qwen3/configuration_qwen3.hpp"
int main(int argc, char* argv[]) {
mllm::init();
std::string config_path = "...";
std::string tokenizer_path= "...";
std::string prompt = "Once upon a time";
auto cfg = mllm::models::qwen3::Qwen3Config(config_path);
auto tokenizer = mllm::models::qwen3::Qwen3Tokenizer(tokenizer_path);
auto model = mllm::models::qwen3::Qwen3ForCausalLM(cfg);
auto inputs = tokenizer.convertMessage({.prompt = prompt});
for (auto& step : model.chat(inputs)) {
std::wcout << tokenizer.detokenize(step.cur_token_id) << std::flush;
}
return 0;
}
config.json (example)¶
{
"architectures": ["Qwen3ForCausalLM"],
"bos_token_id": 151643,
"eos_token_id": 151645,
"attention_bias": false,
"hidden_size": 1024,
"head_dim": 128,
"intermediate_size": 3072,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"num_hidden_layers": 28,
"max_position_embeddings": 40960,
"rms_norm_eps": 1e-06,
"vocab_size": 151936,
"max_cache_length": 2048,
"rope_theta": 1000000.0,
"tie_word_embeddings": true,
"linear_impl_type": "KaiLinear_f32_qai8dxp_qsi4c32p_mxk_nxk_qai8dxp1x8_qsi4c32p8x8_1x8x32"
}
Build & run
python task.py tasks/build_<...>.py
Step 6: Open a Pull Request¶
Fork the official repository.
Create a feature branch:
git checkout -b add-qwen3.Commit atomic changes with clear messages:
[Qwen3] Add configuration loader [Qwen3] Implement Sentence-Piece tokenizer [Qwen3] Add CI unit test and 2-sample generation
Push and open a PR.
A maintainer will review numerical correctness, coding style, and CI status before merging.
Conclusion¶
You have now walked through the complete life-cycle of integrating a new LLM into mllm:
Acquire Models → 2. Convert to MLLM models → 3. (Optionally) Quantize → 4. Implement C++ stubs → 5. Example & test → 6. PR.
Following the checklist above guarantees that your model will load efficiently, run everywhere mllm runs, and is maintainable by the community. Happy hacking!
Appendix¶
quant_config.json¶
{
"^model\\.layers\\.\\d+\\.self_attn\\.q_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
2048,
1024
],
"replace": true
}
},
"^model\\.layers\\.\\d+\\.self_attn\\.k_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
1024,
1024
],
"replace": true
}
},
"^model\\.layers\\.\\d+\\.self_attn\\.v_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
1024,
1024
],
"replace": true
}
},
"^model\\.layers\\.\\d+\\.self_attn\\.o_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
1024,
2048
],
"replace": true
}
},
"^model\\.layers\\.\\d+\\.mlp\\.gate_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
3072,
1024
],
"replace": true
}
},
"^model\\.layers\\.\\d+\\.mlp\\.up_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
3072,
1024
],
"replace": true
}
},
"^model\\.layers\\.\\d+\\.mlp\\.down_proj.(bias|weight)": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
1024,
3072
],
"replace": true
}
},
"lm_head.weight": {
"hints": {
"quant_method": "kai",
"kai_matmul_triplet": "f32_qai8dxp_qsi4c32p",
"kai_matmul_layout": "mxk_nxk",
"kai_matmul_tile_cfg": "qai8dxp1x8_qsi4c32p8x8_1x8x32",
"shape": [
151936,
1024
],
"replace": false,
"rename": "lm_head_out.weight"
}
}
}
configuration_qwen3.hpp¶
// Copyright (c) MLLM Team.
// Licensed under the MIT License.
#pragma once
#include "mllm/core/aops/LinearOp.hpp"
#include "mllm/engine/ConfigFile.hpp"
namespace mllm::models::qwen3 {
struct Qwen3Config : protected ConfigFile {
Qwen3Config() = default;
explicit Qwen3Config(const std::string& file_path) : ConfigFile(file_path) {
// Init all
attention_bias = data()["attention_bias"];
hidden_size = data()["hidden_size"];
intermediate_size = data()["intermediate_size"];
num_attention_heads = data()["num_attention_heads"];
num_key_value_heads = data()["num_key_value_heads"];
num_hidden_layers = data()["num_hidden_layers"];
max_position_embeddings = data()["max_position_embeddings"];
rms_norm_eps = data()["rms_norm_eps"];
vocab_size = data()["vocab_size"];
head_dim = data()["head_dim"];
bos_token_id = data()["bos_token_id"];
eos_token_id = data()["eos_token_id"];
rope_theta = data()["rope_theta"];
tie_word_embeddings = data()["tie_word_embeddings"];
max_cache_length = data()["max_cache_length"];
linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]);
}
bool attention_bias = false;
int32_t hidden_size = 1024;
int32_t head_dim = 128;
int32_t intermediate_size = 3072;
int32_t num_attention_heads = 16;
int32_t num_key_value_heads = 8;
int32_t num_hidden_layers = 28;
int32_t max_position_embeddings = 40960;
float rms_norm_eps = 1e-06;
int32_t vocab_size = 151936;
int64_t bos_token_id = 151643;
int64_t eos_token_id = 151645;
float rope_theta = 1000000.0;
bool tie_word_embeddings = true;
int32_t max_cache_length = 2048;
int32_t end_of_text_token_id = 151645;
aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault;
};
} // namespace mllm::models::qwen3
modeling_qwen3.hpp¶
// Copyright (c) MLLM Team.
// Licensed under the MIT License.
#include "mllm/mllm.hpp"
#include "mllm/nn/Module.hpp"
#include "mllm/nn/Nn.hpp"
#include "mllm/nn/Functional.hpp"
#include "mllm/nn/lmcache/StaticCache.hpp"
#include "mllm/models/qwen3/configuration_qwen3.hpp"
#include "mllm/utils/Enumerate.hpp"
#include "mllm/models/ARGeneration.hpp"
namespace mllm::models::qwen3 {
inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor {
auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc();
auto inv_freq_ptr = inv_freq.ptr<float>();
for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); }
return inv_freq;
}
inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq,
float attention_scaling = 1.0f) -> std::pair<Tensor, Tensor> {
auto batch_size = position_ids.shape()[0];
auto seq_len = position_ids.shape()[1];
auto inv_freq_len = inv_freq.shape()[0];
auto dim = inv_freq_len * 2;
// Create freqs tensor: position_ids @ inv_freq
auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc();
auto freqs_ptr = freqs.ptr<float>();
auto position_ids_ptr = position_ids.ptr<int64_t>();
auto inv_freq_ptr = inv_freq.ptr<float>();
// Compute freqs = position_ids[:, :, None] @ inv_freq[None, :]
for (int b = 0; b < batch_size; ++b) {
for (int s = 0; s < seq_len; ++s) {
auto pos = position_ids_ptr[b * seq_len + s];
for (int d = 0; d < inv_freq_len; ++d) {
freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast<float>(pos) * inv_freq_ptr[d];
}
}
}
// Create sin and cos tensors with shape [batch_size, seq_len, dim]
auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc();
auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc();
auto sin_ptr = sin_emb.ptr<float>();
auto cos_ptr = cos_emb.ptr<float>();
// Compute sin and cos embeddings: emb = [freqs, freqs]
for (int b = 0; b < batch_size; ++b) {
for (int s = 0; s < seq_len; ++s) {
for (int d = 0; d < inv_freq_len; ++d) {
auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d];
auto sin_val = std::sin(freq) * attention_scaling;
auto cos_val = std::cos(freq) * attention_scaling;
// Store the same values in both halves: [freqs, freqs]
sin_ptr[b * seq_len * dim + s * dim + d] = sin_val;
sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val;
cos_ptr[b * seq_len * dim + s * dim + d] = cos_val;
cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val;
}
}
}
return {sin_emb, cos_emb};
}
class Qwen3MLP final : public nn::Module {
nn::Linear gate_proj_;
nn::Linear up_proj_;
nn::Linear down_proj_;
nn::SiLU silu_;
public:
Qwen3MLP() = default;
Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
gate_proj_ = reg<nn::Linear>("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type);
silu_ = reg<nn::SiLU>("act");
up_proj_ = reg<nn::Linear>("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type);
down_proj_ = reg<nn::Linear>("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type);
}
std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto x = gate_proj_(inputs[0]);
x = silu_(x);
auto y = up_proj_(inputs[0]);
x = x * y;
x = down_proj_(x);
return {x};
}
};
class Qwen3Attention final : public nn::Module {
nn::Linear q_proj_;
nn::Linear k_proj_;
nn::Linear v_proj_;
nn::Linear o_proj_;
nn::RMSNorm rms_norm_q_;
nn::RMSNorm rms_norm_k_;
nn::RoPE q_rope_;
nn::RoPE k_rope_;
nn::CausalMask mask_;
nn::Softmax softmax_;
int hidden_size_;
int head_dim_;
int num_attention_heads_;
int num_key_value_heads_;
int num_key_value_groups_;
public:
Qwen3Attention() = default;
Qwen3Attention(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
hidden_size_ = cfg.hidden_size;
num_attention_heads_ = cfg.num_attention_heads;
num_key_value_heads_ = cfg.num_key_value_heads;
head_dim_ = cfg.head_dim;
num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_;
q_proj_ =
reg<nn::Linear>("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type);
k_proj_ =
reg<nn::Linear>("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type);
v_proj_ =
reg<nn::Linear>("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type);
o_proj_ =
reg<nn::Linear>("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type);
rms_norm_q_ = reg<nn::RMSNorm>("q_norm", cfg.rms_norm_eps);
rms_norm_k_ = reg<nn::RMSNorm>("k_norm", cfg.rms_norm_eps);
q_rope_ = reg<nn::RoPE>("q_rope", cfg.rope_theta, cfg.max_position_embeddings);
k_rope_ = reg<nn::RoPE>("k_rope", cfg.rope_theta, cfg.max_position_embeddings);
mask_ = reg<nn::CausalMask>("mask");
softmax_ = reg<nn::Softmax>("softmax", -1);
}
std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto x = inputs[0];
auto llm_embedding_sin = inputs[1];
auto llm_embedding_cos = inputs[2];
auto past_kv_cache = args[0].get<nn::StaticCache*>();
// [B, S, H * D]
auto query_states = q_proj_(x);
auto key_states = k_proj_(x);
auto value_states = v_proj_(x);
int B = inputs[0].shape()[0];
int S = inputs[0].shape()[1];
// [B, S, H, D]
query_states = query_states.view({B, S, num_attention_heads_, head_dim_});
key_states = key_states.view({B, S, num_key_value_heads_, head_dim_});
value_states = value_states.view({B, S, num_key_value_heads_, head_dim_});
// [B, S, H, D]
query_states = rms_norm_q_(query_states);
key_states = rms_norm_k_(key_states);
// [B, H, S, D]
query_states = query_states.transpose(1, 2);
key_states = key_states.transpose(1, 2);
value_states = value_states.transpose(1, 2);
// [B, H, S, D]
query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos);
key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos);
// [B, H, S, D]
std::tie(key_states, value_states) = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states);
Tensor attn;
if (key_states.dtype() == kFloat32) {
// attention weight
// [B, H, S, S]
attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_));
attn = mask_(attn);
attn = softmax_(attn);
} else if (key_states.dtype() == kFloat16) {
attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_));
attn = mask_(attn);
attn = softmax_(attn);
attn = attn.to(kFloat16);
}
// attn output
// [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D]
auto output = nn::functional::matmul(attn, value_states);
// [B, H, S, D] -> [B, S, H, D] -> [B, S, H * D]
output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_});
output = o_proj_(output);
return {output};
}
int layer_idx_;
};
class Qwen3Decoder final : public nn::Module {
public:
Qwen3Attention self_attn_;
Qwen3MLP mlp_;
nn::RMSNorm input_layer_norm_;
nn::RMSNorm post_attention_layer_norm_;
Qwen3Decoder() = default;
Qwen3Decoder(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
self_attn_ = reg<Qwen3Attention>("self_attn", cfg);
mlp_ = reg<Qwen3MLP>("mlp", cfg);
input_layer_norm_ = reg<nn::RMSNorm>("input_layernorm", cfg.rms_norm_eps);
post_attention_layer_norm_ = reg<nn::RMSNorm>("post_attention_layernorm", cfg.rms_norm_eps);
}
std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto llm_embedding_sin = inputs[1];
auto llm_embedding_cos = inputs[2];
auto& kv_cache = args[0];
auto x = input_layer_norm_(inputs[0]);
x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0];
auto tmp = x + inputs[0];
x = post_attention_layer_norm_(tmp);
x = mlp_(x)[0];
x = x + tmp;
return {x};
}
};
class Qwen3Text final : public nn::Module {
nn::ModuleList<Qwen3Decoder> decode_blocks_;
nn::RMSNorm norm_;
nn::Embedding embedding_;
public:
Qwen3Text() = default;
Qwen3Text(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) {
decode_blocks_ = reg<nn::ModuleList<Qwen3Decoder>>("layers", cfg.num_hidden_layers, cfg);
for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; }
norm_ = reg<nn::RMSNorm>("norm", cfg.rms_norm_eps);
embedding_ = reg<nn::Embedding>("embed_tokens", cfg.vocab_size, cfg.hidden_size);
}
std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto& blocks = decode_blocks_.list();
// X is already embedded
auto x = embedding_(inputs[0]);
auto llm_embedding_sin = inputs[1];
auto llm_embedding_cos = inputs[2];
auto& kv_cache = args[0];
for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; }
x = norm_(x);
return {x};
}
};
class Qwen3ForCausalLM : public ARGeneration, public nn::Module {
public:
explicit Qwen3ForCausalLM(const Qwen3Config& cfg) : cfg(cfg) {
kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers,
cfg.num_attention_heads, // q_heads
cfg.num_key_value_heads, // kv_heads
cfg.head_dim, // kv_dim
kFloat32, // k_dtype
kFloat32, // v_dtype
kCPU, // device_type
false // use_fa2
);
eos_token_id_ = cfg.end_of_text_token_id;
max_length_ = cfg.max_cache_length;
tie_word_embeddings_ = cfg.tie_word_embeddings;
llm = reg<Qwen3Text>("model", cfg);
if (cfg.tie_word_embeddings) {
// NOTE:
// model.lm_head.weight is quantization weights of model.embed_tokens.weight
lm_head_ = reg<nn::Linear>("lm_head_out", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type);
}
// Init inv freq
auto inv = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta);
registerBuffer("inv_freq", inv);
}
ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override {
auto sequence = input.at("sequence");
// Generate position_ids for the current sequence
auto batch_size = sequence.shape()[0];
auto seq_len = sequence.shape()[1];
Tensor position_ids = Tensor::nil();
if (input.count("position_ids")) {
// Use existing position_ids for decode phase
position_ids = input.at("position_ids");
// For decode phase, increment the last position
if (seq_len == 1) {
auto last_pos = *position_ids.offsettedPtr<int64_t>({0, position_ids.shape()[1] - 1});
position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc();
*position_ids.offsettedPtr<int64_t>({0, 0}) = last_pos + 1;
}
} else {
// Generate position_ids for prefill phase
position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc();
auto position_ids_ptr = position_ids.ptr<int64_t>();
for (int b = 0; b < batch_size; ++b) {
for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; }
}
}
// Generate RoPE embeddings using the inv_freq buffer
auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f);
sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0];
// clip x to one seq length
{
auto S = sequence.shape()[1];
sequence = sequence[{kAll, {S - 1}, kAll}];
}
if (tie_word_embeddings_) { sequence = lm_head_(sequence); }
return {
{"sequence", sequence},
{"position_ids", position_ids},
};
}
private:
const Qwen3Config& cfg;
Qwen3Text llm;
nn::Linear lm_head_;
bool tie_word_embeddings_;
nn::StaticCache kv_cache_;
};
} // namespace mllm::models::qwen3
tokenization_qwen3.hpp¶
// Copyright (c) MLLM Team.
// Licensed under the MIT License.
#pragma once
#include <vector>
#include <unordered_map>
#include "mllm/preprocessor/tokenizers/BPE.hpp"
#include "mllm/models/ARGeneration.hpp"
#include "mllm/preprocessor/tokenizers/Unicode.hpp"
#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp"
namespace mllm::models::qwen3 {
// we need to handle this:
//
// (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|
// ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
inline bool qwen3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) {
if (pos >= str.size()) return false;
// 1. Match contractions: "'s|'t|'re|'ve|'m|'ll|'d"
static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"};
for (const auto& contraction : contractions) {
if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) {
matched = contraction;
pos += contraction.size();
return true;
}
}
// 2. Match [^\r\n\p{L}\p{N}]?\p{L}+ (non-letter/digit followed by letters)
{
size_t original_pos = pos;
bool has_prefix = false;
matched.clear();
// Check optional non-letter/digit prefix (excluding \r\n)
if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') {
matched += str[pos];
++pos;
has_prefix = true;
}
// Require at least one letter
if (pos < str.size() && preprocessor::isLetter(str[pos])) {
do {
matched += str[pos];
++pos;
} while (pos < str.size() && preprocessor::isLetter(str[pos]));
return true;
} else {
// Rollback if no letters after prefix
if (has_prefix) {
pos = original_pos;
matched.clear();
}
}
}
// 3. Match \p{N} (digits)
if (preprocessor::isDigit(str[pos])) {
matched = str.substr(pos, 1);
++pos;
return true;
}
// 4. Match ?[^\s\p{L}\p{N}]+[\r\n]* (punctuation/symbols with optional space prefix)
{
size_t original_pos = pos;
matched.clear();
size_t start = pos;
// Optional space
if (str[pos] == L' ') { ++pos; }
// Require at least one non-letter/digit/whitespace
if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) {
do {
++pos;
} while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos])
&& !preprocessor::isDigit(str[pos]));
// Capture from start (after optional space) to current pos
matched = str.substr(start, pos - start);
// Capture trailing newlines
while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
matched += str[pos];
++pos;
}
return true;
} else {
// Rollback if no symbols found
pos = original_pos;
}
}
// 5. Match \s*[\r\n]+ (newlines with leading whitespace)
{
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos;
matched = str.substr(start, pos - start);
return true;
} else {
pos = start;
}
}
// 6. Match \s+(?!\S) (whitespace not followed by non-space)
if (std::iswspace(str[pos])) {
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
// Check if at end or followed by whitespace
if (pos >= str.size() || std::iswspace(str[pos])) {
matched = str.substr(start, pos - start);
return true;
} else {
pos = start;
}
}
// 7. Match remaining whitespace
if (std::iswspace(str[pos])) {
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
matched = str.substr(start, pos - start);
return true;
}
return false;
}
inline bool qwen3Regex(const std::string& str, std::vector<std::wstring>& splitted) {
auto w_string = preprocessor::utf8string2WideString(str);
size_t pos = 0;
while (pos < w_string.size()) {
std::wstring matched;
if (qwen3TokenizerMatchPattern(w_string, pos, matched)) {
splitted.push_back(matched);
} else {
++pos;
}
}
return true;
}
struct Qwen3Message {
std::string prompt;
static inline std::string message_template =
"<|im_start|>user\n{{{prompt}}}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
};
class Qwen3Tokenizer final : public mllm::preprocessor::AutoTokenizer {
public:
explicit Qwen3Tokenizer(const std::string& file_path) {
preprocessor::initLocal();
preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_);
for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); }
bpe_.initFromSentencePieceJson(file_path);
special_tokens_trie_.add(L"<|endoftext|>");
special_tokens_trie_.add(L"<|im_start|>");
special_tokens_trie_.add(L"<|im_end|>");
special_tokens_trie_.add(L"<|object_ref_start|>");
special_tokens_trie_.add(L"<|object_ref_end|>");
special_tokens_trie_.add(L"<|box_start|>");
special_tokens_trie_.add(L"<|box_end|>");
special_tokens_trie_.add(L"<|quad_start|>");
special_tokens_trie_.add(L"<|quad_end|>");
special_tokens_trie_.add(L"<|vision_start|>");
special_tokens_trie_.add(L"<|vision_end|>");
special_tokens_trie_.add(L"<|vision_pad|>");
special_tokens_trie_.add(L"<|image_pad|>");
special_tokens_trie_.add(L"<|video_pad|>");
special_tokens_trie_.add(L"<think>");
special_tokens_trie_.add(L"</think>");
}
std::vector<std::wstring> _tokenize(const std::string& str) override {
std::vector<std::wstring> ret;
std::vector<std::wstring> splitted;
::mllm::models::qwen3::qwen3Regex(str, splitted);
for (const auto& s : splitted) {
auto utf_8_str = preprocessor::wideString2Utf8String(s);
std::wstring mapped_str;
for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); }
auto bpe_ts = bpe_._bpe(mapped_str);
for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); }
}
return ret;
}
std::vector<std::wstring> tokenize(const std::string& str) override {
auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str));
std::vector<std::wstring> all_tokens;
for (const auto& token : tokens) {
if (special_tokens_trie_.isSpecialToken(token)) {
all_tokens.emplace_back(token);
continue;
}
auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token));
all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end());
}
return all_tokens;
}
std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); }
std::wstring detokenize(int64_t pos_idx) override {
auto str = _detokenize(pos_idx);
std::string utf_8_str;
for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); }
return {mllm::preprocessor::utf8string2WideString(utf_8_str)};
}
Tensor convert2Ids(const std::vector<std::wstring>& strs) override {
std::vector<int64_t> ids;
ids.reserve(strs.size());
for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); }
Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU)
.setMemType(kExtraInput)
.setName("qwen2-tokenizer-i0")
.alloc();
auto ptr = ret.ptr<int64_t>();
for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; }
return ret;
}
ARGenerationOutputPast convertMessage(const Qwen3Message& message) {
// process prompt
auto applied_string = Qwen3Message::message_template;
size_t pos = applied_string.find("{{{prompt}}}");
applied_string.replace(pos, 12, message.prompt);
// process sequence
auto sequence_str = tokenize(applied_string);
std::vector<int64_t> ids;
ids.reserve(sequence_str.size());
for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); }
// Get sequence Tensor
Tensor sequence = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU)
.setMemType(kNormal)
.setName("qwen2-tokenizer-i0")
.alloc();
auto ptr = sequence.ptr<int64_t>();
for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; }
return {
{"sequence", sequence},
};
}
private:
// For text
preprocessor::BPE bpe_;
std::unordered_map<std::wint_t, wchar_t> bytes_2_unicode_dict_;
std::unordered_map<wchar_t, std::wint_t> bytes_2_unicode_dict_inverse_;
};
} // namespace mllm::models::qwen3