MLLM LM Cache¶
The MLLM LM Cache module provides an efficient Key-Value caching mechanism for optimizing inference performance of large language models and multimodal models. This module supports both static and dynamic caching strategies, effectively reducing redundant computations and improving inference speed.
Overview¶
In Transformer architecture models, the attention mechanism needs to maintain key-value caches to avoid recomputing representations of historical tokens. MLLM provides multiple cache implementations to meet different performance and memory requirements:
StaticCache: Pre-allocates fixed-size cache, suitable for scenarios with known maximum sequence length
DynamicCache: Dynamically allocates cache, suitable for variable-length sequence scenarios
SubStaticCache: A sub-view of static cache, supporting cache slicing operations
API Reference¶
StaticCache¶
Pre-allocates fixed-size cache, suitable for performance optimization during inference.
#include "mllm/nn/lmcache/StaticCache.hpp"
// Create static cache
auto cache = mllm::nn::StaticCache(
max_cache_length, // Maximum cache length
layer_nums, // Number of layers
q_heads, // Number of query heads
kv_heads, // Number of key-value heads
kv_dims, // Key-value dimensions
k_dtype, // Key data type
v_dtype, // Value data type
device_type, // Device type (kCPU, kOpenCL, etc.)
use_fa2 // Whether to use FlashAttention2
);
// Update cache
auto [k_cached, v_cached] = cache.updateKVCache(layer_idx, k_tensor, v_tensor);
// Get current sequence length
int32_t seq_len = cache.getCurrentSeqCnt(layer_idx);
Constructor Parameters¶
Parameter |
Type |
Description |
|---|---|---|
max_cache_length |
int32_t |
Maximum cache sequence length |
layer_nums |
int32_t |
Number of model layers |
q_heads |
int32_t |
Number of query attention heads |
kv_heads |
int32_t |
Number of key-value attention heads |
kv_dims |
int32_t |
Key-value dimensions |
k_dtype |
DataTypes |
Key tensor data type |
v_dtype |
DataTypes |
Value tensor data type |
device_type |
DeviceTypes |
Device type (default kCPU) |
use_fa2 |
bool |
Whether to use FlashAttention2 (default true) |
DynamicCache¶
Dynamically allocates cache, suitable for training or variable-length inference scenarios.
#include "mllm/nn/lmcache/DynamicCache.hpp"
// Create dynamic cache
auto cache = mllm::nn::DynamicCache(
layer_nums, // Number of layers
q_heads, // Number of query heads
kv_heads, // Number of key-value heads
kv_dims, // Key-value dimensions
use_fa2 // Whether to use FlashAttention2
);
// Update cache
auto [k_cached, v_cached] = cache.updateKVCache(layer_idx, k_tensor, v_tensor);
// Get current sequence length
int32_t seq_len = cache.getCurrentSeqCnt();
SubStaticCache¶
A sub-view of static cache that allows slicing operations on the cache.
// Create sub-cache from existing static cache
auto sub_cache = mllm::nn::SubStaticCache(
parent_cache, // Parent cache reference
start_idx, // Start index
len // Length
);
// Use in the same way as StaticCache
auto [k_cached, v_cached] = sub_cache.updateKVCache(layer_idx, k_tensor, v_tensor);
Tensor Format¶
Non-FlashAttention2 Mode¶
Input tensor format: [Batch, Heads, Sequence, Dimension]
// Example: single batch, 32 heads, sequence length 1, dimension 128
Tensor k = Tensor::random({1, 32, 1, 128});
Tensor v = Tensor::random({1, 32, 1, 128});
FlashAttention2 Mode¶
Input tensor format: [Batch, Sequence, Heads, Dimension]
// Example: single batch, sequence length 1, 32 heads, dimension 128
Tensor k = Tensor::random({1, 1, 32, 128});
Tensor v = Tensor::random({1, 1, 32, 128});
Usage Examples¶
Basic Usage¶
#include "mllm/nn/lmcache/StaticCache.hpp"
// Configure parameters
const int32_t max_seq_len = 2048;
const int32_t num_layers = 24;
const int32_t num_q_heads = 32;
const int32_t num_kv_heads = 8; // Support GQA (Grouped Query Attention)
const int32_t head_dim = 128;
// Create cache
auto cache = mllm::nn::StaticCache(
max_seq_len, num_layers, num_q_heads, num_kv_heads, head_dim,
mllm::DataTypes::kFP16, mllm::DataTypes::kFP16, mllm::DeviceTypes::kCPU
);
// Use in inference loop
for (int layer = 0; layer < num_layers; ++layer) {
// Assume k, v are key-value tensors of current layer
auto [k_cache, v_cache] = cache.updateKVCache(layer, k, v);
// Use cached key-values for attention computation
auto attention_output = attention_func(q, k_cache, v_cache);
}
Dynamic Cache Example¶
#include "mllm/nn/lmcache/DynamicCache.hpp"
auto dynamic_cache = mllm::nn::DynamicCache(num_layers, num_q_heads, num_kv_heads, head_dim);
// Build cache step by step
for (int step = 0; step < max_steps; ++step) {
for (int layer = 0; layer < num_layers; ++layer) {
auto [k_cache, v_cache] = dynamic_cache.updateKVCache(layer, k_step, v_step);
// Process current step
}
}
Performance Optimization¶
Memory Layout Optimization¶
CPU: Uses
memcpyfor efficient memory copyingGPU/NPU: Uses tensor’s
copy2method for device-optimized copying operations
GQA Support¶
Supports Grouped Query Attention by calculating the repeat factor through q_heads / kv_heads, automatically handling cases where the number of key-value heads is less than query heads.
Device-Specific Optimization¶
// CPU optimization path
case kCPU: {
// Use memcpy for block copying
std::memcpy(cache_ptr, input_ptr, copy_size);
break;
}
// GPU/NPU optimization path
default: {
// Use tensor operations for device-optimized copying
input_tensor.copy2(cache_tensor);
break;
}
Important Notes¶
Memory Pre-allocation: StaticCache pre-allocates all memory during construction, suitable for scenarios with known maximum sequence length
FA2 Compatibility: Different attention implementations require different tensor layouts, ensure to choose the correct
use_fa2parameterDevice Compatibility: Ensure cache and input tensors are on the same device
Data Types: Supports mixed precision, keys and values can use different data types
Error Handling¶
// Check sequence length limits
if (current_seq_len + input_seq_len > max_cache_length) {
throw std::runtime_error("Sequence length exceeds cache capacity");
}
// Validate tensor shapes
MLLM_RT_ASSERT_EQ(k.shape()[1], kv_heads);
MLLM_RT_ASSERT_EQ(v.shape()[1], kv_heads);
Best Practices¶
Choose Appropriate Cache Type:
Use
StaticCachefor inference to achieve optimal performanceUse
DynamicCachefor training or variable-length scenarios
Memory Management:
Estimate maximum sequence length to avoid memory shortage
Consider using
SubStaticCachefor memory slicing
Performance Tuning:
Choose appropriate data types based on hardware characteristics
Enable FlashAttention2 for better memory efficiency