pymllm.layers.sampling

Sampling operations with FlashInfer acceleration and PyTorch fallback.

This module wraps all flashinfer.sampling APIs and provides pure-PyTorch fallback implementations so that the rest of the codebase can import from here without worrying about whether FlashInfer is installed.

Attributes

Functions

softmax(logits[, temperature, enable_pdl])

Safe softmax with optional temperature scaling.

sampling_from_probs(probs[, indices, deterministic, ...])

Category sampling from probabilities.

sampling_from_logits(logits[, indices, deterministic, ...])

Category sampling from logits (applies softmax internally).

top_p_sampling_from_probs(probs, top_p[, indices, ...])

Top-p (nucleus) sampling from probabilities.

top_k_sampling_from_probs(probs, top_k[, indices, ...])

Top-k sampling from probabilities.

min_p_sampling_from_probs(probs, min_p[, indices, ...])

Min-p sampling from probabilities.

top_k_top_p_sampling_from_logits(logits, top_k, top_p)

Top-k + top-p sampling from pre-softmax logits.

top_k_top_p_sampling_from_probs(probs, top_k, top_p[, ...])

Top-k + top-p sampling from probabilities.

top_p_renorm_probs(probs, top_p)

Renormalize probabilities by top-p thresholding.

top_k_renorm_probs(probs, top_k)

Renormalize probabilities by top-k thresholding.

top_k_mask_logits(logits, top_k)

Mask logits by top-k thresholding (set non-top-k to -inf).

chain_speculative_sampling(draft_probs, ...[, ...])

Speculative sampling for sequence generation.

Module Contents

pymllm.layers.sampling.logger
pymllm.layers.sampling.softmax(logits, temperature=None, enable_pdl=None)

Safe softmax with optional temperature scaling.

Parameters:
  • logits (torch.Tensor) – Shape (batch_size, num_classes).

  • temperature (Optional[Union[torch.Tensor, float]]) – Scalar or per-request (batch_size,) temperature.

  • enable_pdl (Optional[bool]) – FlashInfer PDL flag (ignored in fallback).

Returns:

Probabilities with the same shape as logits.

Return type:

torch.Tensor

pymllm.layers.sampling.sampling_from_probs(probs, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)

Category sampling from probabilities.

Parameters:
  • probs (torch.Tensor) – (batch_size, num_classes) or (unique_batch_size, num_classes) when indices is provided.

  • indices (Optional[torch.Tensor]) – Maps each output to a row in probs.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • check_nan (bool) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

Sampled token ids, shape (batch_size,).

Return type:

torch.Tensor

pymllm.layers.sampling.sampling_from_logits(logits, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)

Category sampling from logits (applies softmax internally).

Parameters:
  • logits (torch.Tensor) – (batch_size, num_classes).

  • indices (Optional[torch.Tensor]) – See FlashInfer docs.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • check_nan (bool) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

Sampled token ids, shape (batch_size,).

Return type:

torch.Tensor

pymllm.layers.sampling.top_p_sampling_from_probs(probs, top_p, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)

Top-p (nucleus) sampling from probabilities.

Parameters:
  • probs (torch.Tensor) – (batch_size, num_classes).

  • top_p (Union[torch.Tensor, float]) – Top-p threshold.

  • indices (Optional[torch.Tensor]) – See FlashInfer docs.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • check_nan (bool) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

Sampled token ids, shape (batch_size,).

Return type:

torch.Tensor

pymllm.layers.sampling.top_k_sampling_from_probs(probs, top_k, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)

Top-k sampling from probabilities.

Parameters:
  • probs (torch.Tensor) – (batch_size, num_classes).

  • top_k (Union[torch.Tensor, int]) – Top-k threshold.

  • indices (Optional[torch.Tensor]) – See FlashInfer docs.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • check_nan (bool) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

Sampled token ids, shape (batch_size,).

Return type:

torch.Tensor

pymllm.layers.sampling.min_p_sampling_from_probs(probs, min_p, indices=None, deterministic=True, generator=None, check_nan=False, seed=None, offset=None)

Min-p sampling from probabilities.

Parameters:
  • probs (torch.Tensor) – (batch_size, num_classes).

  • min_p (Union[torch.Tensor, float]) – Min-p threshold.

  • indices (Optional[torch.Tensor]) – See FlashInfer docs.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • check_nan (bool) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

Sampled token ids, shape (batch_size,).

Return type:

torch.Tensor

pymllm.layers.sampling.top_k_top_p_sampling_from_logits(logits, top_k, top_p, indices=None, filter_apply_order='top_k_first', deterministic=True, generator=None, check_nan=False, seed=None, offset=None)

Top-k + top-p sampling from pre-softmax logits.

Parameters:
  • logits (torch.Tensor) – (batch_size, num_classes).

  • top_k (Union[torch.Tensor, int])

  • top_p (Union[torch.Tensor, float])

  • filter_apply_order (str) – "top_k_first" or "joint".

  • indices (Optional[torch.Tensor]) – See FlashInfer docs.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • check_nan (bool) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

Sampled token ids, shape (batch_size,).

Return type:

torch.Tensor

pymllm.layers.sampling.top_k_top_p_sampling_from_probs(probs, top_k, top_p, indices=None, filter_apply_order='top_k_first', deterministic=True, generator=None, check_nan=False, seed=None, offset=None)

Top-k + top-p sampling from probabilities.

Parameters:
  • probs (torch.Tensor) – (batch_size, num_classes).

  • top_k (Union[torch.Tensor, int])

  • top_p (Union[torch.Tensor, float])

  • filter_apply_order (str) – "top_k_first" or "joint".

  • indices (Optional[torch.Tensor]) – See FlashInfer docs.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • check_nan (bool) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

Sampled token ids, shape (batch_size,).

Return type:

torch.Tensor

pymllm.layers.sampling.top_p_renorm_probs(probs, top_p)

Renormalize probabilities by top-p thresholding.

Parameters:
  • probs (torch.Tensor) – (batch_size, num_classes).

  • top_p (Union[torch.Tensor, float]) – Top-p threshold in (0, 1).

Returns:

Renormalized probabilities.

Return type:

torch.Tensor

pymllm.layers.sampling.top_k_renorm_probs(probs, top_k)

Renormalize probabilities by top-k thresholding.

Parameters:
  • probs (torch.Tensor) – (batch_size, num_classes).

  • top_k (Union[torch.Tensor, int]) – Top-k threshold.

Returns:

Renormalized probabilities.

Return type:

torch.Tensor

pymllm.layers.sampling.top_k_mask_logits(logits, top_k)

Mask logits by top-k thresholding (set non-top-k to -inf).

Parameters:
  • logits (torch.Tensor) – (batch_size, num_classes).

  • top_k (Union[torch.Tensor, int]) – Top-k threshold.

Returns:

Masked logits with the same shape and dtype.

Return type:

torch.Tensor

pymllm.layers.sampling.chain_speculative_sampling(draft_probs, draft_token_ids, target_probs, maybe_output_accepted_token_num=None, maybe_output_emitted_draft_token_num=None, deterministic=True, generator=None, seed=None, offset=None)

Speculative sampling for sequence generation.

Parameters:
  • draft_probs (torch.Tensor) – (batch_size, num_speculate_tokens, vocab_size).

  • draft_token_ids (torch.Tensor) – (batch_size, num_speculate_tokens).

  • target_probs (torch.Tensor) – (batch_size, num_speculate_tokens + 1, vocab_size).

  • maybe_output_accepted_token_num (Optional[torch.Tensor]) – If provided, accepted counts are added in-place.

  • maybe_output_emitted_draft_token_num (Optional[torch.Tensor]) – If provided, emitted counts are added in-place.

  • deterministic (bool) – See FlashInfer docs.

  • generator (Optional[torch.Generator]) – See FlashInfer docs.

  • seed (Optional[int]) – See FlashInfer docs.

  • offset (Optional[int]) – See FlashInfer docs.

Returns:

  • output_token_ids (torch.Tensor) – (batch_size, num_speculate_tokens + 1), rejected slots padded with -1.

  • output_accepted_token_num (torch.Tensor) – (batch_size,).

  • output_emitted_draft_token_num (torch.Tensor) – (batch_size,).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

pymllm.layers.sampling.top_p_renorm_prob
pymllm.layers.sampling.top_k_renorm_prob