pymllm.engine.forward_batch

ForwardMode and ForwardBatch for pymllm.

Simplified forward-batch abstraction: no speculative decoding, no encoder-decoder support, and no distributed-attention complexity (DP/TP head splitting is handled at the layer level by the model code, not here).

Typical data flow

ModelRunner builds a ForwardBatch

attn_backend.init_forward_metadata(forward_batch)

model.forward(input_ids, positions, forward_batch)

RadixAttention.forward(q, k, v, forward_batch)

forward_batch.attn_backend.forward(q, k, v, layer, forward_batch)

Classes

ForwardMode

Describes what kind of forward pass is being performed.

ForwardBatch

All tensors required by a single forward pass through the model.

Module Contents

class pymllm.engine.forward_batch.ForwardMode

Bases: enum.IntEnum

Describes what kind of forward pass is being performed.

Covers standard prefill / decode inference without speculative decoding.

EXTEND
DECODE
MIXED
IDLE
is_extend()

True for EXTEND or MIXED (i.e. any prefill-style pass).

Return type:

bool

is_prefill()

Alias for is_extend().

Return type:

bool

is_decode()
Return type:

bool

is_mixed()
Return type:

bool

is_idle()
Return type:

bool

is_decode_or_idle()
Return type:

bool

class pymllm.engine.forward_batch.ForwardBatch

All tensors required by a single forward pass through the model.

Parameters:
  • forward_mode – The kind of pass being performed (EXTEND / DECODE / MIXED / IDLE).

  • batch_size – Number of sequences in the batch.

  • input_ids – Token ids for every position in the batch, shape [num_tokens]. For decode, num_tokens == batch_size; for extend, num_tokens == extend_num_tokens.

  • req_pool_indices – Index of each sequence in ReqToTokenPool, shape [batch_size] (int32 or int64, on the target device).

  • seq_lens – Total (prefix + new) length of each sequence, shape [batch_size] (int32).

  • out_cache_loc – KV-pool slot that each output token is written to, shape [num_tokens] (int64).

  • seq_lens_sum – Python int equal to seq_lens.sum(). Cached to avoid repeated device-to-host syncs.

  • seq_lens_cpu – CPU copy of seq_lens (optional; used by some attention backends for plan computation without a device sync).

  • positions – Token position for each input token, shape [num_tokens] (int32 or int64).

  • extend_num_tokens – Total number of new (non-prefix) tokens across the batch. Only set during EXTEND / MIXED passes.

  • extend_seq_lens – Number of new tokens for each sequence, shape [batch_size] (int32). Only set during EXTEND / MIXED.

  • extend_prefix_lens – Length of the already-cached prefix for each sequence, shape [batch_size] (int32). Only set during EXTEND / MIXED.

  • extend_start_loc – Cumulative start offset of each sequence in the flattened extend token stream, shape [batch_size] (int32).

  • extend_prefix_lens_cpu – CPU list mirror of extend_prefix_lens.

  • extend_seq_lens_cpu – CPU list mirror of extend_seq_lens.

  • return_logprob – Whether to compute per-token log-probabilities.

  • top_logprobs_nums – Number of top log-probs to return per sequence (None or list of ints).

  • req_to_token_pool – Reference to the ReqToTokenPool (set by the model runner).

  • token_to_kv_pool – Reference to the KVPool (set by the model runner).

  • attn_backend – The attention backend to use (set by the model runner before calling model.forward).

forward_mode: ForwardMode
batch_size: int
input_ids: torch.Tensor
req_pool_indices: torch.Tensor
seq_lens: torch.Tensor
out_cache_loc: torch.Tensor
seq_lens_sum: int
seq_lens_cpu: torch.Tensor | None = None
positions: torch.Tensor | None = None
extend_num_tokens: int | None = None
extend_seq_lens: torch.Tensor | None = None
extend_prefix_lens: torch.Tensor | None = None
extend_start_loc: torch.Tensor | None = None
extend_prefix_lens_cpu: List[int] | None = None
extend_seq_lens_cpu: List[int] | None = None
return_logprob: bool = False
top_logprobs_nums: List[int] | None = None
req_to_token_pool: pymllm.mem_cache.memory_pool.ReqToTokenPool | None = None
token_to_kv_pool: pymllm.mem_cache.memory_pool.KVPool | None = None
attn_backend: pymllm.layers.attention.attention_backend.AttentionBackend | None = None
mrope_position_deltas: torch.Tensor | None = None
pixel_values: torch.Tensor | None = None
image_grid_thw: torch.Tensor | None = None