Skip to content

Sampling Models

aana.core.models.sampling

SamplingParams

Bases: BaseModel

A model for sampling parameters of LLM.

ATTRIBUTE DESCRIPTION
temperature

Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.

TYPE: float

top_p

Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens.

TYPE: float

top_k

Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.

TYPE: int

max_tokens

The maximum number of tokens to generate.

TYPE: int

repetition_penalty

Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens. Default is 1.0 (no penalty).

TYPE: float

kwargs

Extra keyword arguments to pass as sampling parameters.

TYPE: dict

check_top_k

check_top_k(v)

Validates a top_k argument.

Makes sure it is either -1, or at least 1.

PARAMETER DESCRIPTION
v

Value to validate.

TYPE: int

RAISES DESCRIPTION
ValueError

The value is not valid.

RETURNS DESCRIPTION

The top_k value.

Source code in aana/core/models/sampling.py
@field_validator("top_k")
def check_top_k(cls, v: int):
    """Validates a top_k argument.

    Makes sure it is either -1, or at least 1.

    Args:
        v (int): Value to validate.

    Raises:
        ValueError: The value is not valid.

    Returns:
        The top_k value.
    """
    if v is None:
        return v
    if v < -1 or v == 0:
        raise ValueError(f"top_k must be -1 (disable), or at least 1, got {v}.")  # noqa: TRY003
    return v