pymllm.layers.attention.gdn_backend =================================== .. py:module:: pymllm.layers.attention.gdn_backend .. autoapi-nested-parse:: GDN attention backend -- pooled-state GDN computation for hybrid models. Performs GDN (Gated Delta Net) linear-attention using externalized state stored in a :class:`~pymllm.mem_cache.memory_pool.GDNPool`. Supports both extend (prefill) and decode paths with FlashInfer kernels. This backend is not used directly; it is wrapped by :class:`~pymllm.layers.attention.hybrid_backend.HybridAttnBackend`. Attributes ---------- .. autoapisummary:: pymllm.layers.attention.gdn_backend.logger Classes ------- .. autoapisummary:: pymllm.layers.attention.gdn_backend.GDNForwardMetadata pymllm.layers.attention.gdn_backend.GDNAttnBackend Module Contents --------------- .. py:data:: logger .. py:class:: GDNForwardMetadata Per-batch metadata for GDN backend. .. py:attribute:: cache_indices :type: torch.Tensor .. py:attribute:: cu_seqlens :type: Optional[torch.Tensor] :value: None .. py:class:: GDNAttnBackend(gdn_pool, device) GDN linear-attention backend using pooled states. Handles both extend (prefill) and decode paths for GDN layers. Uses FlashInfer kernels when available (SM90+), with PyTorch fallback. :param gdn_pool: Pre-allocated :class:`~pymllm.mem_cache.memory_pool.GDNPool`. :param device: Target device. .. py:attribute:: gdn_pool .. py:attribute:: device .. py:attribute:: forward_metadata :type: Optional[GDNForwardMetadata] :value: None .. py:method:: init_forward_metadata(forward_batch) Prepare GDN metadata from the current forward batch. .. py:method:: init_cuda_graph_state(max_bs, max_num_tokens) Allocate CUDA-graph state for GDN backend. The GDN pool buffers are already pre-allocated at fixed addresses, so we only need to allocate the metadata tensor. .. py:method:: init_forward_metadata_capture_cuda_graph(bs, req_pool_indices, seq_lens) Set up GDN metadata for CUDA-graph capture (decode only). .. py:method:: init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens) Update GDN metadata for CUDA-graph replay (decode only). .. py:method:: forward_decode(layer, forward_batch, mixed_qkv, a, b) GDN decode: one new token per request. Steps: 1. Gather conv_state from pool → [bs, conv_dim, K-1] 2. Conv1d update: shift + weighted sum for 1 new token 3. Scatter updated conv_state back to pool 4. SiLU → split q,k,v 5. FlashInfer gated_delta_rule_decode (or PyTorch fallback) .. py:method:: forward_extend(layer, forward_batch, mixed_qkv, a, b) GDN extend (prefill): multi-token per request. Steps: 1. Gather conv_state from pool for each request 2. Per-request causal conv1d 3. Scatter new conv_state back to pool 4. SiLU → split q,k,v → gating 5. FlashInfer chunk_gated_delta_rule (or PyTorch fallback) 6. Scatter final recurrent state back to pool .. py:method:: forward_gdn(layer, forward_batch, mixed_qkv, a, b) Route to decode or extend based on forward mode.