pymllm.layers.attention.hybrid_backend¶
Hybrid attention backend – FlashInfer + GDN for hybrid architectures.
Wraps a FlashInferAttnBackend (for full-attention layers) and a
GDNAttnBackend (for GDN linear-attention layers). Dispatches
based on layer type:
RadixAttentioncalls → delegated tofull_attn_backendRadixLinearAttentioncalls (viaforward_gdn) → delegated togdn_backend
CUDA-graph compatible: delegates all graph lifecycle methods to both sub-backends.
Attributes¶
Classes¶
Composite attention backend for hybrid full-attention + GDN models. |
Module Contents¶
- pymllm.layers.attention.hybrid_backend.logger¶
- class pymllm.layers.attention.hybrid_backend.HybridAttnBackend(full_attn_backend, gdn_backend, full_attn_layer_ids)¶
Bases:
pymllm.layers.attention.attention_backend.AttentionBackendComposite attention backend for hybrid full-attention + GDN models.
- Parameters:
full_attn_backend (pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend) – FlashInfer backend for standard transformer attention layers.
gdn_backend (pymllm.layers.attention.gdn_backend.GDNAttnBackend) – GDN backend for linear-attention layers.
full_attn_layer_ids (Set[int]) – Set of global layer IDs that use full attention (for logging).
- full_attn_backend¶
- gdn_backend¶
- full_attn_layer_ids¶
- init_forward_metadata(forward_batch)¶
Initialize metadata for both sub-backends.
- Parameters:
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
- Return type:
None
- forward_decode(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)¶
Delegate full-attention decode to FlashInfer backend.
- Parameters:
q (torch.Tensor)
k (Optional[torch.Tensor])
v (Optional[torch.Tensor])
layer (pymllm.layers.attention.radix_attention.RadixAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
save_kv_cache (bool)
- Return type:
torch.Tensor
- forward_extend(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)¶
Delegate full-attention extend to FlashInfer backend.
- Parameters:
q (torch.Tensor)
k (Optional[torch.Tensor])
v (Optional[torch.Tensor])
layer (pymllm.layers.attention.radix_attention.RadixAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
save_kv_cache (bool)
- Return type:
torch.Tensor
- forward_gdn(layer, forward_batch, mixed_qkv, a, b)¶
Delegate GDN computation to the GDN backend.
- Parameters:
layer (pymllm.layers.attention.radix_linear_attention.RadixLinearAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
mixed_qkv (torch.Tensor)
a (torch.Tensor)
b (torch.Tensor)
- Return type:
torch.Tensor
- get_cuda_graph_seq_len_fill_value()¶
Delegate to the full-attention backend.
- Return type:
int
- init_cuda_graph_state(max_bs, max_num_tokens)¶
Allocate CUDA-graph state for both sub-backends.
- Parameters:
max_bs (int)
max_num_tokens (int)
- Return type:
None
- init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode)¶
Set up metadata for CUDA-graph capture in both sub-backends.
- Parameters:
bs (int)
num_tokens (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
forward_mode (pymllm.engine.forward_batch.ForwardMode)
- Return type:
None
- init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu)¶
Update metadata for CUDA-graph replay in both sub-backends.
- Parameters:
bs (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
seq_lens_sum (int)
forward_mode (pymllm.engine.forward_batch.ForwardMode)
seq_lens_cpu (Optional[torch.Tensor])
- Return type:
None