pymllm.layers.attention.attention_backend

Abstract base class for pymllm attention backends.

Every concrete backend (FlashInfer, Triton, torch-native, …) must implement at minimum:

  • init_forward_metadata – called once per batch before the model forward.

  • forward_extend – prefill / extend attention.

  • forward_decode – single-token decode attention.

The public forward method dispatches to the correct variant based on forward_batch.forward_mode.

Classes

AttentionBackend

Abstract base class for attention backends.

Module Contents

class pymllm.layers.attention.attention_backend.AttentionBackend

Bases: abc.ABC

Abstract base class for attention backends.

All concrete backends inherit from this class and implement the abstract methods below.

abstractmethod init_forward_metadata(forward_batch)

Prepare per-batch metadata before the model’s attention layers run.

For FlashInfer this plans the KV-index arrays and calls wrapper.begin_forward; for Triton / torch-native this is a no-op. Must be called once per batch before model.forward.

Parameters:

forward_batch (pymllm.engine.forward_batch.ForwardBatch)

Return type:

None

abstractmethod forward_decode(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)

Run attention for a decode step (one new token per sequence).

Parameters:
Return type:

torch.Tensor

abstractmethod forward_extend(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)

Run attention for a prefill / extend step.

Parameters:
Return type:

torch.Tensor

forward(q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs)

Dispatch to forward_decode or forward_extend based on mode.

For IDLE batches a zero-filled output tensor is returned without any compute.

Parameters:
Return type:

torch.Tensor

abstractmethod forward_gdn(layer, forward_batch, mixed_qkv, a, b)

Run GDN linear-attention for one layer.

Only implemented by backends that support hybrid (full + GDN) architectures. The default raises NotImplementedError.

Parameters:
Return type:

torch.Tensor

abstractmethod get_cuda_graph_seq_len_fill_value()

Fill value used to pad seq_lens tensors for CUDA-graph capture.

Most backends use 1 (not 0) to avoid division-by-zero in attention kernels.

Return type:

int

abstractmethod init_cuda_graph_state(max_bs, max_num_tokens)

Allocate shared CUDA-graph state (buffers reused across captures).

Parameters:
  • max_bs (int)

  • max_num_tokens (int)

Return type:

None

abstractmethod init_forward_metadata_capture_cuda_graph(bs, num_tokens, req_pool_indices, seq_lens, forward_mode)

Set up per-batch metadata for capturing a CUDA graph.

Parameters:
Return type:

None

abstractmethod init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens, seq_lens_sum, forward_mode, seq_lens_cpu)

Update metadata when replaying a captured CUDA graph.

Parameters:
Return type:

None