Source code for corrdim.high_level

from __future__ import annotations

from typing import Optional, Tuple

import torch

from .dimension import estimate_dimension_from_curve
from .low_level import ModelLike, curve_from_text, curve_from_texts, progressive_curve_from_text
from .types import CurveResult, DimensionResult, ProgressiveDimensionResult

DEFAULT_EPSILON_RANGE: Tuple[float, float] = (10**-20.0, 10**20.0)


def _default_measure_every_tokens(sequence_length: int) -> int:
    if sequence_length < 100:
        return 1
    if sequence_length < 1000:
        return 10
    return 100


def _truncate_text_by_tokens(
    text: str,
    truncation_tokens: int,
    model: ModelLike,
    tokenizer: Optional[object] = None,
) -> str:
    if truncation_tokens < 1:
        raise ValueError("truncation_tokens must be a positive integer.")

    tokenizer_obj = tokenizer
    if tokenizer_obj is None and hasattr(model, "tokenizer"):
        tokenizer_obj = model.tokenizer
    if tokenizer_obj is None and isinstance(model, str):
        from transformers import AutoTokenizer

        tokenizer_obj = AutoTokenizer.from_pretrained(model)

    if tokenizer_obj is None:
        raise ValueError("Cannot truncate by tokens without a tokenizer. Please pass tokenizer explicitly.")

    tokens = tokenizer_obj.encode(text, add_special_tokens=False)
    if len(tokens) <= truncation_tokens:
        return text
    return tokenizer_obj.decode(tokens[:truncation_tokens], skip_special_tokens=False)

[docs] def measure_text( text: str, model: ModelLike, tokenizer: Optional[object] = None, truncation_tokens: Optional[int] = None, context_length: Optional[int] = None, dim_reduction: Optional[int] = 8192, stride: int = 1, correlation_integral_range: Optional[Tuple[float, float]] = None, epsilon_range: Optional[Tuple[float, float]] = None, num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, precision: torch.dtype = torch.float16, backend: Optional[str] = None, **model_kwargs, ) -> DimensionResult: if truncation_tokens is not None: text = _truncate_text_by_tokens( text=text, truncation_tokens=truncation_tokens, model=model, tokenizer=tokenizer, ) effective_epsilon_range = DEFAULT_EPSILON_RANGE if epsilon_range is None else epsilon_range curve = curve_from_text( text=text, model=model, tokenizer=tokenizer, context_length=context_length, dim_reduction=dim_reduction, stride=stride, epsilon_range=effective_epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, precision=precision, backend=backend, **model_kwargs, ) return estimate_dimension_from_curve( curve, correlation_integral_range=correlation_integral_range, epsilon_range=epsilon_range, )
[docs] def measure_text_progressive( text: str, model: ModelLike, tokenizer: Optional[object] = None, truncation_tokens: Optional[int] = None, skip_prefix_tokens: int = 100, measure_every_tokens: Optional[int] = None, context_length: Optional[int] = None, dim_reduction: Optional[int] = 8192, stride: int = 1, correlation_integral_range: Optional[Tuple[float, float]] = None, epsilon_range: Optional[Tuple[float, float]] = None, num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, precision: torch.dtype = torch.float16, backend: Optional[str] = None, **model_kwargs, ) -> ProgressiveDimensionResult: """Compute progressive curves once, then fit correlation dimension at sampled prefixes. For each index ``i`` in ``range(skip_prefix_tokens, sequence_length, step)``, uses row ``corrints_progressive[i]`` with the shared ``epsilons`` grid. Results are in :attr:`~corrdim.types.ProgressiveDimensionResult.by_prefix` (``i`` → :class:`~corrdim.types.DimensionResult`). If ``measure_every_tokens`` is ``None``, ``step`` is chosen from ``sequence_length``: ``< 100`` → ``1``, ``< 1000`` → ``10``, otherwise ``100``. Other arguments follow :func:`measure_text` / :func:`~corrdim.low_level.progressive_curve_from_text`. """ if skip_prefix_tokens < 0: raise ValueError("skip_prefix_tokens must be non-negative.") if measure_every_tokens is not None and measure_every_tokens < 1: raise ValueError("measure_every_tokens must be a positive integer.") if truncation_tokens is not None: text = _truncate_text_by_tokens( text=text, truncation_tokens=truncation_tokens, model=model, tokenizer=tokenizer, ) effective_epsilon_range = DEFAULT_EPSILON_RANGE if epsilon_range is None else epsilon_range prog = progressive_curve_from_text( text=text, model=model, tokenizer=tokenizer, context_length=context_length, dim_reduction=dim_reduction, stride=stride, epsilon_range=effective_epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, precision=precision, backend=backend, **model_kwargs, ) step = ( measure_every_tokens if measure_every_tokens is not None else _default_measure_every_tokens(prog.sequence_length) ) out: dict[int, DimensionResult] = {} for i in range(skip_prefix_tokens, prog.sequence_length, step): curve = CurveResult( sequence_length=prog.sequence_length, epsilons=prog.epsilons, corrints=prog.corrints_progressive[i], ) out[i] = estimate_dimension_from_curve( curve, correlation_integral_range=correlation_integral_range, epsilon_range=epsilon_range, ) return ProgressiveDimensionResult( sequence_length=prog.sequence_length, epsilons=prog.epsilons, skip_prefix_tokens=skip_prefix_tokens, measure_every_tokens=step, by_prefix=out, )
[docs] def measure_texts( texts: list[str], model: ModelLike, tokenizer: Optional[object] = None, context_length: Optional[int] = None, dim_reduction: Optional[int] = 8192, stride: int = 1, correlation_integral_range: Optional[Tuple[float, float]] = None, epsilon_range: Optional[Tuple[float, float]] = None, num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, precision: torch.dtype = torch.float16, backend: Optional[str] = None, **model_kwargs, ) -> list[DimensionResult]: effective_epsilon_range = DEFAULT_EPSILON_RANGE if epsilon_range is None else epsilon_range curves = curve_from_texts( texts=texts, model=model, tokenizer=tokenizer, context_length=context_length, dim_reduction=dim_reduction, stride=stride, epsilon_range=effective_epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, precision=precision, backend=backend, **model_kwargs, ) return [ estimate_dimension_from_curve( curve, correlation_integral_range=correlation_integral_range, epsilon_range=epsilon_range, ) for curve in curves ]