Source code for corrdim.low_level

from __future__ import annotations

from collections import OrderedDict
from threading import Lock
from typing import Optional, Tuple, Union

import numpy as np
import torch

from .corrint import correlation_integral, progressive_correlation_integral
from .models import LanguageModelWrapper, create_model_wrapper
from .types import CurveResult, ProgressiveCurveResult
from .utils import clamp

ModelLike = Union[str, LanguageModelWrapper]
_MODEL_CACHE_MAX_SIZE = 2
_MODEL_CACHE: "OrderedDict[tuple, LanguageModelWrapper]" = OrderedDict()
_MODEL_CACHE_LOCK = Lock()


def _freeze_kwargs(kwargs: dict) -> tuple:
    return tuple(sorted((str(key), repr(value)) for key, value in kwargs.items()))


def _cache_key(model_name: str, tokenizer: Optional[object], device: Optional[str], kwargs: dict) -> tuple:
    tokenizer_key = None if tokenizer is None else id(tokenizer)
    return (model_name, tokenizer_key, device, _freeze_kwargs(kwargs))


[docs] def clear_model_cache() -> None: with _MODEL_CACHE_LOCK: _MODEL_CACHE.clear()
def _resolve_model_wrapper( model: ModelLike, tokenizer: Optional[object] = None, device: Optional[str] = None, **kwargs, ) -> LanguageModelWrapper: if isinstance(model, str): key = _cache_key(model, tokenizer, device, kwargs) with _MODEL_CACHE_LOCK: cached = _MODEL_CACHE.get(key) if cached is not None: _MODEL_CACHE.move_to_end(key) return cached wrapper = create_model_wrapper(model, tokenizer=tokenizer, device=device, **kwargs) with _MODEL_CACHE_LOCK: _MODEL_CACHE[key] = wrapper _MODEL_CACHE.move_to_end(key) while len(_MODEL_CACHE) > _MODEL_CACHE_MAX_SIZE: _MODEL_CACHE.popitem(last=False) return wrapper return model def _make_epsilons(vecs: torch.Tensor, epsilon_range: Tuple[float, float], num_epsilon: int) -> torch.Tensor: if not torch.isfinite(vecs).all(): raise ValueError("Found nan or inf in vectors.") if vecs.shape[-2] <= 100: raise ValueError(f"The sequence length is too short ({vecs.shape[-2]} tokens). Please use at least 100 tokens.") return torch.logspace( np.log10(float(epsilon_range[0])), np.log10(float(epsilon_range[1])), num_epsilon, device=vecs.device, )
[docs] def curve_from_vectors( vectors: torch.Tensor, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, backend: Optional[str] = None, ) -> CurveResult: epsilons = _make_epsilons(vectors, epsilon_range, num_epsilon) corrints = correlation_integral(vectors, epsilons, block_size=block_size, show_progress=show_progress, backend=backend) eps_clamped, corr_clamped = clamp(epsilons, corrints, low=float(corrints.min()), high=0.95) if len(eps_clamped) >= 2: epsilons, corrints = eps_clamped, corr_clamped return CurveResult( sequence_length=int(vectors.shape[-2]), epsilons=epsilons.detach().cpu().numpy(), corrints=corrints.detach().cpu().numpy(), )
[docs] def curve_from_vectors_batch( vectors_batch: torch.Tensor, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, backend: Optional[str] = None, ) -> list[CurveResult]: if vectors_batch.dim() != 3: raise ValueError("vectors_batch must have shape (B, M, K).") return [ curve_from_vectors( vectors_batch[idx], epsilon_range=epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, backend=backend, ) for idx in range(vectors_batch.shape[0]) ]
[docs] def progressive_curve_from_vectors( vectors: torch.Tensor, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, backend: Optional[str] = None, ) -> ProgressiveCurveResult: epsilons = _make_epsilons(vectors, epsilon_range, num_epsilon) corrints_progressive = progressive_correlation_integral( vectors, epsilons, block_size=block_size, show_progress=show_progress, backend=backend, ) return ProgressiveCurveResult( sequence_length=int(vectors.shape[-2]), epsilons=epsilons.detach().cpu().numpy(), corrints_progressive=corrints_progressive.detach().cpu().numpy(), )
[docs] def progressive_curve_from_vectors_batch( vectors_batch: torch.Tensor, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, backend: Optional[str] = None, ) -> list[ProgressiveCurveResult]: if vectors_batch.dim() != 3: raise ValueError("vectors_batch must have shape (B, M, K).") return [ progressive_curve_from_vectors( vectors_batch[idx], epsilon_range=epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, backend=backend, ) for idx in range(vectors_batch.shape[0]) ]
def _text_to_vectors( model_wrapper: LanguageModelWrapper, text: str, dim_reduction: Optional[int] = 8192, context_length: Optional[int] = None, stride: int = 1, show_progress: bool = False, precision: torch.dtype = torch.float32, ) -> torch.Tensor: return model_wrapper.get_log_probabilities( text, dim_reduction=dim_reduction, context_length=context_length, stride=stride, show_progress=show_progress, ).type(precision)
[docs] def curve_from_text( text: str, model: ModelLike, tokenizer: Optional[object] = None, context_length: Optional[int] = None, dim_reduction: Optional[int] = 8192, stride: int = 1, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, precision: torch.dtype = torch.float32, backend: Optional[str] = None, **model_kwargs, ) -> CurveResult: model_wrapper = _resolve_model_wrapper(model, tokenizer=tokenizer, **model_kwargs) vectors = _text_to_vectors( model_wrapper, text=text, dim_reduction=dim_reduction, context_length=context_length, stride=stride, show_progress=show_progress, precision=precision, ) return curve_from_vectors( vectors=vectors, epsilon_range=epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, backend=backend, )
[docs] def curve_from_texts( texts: list[str], model: ModelLike, tokenizer: Optional[object] = None, context_length: Optional[int] = None, dim_reduction: Optional[int] = 8192, stride: int = 1, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, precision: torch.dtype = torch.float32, backend: Optional[str] = None, **model_kwargs, ) -> list[CurveResult]: model_wrapper = _resolve_model_wrapper(model, tokenizer=tokenizer, **model_kwargs) return [ curve_from_text( text, model=model_wrapper, context_length=context_length, dim_reduction=dim_reduction, stride=stride, epsilon_range=epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, precision=precision, backend=backend, ) for text in texts ]
[docs] def progressive_curve_from_text( text: str, model: ModelLike, tokenizer: Optional[object] = None, context_length: Optional[int] = None, dim_reduction: Optional[int] = 8192, stride: int = 1, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, precision: torch.dtype = torch.float32, backend: Optional[str] = None, **model_kwargs, ) -> ProgressiveCurveResult: model_wrapper = _resolve_model_wrapper(model, tokenizer=tokenizer, **model_kwargs) vectors = _text_to_vectors( model_wrapper, text=text, dim_reduction=dim_reduction, context_length=context_length, stride=stride, show_progress=show_progress, precision=precision, ) return progressive_curve_from_vectors( vectors=vectors, epsilon_range=epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, backend=backend, )
[docs] def progressive_curve_from_texts( texts: list[str], model: ModelLike, tokenizer: Optional[object] = None, context_length: Optional[int] = None, dim_reduction: Optional[int] = 8192, stride: int = 1, epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0), num_epsilon: int = 1024, block_size: int = 512, show_progress: bool = False, precision: torch.dtype = torch.float32, backend: Optional[str] = None, **model_kwargs, ) -> list[ProgressiveCurveResult]: model_wrapper = _resolve_model_wrapper(model, tokenizer=tokenizer, **model_kwargs) return [ progressive_curve_from_text( text, model=model_wrapper, context_length=context_length, dim_reduction=dim_reduction, stride=stride, epsilon_range=epsilon_range, num_epsilon=num_epsilon, block_size=block_size, show_progress=show_progress, precision=precision, backend=backend, ) for text in texts ]