pymllm.layers.attention.flashinfer_backend ========================================== .. py:module:: pymllm.layers.attention.flashinfer_backend .. autoapi-nested-parse:: FlashInfer attention backend for pymllm. * No model-runner object -- constructor takes explicit scalar / tensor params. * No tensor-parallelism head splitting (handled at the model layer level). * No speculative decoding support. * ``KVPool`` API: - ``get_kv_buffer(layer_id)`` returns ``(k_buf, v_buf)`` each shaped ``[buf_len, num_heads, head_dim]``. - ``set_kv_buffer(layer_id, indices, k, v)`` -- no scale arguments. Supports: * Single-wrapper mode (full context, no sliding window) * Sliding-window mode (two wrappers: window + full) * CUDA-graph capture / replay for decode and target-verify passes. Attributes ---------- .. autoapisummary:: pymllm.layers.attention.flashinfer_backend.logger Classes ------- .. autoapisummary:: pymllm.layers.attention.flashinfer_backend.WrapperDispatch pymllm.layers.attention.flashinfer_backend.DecodeMetadata pymllm.layers.attention.flashinfer_backend.PrefillMetadata pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend Functions --------- .. autoapisummary:: pymllm.layers.attention.flashinfer_backend.should_use_tensor_core Module Contents --------------- .. py:data:: logger .. py:class:: WrapperDispatch(*args, **kwds) Bases: :py:obj:`enum.Enum` Indicates which wrapper to use for a given attention layer. .. py:attribute:: SLIDING_WINDOW .. py:attribute:: CROSS_ATTENTION .. py:class:: DecodeMetadata Per-batch metadata for a decode step. .. py:attribute:: decode_wrappers :type: List[BatchDecodeWithPagedKVCacheWrapper] .. py:class:: PrefillMetadata Per-batch metadata for a prefill / extend step. .. py:attribute:: prefill_wrappers :type: List[BatchPrefillWithPagedKVCacheWrapper] .. py:attribute:: use_ragged :type: bool .. py:attribute:: extend_no_prefix :type: bool .. py:function:: should_use_tensor_core(kv_cache_dtype, num_attention_heads, num_kv_heads) Return whether FlashInfer decode should use tensor cores. For FP8 we always use tensor cores. For fp16 / bf16 we use them when the GQA group size (num_attention_heads / num_kv_heads) is ≥ 4, which fuses the head group with the token dimension in the MMA instruction. .. py:class:: FlashInferAttnBackend(num_heads, num_kv_heads, head_dim, kv_cache_dtype, q_dtype, max_context_len, req_to_token, device, max_req_pool_size, sliding_window_size = None, skip_prefill = False, kv_indptr_buf = None, kv_last_page_len_buf = None, init_new_workspace = False) Bases: :py:obj:`pymllm.layers.attention.attention_backend.AttentionBackend` FlashInfer-based attention backend for pymllm. This class does not depend on a ``ModelRunner`` object. Instead it takes all required configuration explicitly so that it can be constructed independently of any particular model runner. :param num_heads: Number of query heads per device (after any TP sharding). :param num_kv_heads: Number of KV heads per device. :param head_dim: Per-head dimension for Q and K. :param kv_cache_dtype: ``torch.dtype`` of the KV cache (e.g. ``torch.float16``). :param q_dtype: ``torch.dtype`` of the query tensor. :param max_context_len: Maximum sequence length the model supports. :param req_to_token: The ``[max_reqs, max_context_len]`` int32 tensor from ``ReqToTokenPool.req_to_token``. :param device: Target device (e.g. ``torch.device("cuda")``) :param max_req_pool_size: Maximum number of concurrent requests (= ``ReqToTokenPool.size``). Used to pre-allocate ``kv_indptr`` / ``kv_last_page_len`` buffers. :param sliding_window_size: When not ``None``, enables sliding-window attention mode which allocates two wrapper sets (window + full context). :param skip_prefill: When ``True``, skip creating prefill wrappers (for backends that only perform decode, e.g. multi-step draft backends). :param kv_indptr_buf: Optional pre-allocated ``kv_indptr`` buffer. Used when sharing buffers across multiple backend instances (e.g. multi-step draft). :param kv_last_page_len_buf: Optional pre-allocated ``kv_last_page_len`` buffer. :param init_new_workspace: When ``True`` allocate a fresh workspace buffer instead of reusing the global one. .. py:attribute:: num_heads .. py:attribute:: num_kv_heads .. py:attribute:: head_dim .. py:attribute:: kv_cache_dtype .. py:attribute:: q_dtype .. py:attribute:: max_context_len .. py:attribute:: req_to_token .. py:attribute:: device .. py:attribute:: skip_prefill :value: False .. py:attribute:: decode_use_tensor_cores .. py:attribute:: prefill_wrapper_ragged :type: Optional[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] :value: None .. py:attribute:: prefill_wrappers_paged :type: List[flashinfer.BatchPrefillWithPagedKVCacheWrapper] :value: [] .. py:attribute:: decode_wrappers :type: List[flashinfer.BatchDecodeWithPagedKVCacheWrapper] :value: [] .. py:attribute:: indices_updater_decode .. py:attribute:: forward_metadata :type: Optional[Union[DecodeMetadata, PrefillMetadata]] :value: None .. py:attribute:: decode_cuda_graph_metadata :type: dict .. py:attribute:: prefill_cuda_graph_metadata :type: dict .. py:method:: init_forward_metadata(forward_batch) Prepare FlashInfer wrappers for the current batch. Must be called once per batch before the model's ``forward`` method. .. py:method:: forward_extend(q, k, v, layer, forward_batch, save_kv_cache = True, **kwargs) Run attention for a prefill / extend step. .. py:method:: forward_decode(q, k, v, layer, forward_batch, save_kv_cache = True, **kwargs) Run attention for a decode step (one new token per sequence). .. py:method:: get_cuda_graph_seq_len_fill_value() Fill value used to pad ``seq_lens`` tensors for CUDA-graph capture. Most backends use ``1`` (not ``0``) to avoid division-by-zero in attention kernels. .. py:method:: init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf = None) Allocate CUDA-graph shared state buffers. .. py:method:: init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode) Set up metadata for CUDA-graph capture of a decode step. .. py:method:: init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu) Update metadata when replaying a CUDA graph for decode.