pymllm.layers.attention.attention_backend ========================================= .. py:module:: pymllm.layers.attention.attention_backend .. autoapi-nested-parse:: Abstract base class for pymllm attention backends. Every concrete backend (FlashInfer, Triton, torch-native, …) must implement at minimum: * ``init_forward_metadata`` – called once per batch before the model forward. * ``forward_extend`` – prefill / extend attention. * ``forward_decode`` – single-token decode attention. The public ``forward`` method dispatches to the correct variant based on ``forward_batch.forward_mode``. Classes ------- .. autoapisummary:: pymllm.layers.attention.attention_backend.AttentionBackend Module Contents --------------- .. py:class:: AttentionBackend Bases: :py:obj:`abc.ABC` Abstract base class for attention backends. All concrete backends inherit from this class and implement the abstract methods below. .. py:method:: init_forward_metadata(forward_batch) :abstractmethod: Prepare per-batch metadata before the model's attention layers run. For FlashInfer this plans the KV-index arrays and calls ``wrapper.begin_forward``; for Triton / torch-native this is a no-op. Must be called once per batch *before* ``model.forward``. .. py:method:: forward_decode(q, k, v, layer, forward_batch, save_kv_cache = True, **kwargs) :abstractmethod: Run attention for a decode step (one new token per sequence). .. py:method:: forward_extend(q, k, v, layer, forward_batch, save_kv_cache = True, **kwargs) :abstractmethod: Run attention for a prefill / extend step. .. py:method:: forward(q, k, v, layer, forward_batch, save_kv_cache = True, **kwargs) Dispatch to ``forward_decode`` or ``forward_extend`` based on mode. For IDLE batches a zero-filled output tensor is returned without any compute. .. py:method:: forward_gdn(layer, forward_batch, mixed_qkv, a, b) :abstractmethod: Run GDN linear-attention for one layer. Only implemented by backends that support hybrid (full + GDN) architectures. The default raises ``NotImplementedError``. .. py:method:: get_cuda_graph_seq_len_fill_value() :abstractmethod: Fill value used to pad ``seq_lens`` tensors for CUDA-graph capture. Most backends use ``1`` (not ``0``) to avoid division-by-zero in attention kernels. .. py:method:: init_cuda_graph_state(max_bs, max_num_tokens) :abstractmethod: Allocate shared CUDA-graph state (buffers reused across captures). .. py:method:: init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode) :abstractmethod: Set up per-batch metadata for capturing a CUDA graph. .. py:method:: init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu) :abstractmethod: Update metadata when replaying a captured CUDA graph.