from __future__ import annotations
from collections import OrderedDict
from threading import Lock
from typing import Optional, Tuple, Union
import numpy as np
import torch
from .corrint import correlation_integral, progressive_correlation_integral
from .models import LanguageModelWrapper, create_model_wrapper
from .types import CurveResult, ProgressiveCurveResult
from .utils import clamp
ModelLike = Union[str, LanguageModelWrapper]
_MODEL_CACHE_MAX_SIZE = 2
_MODEL_CACHE: "OrderedDict[tuple, LanguageModelWrapper]" = OrderedDict()
_MODEL_CACHE_LOCK = Lock()
_TOKENIZER_CACHE_MAX_SIZE = 4
_TOKENIZER_CACHE: "OrderedDict[str, object]" = OrderedDict()
_TOKENIZER_CACHE_LOCK = Lock()
def _freeze_kwargs(kwargs: dict) -> tuple:
return tuple(sorted((str(key), repr(value)) for key, value in kwargs.items()))
def _cache_key(model_name: str, tokenizer: Optional[object], device: Optional[str], kwargs: dict) -> tuple:
tokenizer_key = None if tokenizer is None else id(tokenizer)
return (model_name, tokenizer_key, device, _freeze_kwargs(kwargs))
[docs]
def clear_model_cache() -> None:
with _MODEL_CACHE_LOCK:
_MODEL_CACHE.clear()
with _TOKENIZER_CACHE_LOCK:
_TOKENIZER_CACHE.clear()
def _get_or_load_tokenizer(model_name: str) -> object:
# Prefer tokenizer already embedded in a loaded model wrapper
with _MODEL_CACHE_LOCK:
for key, wrapper in _MODEL_CACHE.items():
if key[0] == model_name and hasattr(wrapper, "tokenizer"):
return wrapper.tokenizer
# Fall back to tokenizer-only cache
with _TOKENIZER_CACHE_LOCK:
tok = _TOKENIZER_CACHE.get(model_name)
if tok is not None:
_TOKENIZER_CACHE.move_to_end(model_name)
return tok
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_name)
with _TOKENIZER_CACHE_LOCK:
_TOKENIZER_CACHE[model_name] = tok
_TOKENIZER_CACHE.move_to_end(model_name)
while len(_TOKENIZER_CACHE) > _TOKENIZER_CACHE_MAX_SIZE:
_TOKENIZER_CACHE.popitem(last=False)
return tok
def _resolve_model_wrapper(
model: ModelLike,
tokenizer: Optional[object] = None,
device: Optional[str] = None,
forward_chunk_size: Optional[int] = None,
**kwargs,
) -> LanguageModelWrapper:
if isinstance(model, str):
key = _cache_key(model, tokenizer, device, kwargs)
with _MODEL_CACHE_LOCK:
cached = _MODEL_CACHE.get(key)
if cached is not None:
_MODEL_CACHE.move_to_end(key)
if forward_chunk_size is not None:
cached.forward_chunk_size = forward_chunk_size
return cached
wrapper = create_model_wrapper(model, tokenizer=tokenizer, device=device, **kwargs)
with _MODEL_CACHE_LOCK:
_MODEL_CACHE[key] = wrapper
_MODEL_CACHE.move_to_end(key)
while len(_MODEL_CACHE) > _MODEL_CACHE_MAX_SIZE:
_MODEL_CACHE.popitem(last=False)
# Keep tokenizer cache in sync so _get_or_load_tokenizer can reuse it
if hasattr(wrapper, "tokenizer"):
with _TOKENIZER_CACHE_LOCK:
_TOKENIZER_CACHE[model] = wrapper.tokenizer
_TOKENIZER_CACHE.move_to_end(model)
while len(_TOKENIZER_CACHE) > _TOKENIZER_CACHE_MAX_SIZE:
_TOKENIZER_CACHE.popitem(last=False)
if forward_chunk_size is not None:
wrapper.forward_chunk_size = forward_chunk_size
return wrapper
# model is already a wrapper instance
if forward_chunk_size is not None:
model.forward_chunk_size = forward_chunk_size
return model
def _make_epsilons(vecs: torch.Tensor, epsilon_range: Tuple[float, float], num_epsilon: int) -> torch.Tensor:
if not torch.isfinite(vecs).all():
raise ValueError("Found nan or inf in vectors.")
if vecs.shape[-2] <= 100:
raise ValueError(f"The sequence length is too short ({vecs.shape[-2]} tokens). Please use at least 100 tokens.")
# torch.logspace is not supported on MPS; fall back to CPU then move.
device = vecs.device
target = device if device.type != "mps" else "cpu"
eps = torch.logspace(
np.log10(float(epsilon_range[0])),
np.log10(float(epsilon_range[1])),
num_epsilon,
device=target,
)
return eps.to(device)
[docs]
def curve_from_vectors(
vectors: torch.Tensor,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
backend: Optional[str] = None,
) -> CurveResult:
epsilons = _make_epsilons(vectors, epsilon_range, num_epsilon)
corrints = correlation_integral(vectors, epsilons, block_size=block_size, show_progress=show_progress, backend=backend)
eps_clamped, corr_clamped = clamp(epsilons, corrints, low=float(corrints.min()), high=0.95)
if len(eps_clamped) >= 2:
epsilons, corrints = eps_clamped, corr_clamped
return CurveResult(
sequence_length=int(vectors.shape[-2]),
epsilons=epsilons.detach().cpu().numpy(),
corrints=corrints.detach().cpu().numpy(),
)
[docs]
def curve_from_vectors_batch(
vectors_batch: torch.Tensor,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
backend: Optional[str] = None,
) -> list[CurveResult]:
if vectors_batch.dim() != 3:
raise ValueError("vectors_batch must have shape (B, M, K).")
return [
curve_from_vectors(
vectors_batch[idx],
epsilon_range=epsilon_range,
num_epsilon=num_epsilon,
block_size=block_size,
show_progress=show_progress,
backend=backend,
)
for idx in range(vectors_batch.shape[0])
]
[docs]
def progressive_curve_from_vectors(
vectors: torch.Tensor,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
backend: Optional[str] = None,
) -> ProgressiveCurveResult:
epsilons = _make_epsilons(vectors, epsilon_range, num_epsilon)
corrints_progressive = progressive_correlation_integral(
vectors,
epsilons,
block_size=block_size,
show_progress=show_progress,
backend=backend,
)
return ProgressiveCurveResult(
sequence_length=int(vectors.shape[-2]),
epsilons=epsilons.detach().cpu().numpy(),
corrints_progressive=corrints_progressive.detach().cpu().numpy(),
)
[docs]
def progressive_curve_from_vectors_batch(
vectors_batch: torch.Tensor,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
backend: Optional[str] = None,
) -> list[ProgressiveCurveResult]:
if vectors_batch.dim() != 3:
raise ValueError("vectors_batch must have shape (B, M, K).")
return [
progressive_curve_from_vectors(
vectors_batch[idx],
epsilon_range=epsilon_range,
num_epsilon=num_epsilon,
block_size=block_size,
show_progress=show_progress,
backend=backend,
)
for idx in range(vectors_batch.shape[0])
]
def _text_to_vectors(
model_wrapper: LanguageModelWrapper,
text: str,
dim_reduction: Optional[int] = 8192,
context_length: Optional[int] = None,
stride: int = 1,
show_progress: bool = False,
precision: torch.dtype = torch.float32,
) -> torch.Tensor:
return model_wrapper.get_log_probabilities(
text,
dim_reduction=dim_reduction,
context_length=context_length,
stride=stride,
show_progress=show_progress,
).type(precision)
[docs]
def text_to_vectors(
text: str,
model: ModelLike,
tokenizer: Optional[object] = None,
context_length: Optional[int] = None,
dim_reduction: Optional[int] = 8192,
stride: int = 1,
show_progress: bool = False,
precision: torch.dtype = torch.float32,
forward_chunk_size: Optional[int] = None,
**model_kwargs,
) -> torch.Tensor:
"""Extract log-probability vectors from *text* using *model*.
This is the public entry point for vector extraction; the returned tensor
has shape ``(sampled_seq_len, reduced_vocab_size)`` and can be passed
directly to :func:`curve_from_vectors` or :func:`progressive_curve_from_vectors`.
Args:
text: Input text.
model: HuggingFace model name/ID (``str``) or a pre-built
:class:`~corrdim.models.LanguageModelWrapper` instance.
tokenizer: Tokenizer instance (only used when *model* is a string).
context_length: Maximum context length for the model.
dim_reduction: Vocabulary grouping size for dimensionality reduction.
stride: Keep every *stride*-th token vector.
show_progress: Show a progress bar during inference.
precision: Output tensor dtype.
forward_chunk_size: Number of tokens per forward-pass chunk.
Reduce this value (e.g. 128) on systems with limited VRAM.
Only effective when *model* is a string; for wrapper instances
set the attribute directly.
**model_kwargs: Extra keyword arguments forwarded to the model loader
when *model* is a string.
"""
model_wrapper = _resolve_model_wrapper(
model, tokenizer=tokenizer, forward_chunk_size=forward_chunk_size, **model_kwargs
)
return _text_to_vectors(
model_wrapper,
text=text,
dim_reduction=dim_reduction,
context_length=context_length,
stride=stride,
show_progress=show_progress,
precision=precision,
)
def _texts_to_vectors_batched(
model_wrapper: LanguageModelWrapper,
texts: list[str],
dim_reduction: Optional[int] = 8192,
context_length: Optional[int] = None,
stride: int = 1,
show_progress: bool = False,
precision: torch.dtype = torch.float32,
batch_size: Optional[int] = None,
) -> list[torch.Tensor]:
"""Prefer :meth:`~corrdim.models.TransformersModelWrapper.get_log_probabilities_batch` when present."""
if batch_size is not None and batch_size < 1:
raise ValueError("batch_size must be a positive integer or None.")
batch_fn = getattr(model_wrapper, "get_log_probabilities_batch", None)
if batch_fn is not None:
raw = batch_fn(
texts,
context_length=context_length,
dim_reduction=dim_reduction,
stride=stride,
show_progress=show_progress,
batch_size=batch_size,
)
return [v.type(precision) for v in raw]
eff_bs = 1 if batch_size is None else batch_size
out: list[torch.Tensor] = []
for i in range(0, len(texts), eff_bs):
for text in texts[i : i + eff_bs]:
out.append(
model_wrapper.get_log_probabilities(
text,
dim_reduction=dim_reduction,
context_length=context_length,
stride=stride,
show_progress=show_progress,
).type(precision)
)
return out
[docs]
def curve_from_text(
text: str,
model: ModelLike,
tokenizer: Optional[object] = None,
context_length: Optional[int] = None,
dim_reduction: Optional[int] = 8192,
stride: int = 1,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
precision: torch.dtype = torch.float32,
backend: Optional[str] = None,
forward_chunk_size: Optional[int] = None,
**model_kwargs,
) -> CurveResult:
model_wrapper = _resolve_model_wrapper(
model, tokenizer=tokenizer, forward_chunk_size=forward_chunk_size, **model_kwargs
)
vectors = _text_to_vectors(
model_wrapper,
text=text,
dim_reduction=dim_reduction,
context_length=context_length,
stride=stride,
show_progress=show_progress,
precision=precision,
)
return curve_from_vectors(
vectors=vectors,
epsilon_range=epsilon_range,
num_epsilon=num_epsilon,
block_size=block_size,
show_progress=show_progress,
backend=backend,
)
[docs]
def curve_from_texts(
texts: list[str],
model: ModelLike,
tokenizer: Optional[object] = None,
context_length: Optional[int] = None,
dim_reduction: Optional[int] = 8192,
stride: int = 1,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
precision: torch.dtype = torch.float32,
backend: Optional[str] = None,
batch_size: Optional[int] = None,
forward_chunk_size: Optional[int] = None,
**model_kwargs,
) -> list[CurveResult]:
model_wrapper = _resolve_model_wrapper(
model, tokenizer=tokenizer, forward_chunk_size=forward_chunk_size, **model_kwargs
)
vectors_list = _texts_to_vectors_batched(
model_wrapper,
texts,
dim_reduction=dim_reduction,
context_length=context_length,
stride=stride,
show_progress=show_progress,
precision=precision,
batch_size=batch_size,
)
return [
curve_from_vectors(
vectors=v,
epsilon_range=epsilon_range,
num_epsilon=num_epsilon,
block_size=block_size,
show_progress=show_progress,
backend=backend,
)
for v in vectors_list
]
[docs]
def progressive_curve_from_text(
text: str,
model: ModelLike,
tokenizer: Optional[object] = None,
context_length: Optional[int] = None,
dim_reduction: Optional[int] = 8192,
stride: int = 1,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
precision: torch.dtype = torch.float32,
backend: Optional[str] = None,
forward_chunk_size: Optional[int] = None,
**model_kwargs,
) -> ProgressiveCurveResult:
model_wrapper = _resolve_model_wrapper(
model, tokenizer=tokenizer, forward_chunk_size=forward_chunk_size, **model_kwargs
)
vectors = _text_to_vectors(
model_wrapper,
text=text,
dim_reduction=dim_reduction,
context_length=context_length,
stride=stride,
show_progress=show_progress,
precision=precision,
)
return progressive_curve_from_vectors(
vectors=vectors,
epsilon_range=epsilon_range,
num_epsilon=num_epsilon,
block_size=block_size,
show_progress=show_progress,
backend=backend,
)
[docs]
def progressive_curve_from_texts(
texts: list[str],
model: ModelLike,
tokenizer: Optional[object] = None,
context_length: Optional[int] = None,
dim_reduction: Optional[int] = 8192,
stride: int = 1,
epsilon_range: Tuple[float, float] = (10**-20.0, 10**20.0),
num_epsilon: int = 1024,
block_size: int = 512,
show_progress: bool = False,
precision: torch.dtype = torch.float32,
backend: Optional[str] = None,
batch_size: Optional[int] = None,
forward_chunk_size: Optional[int] = None,
**model_kwargs,
) -> list[ProgressiveCurveResult]:
model_wrapper = _resolve_model_wrapper(
model, tokenizer=tokenizer, forward_chunk_size=forward_chunk_size, **model_kwargs
)
vectors_list = _texts_to_vectors_batched(
model_wrapper,
texts,
dim_reduction=dim_reduction,
context_length=context_length,
stride=stride,
show_progress=show_progress,
precision=precision,
batch_size=batch_size,
)
return [
progressive_curve_from_vectors(
vectors=v,
epsilon_range=epsilon_range,
num_epsilon=num_epsilon,
block_size=block_size,
show_progress=show_progress,
backend=backend,
)
for v in vectors_list
]