pymllm.layers.sampling¶
Sampling operations with FlashInfer acceleration and PyTorch fallback.
This module wraps all flashinfer.sampling APIs and provides pure-PyTorch fallback implementations so that the rest of the codebase can import from here without worrying about whether FlashInfer is installed.
Attributes¶
Functions¶
|
Safe softmax with optional temperature scaling. |
|
Category sampling from probabilities. |
|
Category sampling from logits (applies softmax internally). |
|
Top-p (nucleus) sampling from probabilities. |
|
Top-k sampling from probabilities. |
|
Min-p sampling from probabilities. |
|
Top-k + top-p sampling from pre-softmax logits. |
|
Top-k + top-p sampling from probabilities. |
|
Renormalize probabilities by top-p thresholding. |
|
Renormalize probabilities by top-k thresholding. |
|
Mask logits by top-k thresholding (set non-top-k to -inf). |
|
Speculative sampling for sequence generation. |
Module Contents¶
- pymllm.layers.sampling.logger¶
- pymllm.layers.sampling.softmax(logits, temperature=None, enable_pdl=None)¶
Safe softmax with optional temperature scaling.
- Parameters:
logits (torch.Tensor) – Shape
(batch_size, num_classes).temperature (Optional[Union[torch.Tensor, float]]) – Scalar or per-request
(batch_size,)temperature.enable_pdl (Optional[bool]) – FlashInfer PDL flag (ignored in fallback).
- Returns:
Probabilities with the same shape as logits.
- Return type:
torch.Tensor
- pymllm.layers.sampling.sampling_from_probs(probs, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)¶
Category sampling from probabilities.
- Parameters:
probs (torch.Tensor) –
(batch_size, num_classes)or(unique_batch_size, num_classes)when indices is provided.indices (Optional[torch.Tensor]) – Maps each output to a row in probs.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
check_nan (bool) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
Sampled token ids, shape
(batch_size,).- Return type:
torch.Tensor
- pymllm.layers.sampling.sampling_from_logits(logits, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)¶
Category sampling from logits (applies softmax internally).
- Parameters:
logits (torch.Tensor) –
(batch_size, num_classes).indices (Optional[torch.Tensor]) – See FlashInfer docs.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
check_nan (bool) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
Sampled token ids, shape
(batch_size,).- Return type:
torch.Tensor
- pymllm.layers.sampling.top_p_sampling_from_probs(probs, top_p, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)¶
Top-p (nucleus) sampling from probabilities.
- Parameters:
probs (torch.Tensor) –
(batch_size, num_classes).top_p (Union[torch.Tensor, float]) – Top-p threshold.
indices (Optional[torch.Tensor]) – See FlashInfer docs.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
check_nan (bool) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
Sampled token ids, shape
(batch_size,).- Return type:
torch.Tensor
- pymllm.layers.sampling.top_k_sampling_from_probs(probs, top_k, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)¶
Top-k sampling from probabilities.
- Parameters:
probs (torch.Tensor) –
(batch_size, num_classes).top_k (Union[torch.Tensor, int]) – Top-k threshold.
indices (Optional[torch.Tensor]) – See FlashInfer docs.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
check_nan (bool) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
Sampled token ids, shape
(batch_size,).- Return type:
torch.Tensor
- pymllm.layers.sampling.min_p_sampling_from_probs(probs, min_p, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)¶
Min-p sampling from probabilities.
- Parameters:
probs (torch.Tensor) –
(batch_size, num_classes).min_p (Union[torch.Tensor, float]) – Min-p threshold.
indices (Optional[torch.Tensor]) – See FlashInfer docs.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
check_nan (bool) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
Sampled token ids, shape
(batch_size,).- Return type:
torch.Tensor
- pymllm.layers.sampling.top_k_top_p_sampling_from_logits(logits, top_k, top_p, indices=None, filter_apply_order='top_k_first', deterministic=True, generator=None, check_nan=False, seed=None, offset=None)¶
Top-k + top-p sampling from pre-softmax logits.
- Parameters:
logits (torch.Tensor) –
(batch_size, num_classes).top_k (Union[torch.Tensor, int])
top_p (Union[torch.Tensor, float])
filter_apply_order (str) –
"top_k_first"or"joint".indices (Optional[torch.Tensor]) – See FlashInfer docs.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
check_nan (bool) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
Sampled token ids, shape
(batch_size,).- Return type:
torch.Tensor
- pymllm.layers.sampling.top_k_top_p_sampling_from_probs(probs, top_k, top_p, indices=None, filter_apply_order='top_k_first', deterministic=True, generator=None, check_nan=False, seed=None, offset=None)¶
Top-k + top-p sampling from probabilities.
- Parameters:
probs (torch.Tensor) –
(batch_size, num_classes).top_k (Union[torch.Tensor, int])
top_p (Union[torch.Tensor, float])
filter_apply_order (str) –
"top_k_first"or"joint".indices (Optional[torch.Tensor]) – See FlashInfer docs.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
check_nan (bool) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
Sampled token ids, shape
(batch_size,).- Return type:
torch.Tensor
- pymllm.layers.sampling.top_p_renorm_probs(probs, top_p)¶
Renormalize probabilities by top-p thresholding.
- Parameters:
probs (torch.Tensor) –
(batch_size, num_classes).top_p (Union[torch.Tensor, float]) – Top-p threshold in
(0, 1).
- Returns:
Renormalized probabilities.
- Return type:
torch.Tensor
- pymllm.layers.sampling.top_k_renorm_probs(probs, top_k)¶
Renormalize probabilities by top-k thresholding.
- Parameters:
probs (torch.Tensor) –
(batch_size, num_classes).top_k (Union[torch.Tensor, int]) – Top-k threshold.
- Returns:
Renormalized probabilities.
- Return type:
torch.Tensor
- pymllm.layers.sampling.top_k_mask_logits(logits, top_k)¶
Mask logits by top-k thresholding (set non-top-k to -inf).
- Parameters:
logits (torch.Tensor) –
(batch_size, num_classes).top_k (Union[torch.Tensor, int]) – Top-k threshold.
- Returns:
Masked logits with the same shape and dtype.
- Return type:
torch.Tensor
- pymllm.layers.sampling.chain_speculative_sampling(draft_probs, draft_token_ids, target_probs, maybe_output_accepted_token_num=None, maybe_output_emitted_draft_token_num=None, deterministic=True, generator=None, seed=None, offset=None)¶
Speculative sampling for sequence generation.
- Parameters:
draft_probs (torch.Tensor) –
(batch_size, num_speculate_tokens, vocab_size).draft_token_ids (torch.Tensor) –
(batch_size, num_speculate_tokens).target_probs (torch.Tensor) –
(batch_size, num_speculate_tokens + 1, vocab_size).maybe_output_accepted_token_num (Optional[torch.Tensor]) – If provided, accepted counts are added in-place.
maybe_output_emitted_draft_token_num (Optional[torch.Tensor]) – If provided, emitted counts are added in-place.
deterministic (bool) – See FlashInfer docs.
generator (Optional[torch.Generator]) – See FlashInfer docs.
seed (Optional[int]) – See FlashInfer docs.
offset (Optional[int]) – See FlashInfer docs.
- Returns:
output_token_ids (torch.Tensor) –
(batch_size, num_speculate_tokens + 1), rejected slots padded with -1.output_accepted_token_num (torch.Tensor) –
(batch_size,).output_emitted_draft_token_num (torch.Tensor) –
(batch_size,).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- pymllm.layers.sampling.top_p_renorm_prob¶
- pymllm.layers.sampling.top_k_renorm_prob¶