pymllm.layers.attention.gdn_backend

GDN attention backend – pooled-state GDN computation for hybrid models.

Performs GDN (Gated Delta Net) linear-attention using externalized state stored in a GDNPool. Supports both extend (prefill) and decode paths with FlashInfer kernels.

This backend is not used directly; it is wrapped by HybridAttnBackend.

Attributes

Classes

GDNForwardMetadata

Per-batch metadata for GDN backend.

GDNAttnBackend

GDN linear-attention backend using pooled states.

Module Contents

pymllm.layers.attention.gdn_backend.logger
class pymllm.layers.attention.gdn_backend.GDNForwardMetadata

Per-batch metadata for GDN backend.

cache_indices: torch.Tensor
cu_seqlens: torch.Tensor | None = None
class pymllm.layers.attention.gdn_backend.GDNAttnBackend(gdn_pool, device)

GDN linear-attention backend using pooled states.

Handles both extend (prefill) and decode paths for GDN layers. Uses FlashInfer kernels when available (SM90+), with PyTorch fallback.

Parameters:
gdn_pool
device
forward_metadata: GDNForwardMetadata | None = None
init_forward_metadata(forward_batch)

Prepare GDN metadata from the current forward batch.

Parameters:

forward_batch (pymllm.engine.forward_batch.ForwardBatch)

Return type:

None

init_cuda_graph_state(max_bs, max_num_tokens)

Allocate CUDA-graph state for GDN backend.

The GDN pool buffers are already pre-allocated at fixed addresses, so we only need to allocate the metadata tensor.

Parameters:
  • max_bs (int)

  • max_num_tokens (int)

Return type:

None

init_forward_metadata_capture_cuda_graph(bs, req_pool_indices, seq_lens)

Set up GDN metadata for CUDA-graph capture (decode only).

Parameters:
  • bs (int)

  • req_pool_indices (torch.Tensor)

  • seq_lens (torch.Tensor)

Return type:

None

init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens)

Update GDN metadata for CUDA-graph replay (decode only).

Parameters:
  • bs (int)

  • req_pool_indices (torch.Tensor)

  • seq_lens (torch.Tensor)

Return type:

None

forward_decode(layer, forward_batch, mixed_qkv, a, b)

GDN decode: one new token per request.

Steps: 1. Gather conv_state from pool → [bs, conv_dim, K-1] 2. Conv1d update: shift + weighted sum for 1 new token 3. Scatter updated conv_state back to pool 4. SiLU → split q,k,v 5. FlashInfer gated_delta_rule_decode (or PyTorch fallback)

Parameters:
Return type:

torch.Tensor

forward_extend(layer, forward_batch, mixed_qkv, a, b)

GDN extend (prefill): multi-token per request.

Steps: 1. Gather conv_state from pool for each request 2. Per-request causal conv1d 3. Scatter new conv_state back to pool 4. SiLU → split q,k,v → gating 5. FlashInfer chunk_gated_delta_rule (or PyTorch fallback) 6. Scatter final recurrent state back to pool

Parameters:
Return type:

torch.Tensor

forward_gdn(layer, forward_batch, mixed_qkv, a, b)

Route to decode or extend based on forward mode.

Parameters:
Return type:

torch.Tensor