pymllm.layers.attention.hybrid_backend

Hybrid attention backend – FlashInfer + GDN for hybrid architectures.

Wraps a FlashInferAttnBackend (for full-attention layers) and a GDNAttnBackend (for GDN linear-attention layers). Dispatches based on layer type:

  • RadixAttention calls → delegated to full_attn_backend

  • RadixLinearAttention calls (via forward_gdn) → delegated to gdn_backend

CUDA-graph compatible: delegates all graph lifecycle methods to both sub-backends.

Attributes

Classes

HybridAttnBackend

Composite attention backend for hybrid full-attention + GDN models.

Module Contents

pymllm.layers.attention.hybrid_backend.logger
class pymllm.layers.attention.hybrid_backend.HybridAttnBackend(full_attn_backend, gdn_backend, full_attn_layer_ids)

Bases: pymllm.layers.attention.attention_backend.AttentionBackend

Composite attention backend for hybrid full-attention + GDN models.

Parameters:
full_attn_backend
gdn_backend
full_attn_layer_ids
init_forward_metadata(forward_batch)

Initialize metadata for both sub-backends.

Parameters:

forward_batch (pymllm.engine.forward_batch.ForwardBatch)

Return type:

None

forward_decode(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)

Delegate full-attention decode to FlashInfer backend.

Parameters:
Return type:

torch.Tensor

forward_extend(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)

Delegate full-attention extend to FlashInfer backend.

Parameters:
Return type:

torch.Tensor

forward_gdn(layer, forward_batch, mixed_qkv, a, b)

Delegate GDN computation to the GDN backend.

Parameters:
Return type:

torch.Tensor

get_cuda_graph_seq_len_fill_value()

Delegate to the full-attention backend.

Return type:

int

init_cuda_graph_state(max_bs, max_num_tokens)

Allocate CUDA-graph state for both sub-backends.

Parameters:
  • max_bs (int)

  • max_num_tokens (int)

Return type:

None

init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode)

Set up metadata for CUDA-graph capture in both sub-backends.

Parameters:
Return type:

None

init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu)

Update metadata for CUDA-graph replay in both sub-backends.

Parameters:
Return type:

None