pymllm.engine.forward_batch¶
ForwardMode and ForwardBatch for pymllm.
Simplified forward-batch abstraction: no speculative decoding, no encoder-decoder support, and no distributed-attention complexity (DP/TP head splitting is handled at the layer level by the model code, not here).
Typical data flow¶
- ModelRunner builds a ForwardBatch
↓
- attn_backend.init_forward_metadata(forward_batch)
↓
- model.forward(input_ids, positions, forward_batch)
↓
- RadixAttention.forward(q, k, v, forward_batch)
↓
forward_batch.attn_backend.forward(q, k, v, layer, forward_batch)
Classes¶
Describes what kind of forward pass is being performed. |
|
All tensors required by a single forward pass through the model. |
Module Contents¶
- class pymllm.engine.forward_batch.ForwardMode¶
Bases:
enum.IntEnumDescribes what kind of forward pass is being performed.
Covers standard prefill / decode inference without speculative decoding.
- EXTEND¶
- DECODE¶
- MIXED¶
- IDLE¶
- is_extend()¶
True for EXTEND or MIXED (i.e. any prefill-style pass).
- Return type:
bool
- is_prefill()¶
Alias for
is_extend().- Return type:
bool
- is_decode()¶
- Return type:
bool
- is_mixed()¶
- Return type:
bool
- is_idle()¶
- Return type:
bool
- is_decode_or_idle()¶
- Return type:
bool
- class pymllm.engine.forward_batch.ForwardBatch¶
All tensors required by a single forward pass through the model.
- Parameters:
forward_mode – The kind of pass being performed (EXTEND / DECODE / MIXED / IDLE).
batch_size – Number of sequences in the batch.
input_ids – Token ids for every position in the batch, shape
[num_tokens]. For decode,num_tokens == batch_size; for extend,num_tokens == extend_num_tokens.req_pool_indices – Index of each sequence in
ReqToTokenPool, shape[batch_size](int32 or int64, on the target device).seq_lens – Total (prefix + new) length of each sequence, shape
[batch_size](int32).out_cache_loc – KV-pool slot that each output token is written to, shape
[num_tokens](int64).seq_lens_sum – Python
intequal toseq_lens.sum(). Cached to avoid repeated device-to-host syncs.seq_lens_cpu – CPU copy of
seq_lens(optional; used by some attention backends for plan computation without a device sync).positions – Token position for each input token, shape
[num_tokens](int32 or int64).extend_num_tokens – Total number of new (non-prefix) tokens across the batch. Only set during EXTEND / MIXED passes.
extend_seq_lens – Number of new tokens for each sequence, shape
[batch_size](int32). Only set during EXTEND / MIXED.extend_prefix_lens – Length of the already-cached prefix for each sequence, shape
[batch_size](int32). Only set during EXTEND / MIXED.extend_start_loc – Cumulative start offset of each sequence in the flattened extend token stream, shape
[batch_size](int32).extend_prefix_lens_cpu – CPU list mirror of
extend_prefix_lens.extend_seq_lens_cpu – CPU list mirror of
extend_seq_lens.return_logprob – Whether to compute per-token log-probabilities.
top_logprobs_nums – Number of top log-probs to return per sequence (None or list of ints).
req_to_token_pool – Reference to the
ReqToTokenPool(set by the model runner).token_to_kv_pool – Reference to the
KVPool(set by the model runner).attn_backend – The attention backend to use (set by the model runner before calling
model.forward).
- forward_mode: ForwardMode¶
- batch_size: int¶
- input_ids: torch.Tensor¶
- req_pool_indices: torch.Tensor¶
- seq_lens: torch.Tensor¶
- out_cache_loc: torch.Tensor¶
- seq_lens_sum: int¶
- seq_lens_cpu: torch.Tensor | None = None¶
- positions: torch.Tensor | None = None¶
- extend_num_tokens: int | None = None¶
- extend_seq_lens: torch.Tensor | None = None¶
- extend_prefix_lens: torch.Tensor | None = None¶
- extend_start_loc: torch.Tensor | None = None¶
- extend_prefix_lens_cpu: List[int] | None = None¶
- extend_seq_lens_cpu: List[int] | None = None¶
- return_logprob: bool = False¶
- top_logprobs_nums: List[int] | None = None¶
- req_to_token_pool: pymllm.mem_cache.memory_pool.ReqToTokenPool | None = None¶
- token_to_kv_pool: pymllm.mem_cache.memory_pool.KVPool | None = None¶
- attn_backend: pymllm.layers.attention.attention_backend.AttentionBackend | None = None¶
- mrope_position_deltas: torch.Tensor | None = None¶
- pixel_values: torch.Tensor | None = None¶
- image_grid_thw: torch.Tensor | None = None¶