ARGeneration API¶
The ARGeneration class is an abstract base class for autoregressive generation models in MLLM. It provides essential functionalities for generating sequences using various sampling methods and tracing capabilities.
#include "mllm/models/ARGeneration.hpp"
Base Class¶
-
class ARGeneration¶
Abstract base class for autoregressive generation models.
Protected Attributes¶
-
bool ARGeneration::do_sample_¶
Flag indicating whether to perform sampling during generation. Default is false.
-
int ARGeneration::eos_token_id_¶
End-of-sequence token ID used to terminate generation.
-
int ARGeneration::max_length_¶
Maximum length of generated sequences. Default is 1024.
Core Virtual Methods¶
-
virtual ARGenerationOutputPast ARGeneration::forward(const ARGenerationOutputPast &input, const ARGenerationArgs &args) = 0¶
Pure virtual function for forward pass of the model.
- Parameters:
input – Input tensors map
args – Arguments for the forward pass
- Returns:
Output tensors map with past states
-
virtual ARGenerationOutputPast ARGeneration::generate(const ARGenerationOutputPast &input, const ARGenerationArgs &args)¶
Generate sequences using the model.
- Parameters:
input – Input tensors map
args – Arguments for generation
- Returns:
Generated output tensors map with past states
-
virtual void ARGeneration::streamGenerate(const ARGenerationOutputPast &input, const ARGenerationArgs &args, const std::function<void(int64_t)> &callback)¶
Generate sequences with streaming output.
- Parameters:
input – Input tensors map
args – Arguments for generation
callback – Callback function to handle generated tokens
-
virtual IROutput ARGeneration::trace(const ARGenerationOutputPast &input, const ARGenerationArgs &args)¶
Trace the model execution for compilation or analysis.
- Parameters:
input – Input tensors map
args – Arguments for tracing
- Returns:
IR context output map
Sampling Methods¶
-
int64_t ARGeneration::sampleGreedy(Tensor &logits)¶
Sample the next token using greedy strategy (select the token with highest probability).
- Parameters:
logits – Logits tensor from the model
- Returns:
Selected token ID
-
int64_t ARGeneration::sampleTemperature(Tensor &logits, float temperature)¶
Sample the next token using temperature-based sampling.
- Parameters:
logits – Logits tensor from the model
temperature – Temperature value for sampling (higher values increase randomness)
- Returns:
Selected token ID
-
int64_t ARGeneration::sampleTopK(Tensor &logits, int k, float temperature)¶
Sample the next token using top-k sampling strategy.
- Parameters:
logits – Logits tensor from the model
k – Number of top tokens to consider
temperature – Temperature value for sampling
- Returns:
Selected token ID
-
int64_t ARGeneration::sampleTopP(Tensor &logits, float p, float temperature)¶
Sample the next token using nucleus (top-p) sampling strategy.
- Parameters:
logits – Logits tensor from the model
p – Cumulative probability threshold
temperature – Temperature value for sampling
- Returns:
Selected token ID
-
int64_t ARGeneration::categoricalSample(const Tensor &probs)¶
Sample from a categorical distribution.
- Parameters:
probs – Probability distribution tensor
- Returns:
Sampled token ID
Utility Methods¶
-
Tensor ARGeneration::getLastLogits(Tensor &logits)¶
Extract the logits for the last token in the sequence.
- Parameters:
logits – Full logits tensor
- Returns:
Logits for the last token
-
int ARGeneration::sampleFromDistribution(const std::vector<float> &probs)¶
Sample from a probability distribution.
- Parameters:
probs – Vector of probabilities
- Returns:
Sampled index