pymllm.mem_cache.mamba_radix_cache¶
Radix-tree KV cache with independent Mamba/SSM state tracking.
Extends RadixCache with dual-tracked
state for hybrid models that combine full attention layers and SSM (Mamba /
GDN) layers. Each tree node stores both:
value: KV-pool indices for full-attention layersmamba_value: state-pool indices for SSM layers
The two pools have independent reference counting and LRU eviction: Mamba state can be evicted more aggressively than full KV cache.
Reference: sglang MambaRadixCache.
Attributes¶
Classes¶
Tree node with dual KV + Mamba state tracking. |
|
Intrusive doubly-linked list for LRU ordering. |
|
Radix tree with independent Mamba/SSM state tracking. |
Module Contents¶
- pymllm.mem_cache.mamba_radix_cache.logger¶
- class pymllm.mem_cache.mamba_radix_cache.MambaTreeNode¶
Tree node with dual KV + Mamba state tracking.
Invariant:
full_lock_ref >= mamba_lock_ref. If Mamba state is locked, full KV must also be locked; full KV alone can be locked without locking Mamba state.- __slots__ = ('children', 'parent', 'key', 'value', 'mamba_value', 'full_lock_ref', 'mamba_lock_ref',...¶
- children: Dict[Any, MambaTreeNode]¶
- parent: MambaTreeNode | None = None¶
- key: pymllm.mem_cache.base_prefix_cache.RadixKey | None = None¶
- value: torch.Tensor | None = None¶
- mamba_value: torch.Tensor | None = None¶
- full_lock_ref: int = 0¶
- mamba_lock_ref: int = 0¶
- last_access_time: float¶
- hit_count: int = 0¶
- id: int = 1¶
- prev: MambaTreeNode | None = None¶
- next: MambaTreeNode | None = None¶
- mamba_prev: MambaTreeNode | None = None¶
- mamba_next: MambaTreeNode | None = None¶
- property evicted: bool¶
- Return type:
bool
- property mamba_tombstone: bool¶
Node has full KV but Mamba state was evicted.
- Return type:
bool
- __lt__(other)¶
- Parameters:
other (MambaTreeNode)
- Return type:
bool
- class pymllm.mem_cache.mamba_radix_cache.LRUList(mamba=False)¶
Intrusive doubly-linked list for LRU ordering.
Supports two modes via mamba flag: uses
prev/nextormamba_prev/mamba_nextpointers onMambaTreeNode.- Parameters:
mamba (bool)
- mamba = False¶
- head¶
- tail¶
- __len__()¶
- Return type:
int
- __contains__(node)¶
- Parameters:
node (Optional[MambaTreeNode])
- Return type:
bool
- insert_mru(node)¶
Insert node at the MRU (head) position.
- Parameters:
node (MambaTreeNode)
- Return type:
None
- remove(node)¶
Remove node from the list.
- Parameters:
node (MambaTreeNode)
- Return type:
None
- touch_mru(node)¶
Move an existing node to the MRU position.
- Parameters:
node (MambaTreeNode)
- Return type:
None
- touch_node_and_parents_mru(node, root)¶
Move node and all ancestors up to root to MRU.
Child is more recently used than parent.
- Parameters:
node (MambaTreeNode)
root (MambaTreeNode)
- Return type:
None
- get_lru_leaf_unlocked()¶
Return the LRU leaf node with lock_ref == 0, or
None.- Return type:
Optional[MambaTreeNode]
- get_lru_unlocked()¶
Return the LRU node with lock_ref == 0, or
None.- Return type:
Optional[MambaTreeNode]
- class pymllm.mem_cache.mamba_radix_cache.MambaRadixCache(page_size=1, token_to_kv_pool_allocator=None, mamba_pool=None, on_node_evict=None)¶
Bases:
pymllm.mem_cache.base_prefix_cache.BasePrefixCacheRadix tree with independent Mamba/SSM state tracking.
- Parameters:
page_size (int) – Number of tokens per KV-pool page.
token_to_kv_pool_allocator (Any) – Pool allocator for full-attention KV indices.
mamba_pool (Any) – Pool object for Mamba/SSM state. Must support
alloc_track_slot(),free_track_slot(slot),copy_states(src, dst).on_node_evict (Optional[Callable[[int], None]]) – Optional callback invoked with node id on eviction.
- page_size = 1¶
- pool = None¶
- mamba_pool = None¶
- on_node_evict = None¶
- full_lru¶
- mamba_lru¶
- evictable_size()¶
- Return type:
int
- protected_size()¶
- Return type:
int
- mamba_evictable_size()¶
- Return type:
int
- mamba_protected_size()¶
- Return type:
int
- total_size()¶
- Return type:
int
- reset()¶
Clear all cached state and re-initialise.
- Return type:
None
- match_prefix(key)¶
Find longest cached prefix. Also returns
mamba_branching_seqlen.- Parameters:
- Return type:
- insert(key, value=None, *, mamba_value=None, **kwargs)¶
Insert with both full KV and Mamba state values.
- Parameters:
value (Optional[torch.Tensor])
mamba_value (Optional[torch.Tensor])
kwargs (Any)
- Return type:
- evict(num_tokens, swa_num_tokens=0)¶
Evict full KV and/or Mamba state tokens.
Phase 1: Evict full KV leaves (frees both KV and Mamba state). Phase 2: Evict Mamba state from internal nodes (tombstone mamba).
- Parameters:
num_tokens (int)
swa_num_tokens (int)
- Return type:
- inc_lock_ref(node)¶
Lock full KV and Mamba state from node to root.
Full lock propagates up to root. Mamba lock only applies to the node itself (not ancestors).
- Parameters:
node (MambaTreeNode)
- Return type:
Optional[Any]
- dec_lock_ref(node, **kwargs)¶
Unlock full KV and Mamba state.
- Parameters:
node (MambaTreeNode)
kwargs (Any)
- Return type:
None
- pretty_print()¶
Print the tree structure to stdout.
- Return type:
None