pymllm.layers.attention.hybrid_backend ====================================== .. py:module:: pymllm.layers.attention.hybrid_backend .. autoapi-nested-parse:: Hybrid attention backend -- FlashInfer + GDN for hybrid architectures. Wraps a :class:`FlashInferAttnBackend` (for full-attention layers) and a :class:`GDNAttnBackend` (for GDN linear-attention layers). Dispatches based on layer type: * ``RadixAttention`` calls → delegated to ``full_attn_backend`` * ``RadixLinearAttention`` calls (via ``forward_gdn``) → delegated to ``gdn_backend`` CUDA-graph compatible: delegates all graph lifecycle methods to both sub-backends. Attributes ---------- .. autoapisummary:: pymllm.layers.attention.hybrid_backend.logger Classes ------- .. autoapisummary:: pymllm.layers.attention.hybrid_backend.HybridAttnBackend Module Contents --------------- .. py:data:: logger .. py:class:: HybridAttnBackend(full_attn_backend, gdn_backend, full_attn_layer_ids) Bases: :py:obj:`pymllm.layers.attention.attention_backend.AttentionBackend` Composite attention backend for hybrid full-attention + GDN models. :param full_attn_backend: FlashInfer backend for standard transformer attention layers. :param gdn_backend: GDN backend for linear-attention layers. :param full_attn_layer_ids: Set of global layer IDs that use full attention (for logging). .. py:attribute:: full_attn_backend .. py:attribute:: gdn_backend .. py:attribute:: full_attn_layer_ids .. py:method:: init_forward_metadata(forward_batch) Initialize metadata for both sub-backends. .. py:method:: forward_decode(q, k, v, layer, forward_batch, save_kv_cache = True, **kwargs) Delegate full-attention decode to FlashInfer backend. .. py:method:: forward_extend(q, k, v, layer, forward_batch, save_kv_cache = True, **kwargs) Delegate full-attention extend to FlashInfer backend. .. py:method:: forward_gdn(layer, forward_batch, mixed_qkv, a, b) Delegate GDN computation to the GDN backend. .. py:method:: get_cuda_graph_seq_len_fill_value() Delegate to the full-attention backend. .. py:method:: init_cuda_graph_state(max_bs, max_num_tokens) Allocate CUDA-graph state for both sub-backends. .. py:method:: init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode) Set up metadata for CUDA-graph capture in both sub-backends. .. py:method:: init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu) Update metadata for CUDA-graph replay in both sub-backends.