pymllm.layers.attention.gdn_backend¶
GDN attention backend – pooled-state GDN computation for hybrid models.
Performs GDN (Gated Delta Net) linear-attention using externalized state
stored in a GDNPool. Supports
both extend (prefill) and decode paths with FlashInfer kernels.
This backend is not used directly; it is wrapped by
HybridAttnBackend.
Attributes¶
Classes¶
Per-batch metadata for GDN backend. |
|
GDN linear-attention backend using pooled states. |
Module Contents¶
- pymllm.layers.attention.gdn_backend.logger¶
- class pymllm.layers.attention.gdn_backend.GDNForwardMetadata¶
Per-batch metadata for GDN backend.
- cache_indices: torch.Tensor¶
- cu_seqlens: torch.Tensor | None = None¶
- class pymllm.layers.attention.gdn_backend.GDNAttnBackend(gdn_pool, device)¶
GDN linear-attention backend using pooled states.
Handles both extend (prefill) and decode paths for GDN layers. Uses FlashInfer kernels when available (SM90+), with PyTorch fallback.
- Parameters:
gdn_pool (pymllm.mem_cache.memory_pool.GDNPool) – Pre-allocated
GDNPool.device (torch.device) – Target device.
- gdn_pool¶
- device¶
- forward_metadata: GDNForwardMetadata | None = None¶
- init_forward_metadata(forward_batch)¶
Prepare GDN metadata from the current forward batch.
- Parameters:
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
- Return type:
None
- init_cuda_graph_state(max_bs, max_num_tokens)¶
Allocate CUDA-graph state for GDN backend.
The GDN pool buffers are already pre-allocated at fixed addresses, so we only need to allocate the metadata tensor.
- Parameters:
max_bs (int)
max_num_tokens (int)
- Return type:
None
- init_forward_metadata_capture_cuda_graph(bs, req_pool_indices, seq_lens)¶
Set up GDN metadata for CUDA-graph capture (decode only).
- Parameters:
bs (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
- Return type:
None
- init_forward_metadata_replay_cuda_graph(bs, req_pool_indices, seq_lens)¶
Update GDN metadata for CUDA-graph replay (decode only).
- Parameters:
bs (int)
req_pool_indices (torch.Tensor)
seq_lens (torch.Tensor)
- Return type:
None
- forward_decode(layer, forward_batch, mixed_qkv, a, b)¶
GDN decode: one new token per request.
Steps: 1. Gather conv_state from pool → [bs, conv_dim, K-1] 2. Conv1d update: shift + weighted sum for 1 new token 3. Scatter updated conv_state back to pool 4. SiLU → split q,k,v 5. FlashInfer gated_delta_rule_decode (or PyTorch fallback)
- Parameters:
layer (pymllm.layers.attention.radix_linear_attention.RadixLinearAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
mixed_qkv (torch.Tensor)
a (torch.Tensor)
b (torch.Tensor)
- Return type:
torch.Tensor
- forward_extend(layer, forward_batch, mixed_qkv, a, b)¶
GDN extend (prefill): multi-token per request.
Steps: 1. Gather conv_state from pool for each request 2. Per-request causal conv1d 3. Scatter new conv_state back to pool 4. SiLU → split q,k,v → gating 5. FlashInfer chunk_gated_delta_rule (or PyTorch fallback) 6. Scatter final recurrent state back to pool
- Parameters:
layer (pymllm.layers.attention.radix_linear_attention.RadixLinearAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
mixed_qkv (torch.Tensor)
a (torch.Tensor)
b (torch.Tensor)
- Return type:
torch.Tensor
- forward_gdn(layer, forward_batch, mixed_qkv, a, b)¶
Route to decode or extend based on forward mode.
- Parameters:
layer (pymllm.layers.attention.radix_linear_attention.RadixLinearAttention)
forward_batch (pymllm.engine.forward_batch.ForwardBatch)
mixed_qkv (torch.Tensor)
a (torch.Tensor)
b (torch.Tensor)
- Return type:
torch.Tensor