pymllm.mem_cache.radix_cache¶
Radix-tree KV cache with SWA and multimodal support.
- Supports:
Multi-batch serving on a single GPU
Sliding Window Attention (SWA) via tombstone mechanism
Multimodal namespace isolation via
extra_keySHA256 position-aware hashing
Page-aligned operations (page_size >= 1)
LRU leaf eviction
Attributes¶
Classes¶
A single node in the radix tree. |
|
Radix tree for KV-cache prefix sharing. |
Module Contents¶
- pymllm.mem_cache.radix_cache.logger¶
- class pymllm.mem_cache.radix_cache.TreeNode¶
A single node in the radix tree.
valueholds a 1-Dint64tensor of KV-pool indices (one per token inkey). When the node has been evicted,valueisNone.- __slots__ = ('children', 'parent', 'key', 'value', 'lock_ref', 'swa_lock_ref', 'swa_tombstone',...¶
- key: pymllm.mem_cache.base_prefix_cache.RadixKey | None = None¶
- value: torch.Tensor | None = None¶
- lock_ref: int = 0¶
- swa_lock_ref: int = 0¶
- swa_tombstone: bool = False¶
- swa_boundary_id: int | None = None¶
- last_access_time: float¶
- hit_count: int = 0¶
- hash_values: List[str] | None = None¶
- id: int = 1¶
- property evicted: bool¶
- Return type:
bool
- class pymllm.mem_cache.radix_cache.RadixCache(page_size=1, sliding_window_size=None, token_to_kv_pool_allocator=None, on_node_evict=None)¶
Bases:
pymllm.mem_cache.base_prefix_cache.BasePrefixCacheRadix tree for KV-cache prefix sharing.
- Parameters:
page_size (int) – Number of tokens per KV-pool page. Keys and values are aligned to this granularity.
sliding_window_size (Optional[int]) – If set, enables SWA mode. The cache tracks which nodes have had their SWA KV freed (tombstoned) and constrains prefix matching so that the sliding-window invariant is maintained.
token_to_kv_pool_allocator (Any) – Optional pool allocator with
free(indices)(andfree_swafor SWA mode). When None, index tensors are simply discarded.on_node_evict (Optional[Callable[[int], None]]) – Optional callback invoked with the node id when a node is evicted.
- page_size = 1¶
- sliding_window_size = None¶
- pool = None¶
- on_node_evict = None¶
- property supports_swa: bool¶
- Return type:
bool
- evictable_size()¶
- Return type:
int
- swa_evictable_size()¶
- Return type:
int
- protected_size()¶
- Return type:
int
- swa_protected_size()¶
- Return type:
int
- total_size()¶
Total number of cached tokens (including tombstoned).
- Return type:
int
- reset()¶
Clear all cached state and re-initialise the root node.
- Return type:
None
- match_prefix(key)¶
Find the longest cached prefix of key.
For SWA mode the match is further constrained: the path from the returned
last_nodeto root must have at leastsliding_window_sizenon-tombstone tokens (or be entirely tombstone-free back to root).Accessing a prefix refreshes LRU timestamps along the matched path.
- Parameters:
- Return type:
- insert(key, value=None, *, prev_prefix_len=0, swa_evicted_seqlen=0, **kwargs)¶
Insert key/value into the tree.
Returns how many leading tokens were already present (the prefix length). The caller is responsible for freeing duplicate KV indices in the range
[cache_protected_len, prefix_len).- Parameters:
prev_prefix_len (int) – (SWA mode) tokens before this offset are already protected and should not have their values overwritten.
swa_evicted_seqlen (int) – (SWA mode) the sequence length up to which SWA KV has been previously evicted. Used to decide whether a tombstoned node can be un-tombstoned with the incoming value.
value (Optional[torch.Tensor])
kwargs (Any)
- Return type:
- evict(num_tokens, swa_num_tokens=0)¶
Evict up to num_tokens (full) and swa_num_tokens (SWA) tokens.
Full eviction removes leaf nodes entirely; SWA eviction tombstones internal nodes (freeing SWA KV but retaining full-attn KV).
- Parameters:
num_tokens (int)
swa_num_tokens (int)
- Return type:
- inc_lock_ref(node)¶
Lock nodes from node up to root (prevents eviction).
Returns
swa_boundary_idthat must be passed back todec_lock_ref(). In non-SWA mode, returnsNone.- Parameters:
node (TreeNode)
- Return type:
Optional[int]
- dec_lock_ref(node, swa_boundary_id=None, **kwargs)¶
Unlock nodes from node up to root.
- Parameters:
node (TreeNode)
swa_boundary_id (Optional[int])
kwargs (Any)
- Return type:
None
- compute_node_hash(node)¶
Compute position-aware SHA-256 hashes for node (one per page).
Lazily computed and cached on
node.hash_values.- Parameters:
node (TreeNode)
- Return type:
List[str]
- pretty_print()¶
Print the tree structure to stdout.
- Return type:
None