ARGeneration API

The ARGeneration class is an abstract base class for autoregressive generation models in MLLM. It provides essential functionalities for generating sequences using various sampling methods and tracing capabilities.

#include "mllm/models/ARGeneration.hpp"

Base Class

class ARGeneration

Abstract base class for autoregressive generation models.

Protected Attributes

bool ARGeneration::do_sample_

Flag indicating whether to perform sampling during generation. Default is false.

int ARGeneration::eos_token_id_

End-of-sequence token ID used to terminate generation.

int ARGeneration::max_length_

Maximum length of generated sequences. Default is 1024.

Core Virtual Methods

virtual ARGenerationOutputPast ARGeneration::forward(const ARGenerationOutputPast &input, const ARGenerationArgs &args) = 0

Pure virtual function for forward pass of the model.

Parameters:
  • input – Input tensors map

  • args – Arguments for the forward pass

Returns:

Output tensors map with past states

virtual ARGenerationOutputPast ARGeneration::generate(const ARGenerationOutputPast &input, const ARGenerationArgs &args)

Generate sequences using the model.

Parameters:
  • input – Input tensors map

  • args – Arguments for generation

Returns:

Generated output tensors map with past states

virtual void ARGeneration::streamGenerate(const ARGenerationOutputPast &input, const ARGenerationArgs &args, const std::function<void(int64_t)> &callback)

Generate sequences with streaming output.

Parameters:
  • input – Input tensors map

  • args – Arguments for generation

  • callback – Callback function to handle generated tokens

virtual IROutput ARGeneration::trace(const ARGenerationOutputPast &input, const ARGenerationArgs &args)

Trace the model execution for compilation or analysis.

Parameters:
  • input – Input tensors map

  • args – Arguments for tracing

Returns:

IR context output map

Sampling Methods

int64_t ARGeneration::sampleGreedy(Tensor &logits)

Sample the next token using greedy strategy (select the token with highest probability).

Parameters:

logits – Logits tensor from the model

Returns:

Selected token ID

int64_t ARGeneration::sampleTemperature(Tensor &logits, float temperature)

Sample the next token using temperature-based sampling.

Parameters:
  • logits – Logits tensor from the model

  • temperature – Temperature value for sampling (higher values increase randomness)

Returns:

Selected token ID

int64_t ARGeneration::sampleTopK(Tensor &logits, int k, float temperature)

Sample the next token using top-k sampling strategy.

Parameters:
  • logits – Logits tensor from the model

  • k – Number of top tokens to consider

  • temperature – Temperature value for sampling

Returns:

Selected token ID

int64_t ARGeneration::sampleTopP(Tensor &logits, float p, float temperature)

Sample the next token using nucleus (top-p) sampling strategy.

Parameters:
  • logits – Logits tensor from the model

  • p – Cumulative probability threshold

  • temperature – Temperature value for sampling

Returns:

Selected token ID

int64_t ARGeneration::categoricalSample(const Tensor &probs)

Sample from a categorical distribution.

Parameters:

probs – Probability distribution tensor

Returns:

Sampled token ID

Utility Methods

Tensor ARGeneration::getLastLogits(Tensor &logits)

Extract the logits for the last token in the sequence.

Parameters:

logits – Full logits tensor

Returns:

Logits for the last token

int ARGeneration::sampleFromDistribution(const std::vector<float> &probs)

Sample from a probability distribution.

Parameters:

probs – Vector of probabilities

Returns:

Sampled index