pymllm.mem_cache.radix_cache

Radix-tree KV cache with SWA and multimodal support.

Supports:
  • Multi-batch serving on a single GPU

  • Sliding Window Attention (SWA) via tombstone mechanism

  • Multimodal namespace isolation via extra_key

  • SHA256 position-aware hashing

  • Page-aligned operations (page_size >= 1)

  • LRU leaf eviction

Attributes

Classes

TreeNode

A single node in the radix tree.

RadixCache

Radix tree for KV-cache prefix sharing.

Module Contents

pymllm.mem_cache.radix_cache.logger
class pymllm.mem_cache.radix_cache.TreeNode

A single node in the radix tree.

value holds a 1-D int64 tensor of KV-pool indices (one per token in key). When the node has been evicted, value is None.

__slots__ = ('children', 'parent', 'key', 'value', 'lock_ref', 'swa_lock_ref', 'swa_tombstone',...
children: Dict[Any, TreeNode]
parent: TreeNode | None = None
key: pymllm.mem_cache.base_prefix_cache.RadixKey | None = None
value: torch.Tensor | None = None
lock_ref: int = 0
swa_lock_ref: int = 0
swa_tombstone: bool = False
swa_boundary_id: int | None = None
last_access_time: float
hit_count: int = 0
hash_values: List[str] | None = None
id: int = 1
property evicted: bool
Return type:

bool

__lt__(other)
Parameters:

other (TreeNode)

Return type:

bool

class pymllm.mem_cache.radix_cache.RadixCache(page_size=1, sliding_window_size=None, token_to_kv_pool_allocator=None, on_node_evict=None)

Bases: pymllm.mem_cache.base_prefix_cache.BasePrefixCache

Radix tree for KV-cache prefix sharing.

Parameters:
  • page_size (int) – Number of tokens per KV-pool page. Keys and values are aligned to this granularity.

  • sliding_window_size (Optional[int]) – If set, enables SWA mode. The cache tracks which nodes have had their SWA KV freed (tombstoned) and constrains prefix matching so that the sliding-window invariant is maintained.

  • token_to_kv_pool_allocator (Any) – Optional pool allocator with free(indices) (and free_swa for SWA mode). When None, index tensors are simply discarded.

  • on_node_evict (Optional[Callable[[int], None]]) – Optional callback invoked with the node id when a node is evicted.

page_size = 1
sliding_window_size = None
pool = None
on_node_evict = None
property supports_swa: bool
Return type:

bool

evictable_size()
Return type:

int

swa_evictable_size()
Return type:

int

protected_size()
Return type:

int

swa_protected_size()
Return type:

int

total_size()

Total number of cached tokens (including tombstoned).

Return type:

int

reset()

Clear all cached state and re-initialise the root node.

Return type:

None

match_prefix(key)

Find the longest cached prefix of key.

For SWA mode the match is further constrained: the path from the returned last_node to root must have at least sliding_window_size non-tombstone tokens (or be entirely tombstone-free back to root).

Accessing a prefix refreshes LRU timestamps along the matched path.

Parameters:

key (pymllm.mem_cache.base_prefix_cache.RadixKey)

Return type:

pymllm.mem_cache.base_prefix_cache.MatchResult

insert(key, value=None, *, prev_prefix_len=0, swa_evicted_seqlen=0, **kwargs)

Insert key/value into the tree.

Returns how many leading tokens were already present (the prefix length). The caller is responsible for freeing duplicate KV indices in the range [cache_protected_len, prefix_len).

Parameters:
  • prev_prefix_len (int) – (SWA mode) tokens before this offset are already protected and should not have their values overwritten.

  • swa_evicted_seqlen (int) – (SWA mode) the sequence length up to which SWA KV has been previously evicted. Used to decide whether a tombstoned node can be un-tombstoned with the incoming value.

  • key (pymllm.mem_cache.base_prefix_cache.RadixKey)

  • value (Optional[torch.Tensor])

  • kwargs (Any)

Return type:

pymllm.mem_cache.base_prefix_cache.InsertResult

evict(num_tokens, swa_num_tokens=0)

Evict up to num_tokens (full) and swa_num_tokens (SWA) tokens.

Full eviction removes leaf nodes entirely; SWA eviction tombstones internal nodes (freeing SWA KV but retaining full-attn KV).

Parameters:
  • num_tokens (int)

  • swa_num_tokens (int)

Return type:

pymllm.mem_cache.base_prefix_cache.EvictResult

inc_lock_ref(node)

Lock nodes from node up to root (prevents eviction).

Returns swa_boundary_id that must be passed back to dec_lock_ref(). In non-SWA mode, returns None.

Parameters:

node (TreeNode)

Return type:

Optional[int]

dec_lock_ref(node, swa_boundary_id=None, **kwargs)

Unlock nodes from node up to root.

Parameters:
  • node (TreeNode)

  • swa_boundary_id (Optional[int])

  • kwargs (Any)

Return type:

None

compute_node_hash(node)

Compute position-aware SHA-256 hashes for node (one per page).

Lazily computed and cached on node.hash_values.

Parameters:

node (TreeNode)

Return type:

List[str]

pretty_print()

Print the tree structure to stdout.

Return type:

None