pymllm.layers.sampling ====================== .. py:module:: pymllm.layers.sampling .. autoapi-nested-parse:: Sampling operations with FlashInfer acceleration and PyTorch fallback. This module wraps all flashinfer.sampling APIs and provides pure-PyTorch fallback implementations so that the rest of the codebase can import from here without worrying about whether FlashInfer is installed. Attributes ---------- .. autoapisummary:: pymllm.layers.sampling.logger pymllm.layers.sampling.top_p_renorm_prob pymllm.layers.sampling.top_k_renorm_prob Functions --------- .. autoapisummary:: pymllm.layers.sampling.softmax pymllm.layers.sampling.sampling_from_probs pymllm.layers.sampling.sampling_from_logits pymllm.layers.sampling.top_p_sampling_from_probs pymllm.layers.sampling.top_k_sampling_from_probs pymllm.layers.sampling.min_p_sampling_from_probs pymllm.layers.sampling.top_k_top_p_sampling_from_logits pymllm.layers.sampling.top_k_top_p_sampling_from_probs pymllm.layers.sampling.top_p_renorm_probs pymllm.layers.sampling.top_k_renorm_probs pymllm.layers.sampling.top_k_mask_logits pymllm.layers.sampling.chain_speculative_sampling Module Contents --------------- .. py:data:: logger .. py:function:: softmax(logits, temperature = None, enable_pdl = None) Safe softmax with optional temperature scaling. :param logits: Shape ``(batch_size, num_classes)``. :type logits: torch.Tensor :param temperature: Scalar or per-request ``(batch_size,)`` temperature. :type temperature: Optional[Union[torch.Tensor, float]] :param enable_pdl: FlashInfer PDL flag (ignored in fallback). :type enable_pdl: Optional[bool] :returns: Probabilities with the same shape as *logits*. :rtype: torch.Tensor .. py:function:: sampling_from_probs(probs, indices = None, deterministic = True, generator = None, check_nan = False, seed = None, offset = None) Category sampling from probabilities. :param probs: ``(batch_size, num_classes)`` or ``(unique_batch_size, num_classes)`` when *indices* is provided. :type probs: torch.Tensor :param indices: Maps each output to a row in *probs*. :type indices: Optional[torch.Tensor] :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param check_nan: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: Sampled token ids, shape ``(batch_size,)``. :rtype: torch.Tensor .. py:function:: sampling_from_logits(logits, indices = None, deterministic = True, generator = None, check_nan = False, seed = None, offset = None) Category sampling from logits (applies softmax internally). :param logits: ``(batch_size, num_classes)``. :type logits: torch.Tensor :param indices: See FlashInfer docs. :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param check_nan: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: Sampled token ids, shape ``(batch_size,)``. :rtype: torch.Tensor .. py:function:: top_p_sampling_from_probs(probs, top_p, indices = None, deterministic = True, generator = None, check_nan = False, seed = None, offset = None) Top-p (nucleus) sampling from probabilities. :param probs: ``(batch_size, num_classes)``. :type probs: torch.Tensor :param top_p: Top-p threshold. :type top_p: Union[torch.Tensor, float] :param indices: See FlashInfer docs. :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param check_nan: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: Sampled token ids, shape ``(batch_size,)``. :rtype: torch.Tensor .. py:function:: top_k_sampling_from_probs(probs, top_k, indices = None, deterministic = True, generator = None, check_nan = False, seed = None, offset = None) Top-k sampling from probabilities. :param probs: ``(batch_size, num_classes)``. :type probs: torch.Tensor :param top_k: Top-k threshold. :type top_k: Union[torch.Tensor, int] :param indices: See FlashInfer docs. :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param check_nan: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: Sampled token ids, shape ``(batch_size,)``. :rtype: torch.Tensor .. py:function:: min_p_sampling_from_probs(probs, min_p, indices = None, deterministic = True, generator = None, check_nan = False, seed = None, offset = None) Min-p sampling from probabilities. :param probs: ``(batch_size, num_classes)``. :type probs: torch.Tensor :param min_p: Min-p threshold. :type min_p: Union[torch.Tensor, float] :param indices: See FlashInfer docs. :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param check_nan: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: Sampled token ids, shape ``(batch_size,)``. :rtype: torch.Tensor .. py:function:: top_k_top_p_sampling_from_logits(logits, top_k, top_p, indices = None, filter_apply_order = 'top_k_first', deterministic = True, generator = None, check_nan = False, seed = None, offset = None) Top-k + top-p sampling from pre-softmax logits. :param logits: ``(batch_size, num_classes)``. :type logits: torch.Tensor :param top_k: :type top_k: Union[torch.Tensor, int] :param top_p: :type top_p: Union[torch.Tensor, float] :param filter_apply_order: ``"top_k_first"`` or ``"joint"``. :type filter_apply_order: str :param indices: See FlashInfer docs. :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param check_nan: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: Sampled token ids, shape ``(batch_size,)``. :rtype: torch.Tensor .. py:function:: top_k_top_p_sampling_from_probs(probs, top_k, top_p, indices = None, filter_apply_order = 'top_k_first', deterministic = True, generator = None, check_nan = False, seed = None, offset = None) Top-k + top-p sampling from probabilities. :param probs: ``(batch_size, num_classes)``. :type probs: torch.Tensor :param top_k: :type top_k: Union[torch.Tensor, int] :param top_p: :type top_p: Union[torch.Tensor, float] :param filter_apply_order: ``"top_k_first"`` or ``"joint"``. :type filter_apply_order: str :param indices: See FlashInfer docs. :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param check_nan: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: Sampled token ids, shape ``(batch_size,)``. :rtype: torch.Tensor .. py:function:: top_p_renorm_probs(probs, top_p) Renormalize probabilities by top-p thresholding. :param probs: ``(batch_size, num_classes)``. :type probs: torch.Tensor :param top_p: Top-p threshold in ``(0, 1)``. :type top_p: Union[torch.Tensor, float] :returns: Renormalized probabilities. :rtype: torch.Tensor .. py:function:: top_k_renorm_probs(probs, top_k) Renormalize probabilities by top-k thresholding. :param probs: ``(batch_size, num_classes)``. :type probs: torch.Tensor :param top_k: Top-k threshold. :type top_k: Union[torch.Tensor, int] :returns: Renormalized probabilities. :rtype: torch.Tensor .. py:function:: top_k_mask_logits(logits, top_k) Mask logits by top-k thresholding (set non-top-k to -inf). :param logits: ``(batch_size, num_classes)``. :type logits: torch.Tensor :param top_k: Top-k threshold. :type top_k: Union[torch.Tensor, int] :returns: Masked logits with the same shape and dtype. :rtype: torch.Tensor .. py:function:: chain_speculative_sampling(draft_probs, draft_token_ids, target_probs, maybe_output_accepted_token_num = None, maybe_output_emitted_draft_token_num = None, deterministic = True, generator = None, seed = None, offset = None) Speculative sampling for sequence generation. :param draft_probs: ``(batch_size, num_speculate_tokens, vocab_size)``. :type draft_probs: torch.Tensor :param draft_token_ids: ``(batch_size, num_speculate_tokens)``. :type draft_token_ids: torch.Tensor :param target_probs: ``(batch_size, num_speculate_tokens + 1, vocab_size)``. :type target_probs: torch.Tensor :param maybe_output_accepted_token_num: If provided, accepted counts are added in-place. :type maybe_output_accepted_token_num: Optional[torch.Tensor] :param maybe_output_emitted_draft_token_num: If provided, emitted counts are added in-place. :type maybe_output_emitted_draft_token_num: Optional[torch.Tensor] :param deterministic: See FlashInfer docs. :param generator: See FlashInfer docs. :param seed: See FlashInfer docs. :param offset: See FlashInfer docs. :returns: * **output_token_ids** (*torch.Tensor*) -- ``(batch_size, num_speculate_tokens + 1)``, rejected slots padded with -1. * **output_accepted_token_num** (*torch.Tensor*) -- ``(batch_size,)``. * **output_emitted_draft_token_num** (*torch.Tensor*) -- ``(batch_size,)``. .. py:data:: top_p_renorm_prob .. py:data:: top_k_renorm_prob