Source code for corrdim.models

"""
Language model wrapper module.

This module provides interfaces for different language models to extract
log-probability vectors needed for correlation dimension computation.
"""

import torch
import numpy as np
import tqdm.auto as tqdm
from typing import Optional, List
from abc import ABC, abstractmethod
from transformers import AutoTokenizer, AutoModelForCausalLM
from .utils import reduce_dimension

_FORWARD_CHUNK_SIZE = 512


[docs] class LanguageModelWrapper(ABC): """Abstract base class for language model wrappers."""
[docs] @abstractmethod def get_log_probabilities( self, text: str, context_length: int = None, dim_reduction: int = None, stride: int = 1, show_progress: bool = False, ): """ Extract log-probability vectors for each token position. Args: text: Input text context_length: Maximum context length batch_size: Batch size for processing Returns: Array of sampled log-probability vectors of shape (sampled_seq_len, vocab_size) """ pass
[docs] class TransformersModelWrapper(LanguageModelWrapper): """Wrapper for Hugging Face Transformers models.""" def __init__( self, model_name: str, tokenizer: Optional[object] = None, device: Optional[str] = None, **kwargs, ): """ Initialize the transformers model wrapper. Args: model_name: Name or path of the model tokenizer: Tokenizer instance (if None, will load from model_name) device: Device to run on torch_dtype: Data type for model weights """ self.model_name = model_name self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") # Load tokenizer if tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(model_name) else: self.tokenizer = tokenizer # Add padding token if not present if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Prefer fast/low-CPU-memory loading on CUDA. # Keep explicit user-provided kwargs as highest priority. load_kwargs = dict(kwargs) if self.device.startswith("cuda"): load_kwargs.setdefault("device_map", "auto") load_kwargs.setdefault("low_cpu_mem_usage", True) load_kwargs.setdefault("torch_dtype", torch.float16) else: load_kwargs.setdefault("low_cpu_mem_usage", True) # Prefer torch_dtype by default; fall back to dtype for model loaders # that only accept dtype. try: self.model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs) except TypeError as exc: can_retry_with_dtype = "torch_dtype" in load_kwargs and "dtype" not in load_kwargs mentions_torch_dtype = "torch_dtype" in str(exc) if not (can_retry_with_dtype and mentions_torch_dtype): raise retry_kwargs = dict(load_kwargs) retry_kwargs["dtype"] = retry_kwargs.pop("torch_dtype") self.model = AutoModelForCausalLM.from_pretrained(model_name, **retry_kwargs) load_kwargs = retry_kwargs # Load model. With device_map=auto, do not call model.to(...). if "device_map" not in load_kwargs: self.model = self.model.to(self.device) self.model.eval() self.input_device = self._infer_input_device() def _infer_input_device(self) -> torch.device: def _to_device(mapped) -> Optional[torch.device]: # Accelerate may store device map entries as int (GPU ordinal), # torch.device, or strings like "cuda:0"/"cpu"/"disk". if isinstance(mapped, int): return torch.device(f"cuda:{mapped}") if isinstance(mapped, torch.device): return mapped s = str(mapped) if s in {"cpu"} or s.startswith("cuda"): return torch.device(s) return None # Accelerate-dispatched models may span multiple devices and may not expose # a single stable `.device`. Feed inputs to the first CUDA shard if present. hf_device_map = getattr(self.model, "hf_device_map", None) if isinstance(hf_device_map, dict): for mapped in hf_device_map.values(): device = _to_device(mapped) if device is not None and device.type == "cuda": return device return torch.device("cpu") if hasattr(self.model, "device"): return self.model.device return torch.device(self.device)
[docs] def encode(self, text: str, **kwargs) -> List[int]: """Tokenize text.""" return self.tokenizer.encode(text, **kwargs)
[docs] def decode(self, tokens: List[int], **kwargs) -> str: """Decode tokens to text.""" return self.tokenizer.decode(tokens, **kwargs)
@torch.no_grad() def get_log_probabilities( self, text: str, context_length: int = None, dim_reduction: int = None, stride: int = 1, show_progress: bool = False, ) -> torch.Tensor: """ Extract sampled log-probability vectors. Args: text: Input text context_length: Maximum context length stride: Sampling interval over token positions. Keep every `stride`-th token vector. Returns: Array of sampled log-probability vectors of shape (sampled_seq_len, vocab_size) """ # Re-resolve input device in case model dispatch/cache changed after init. self.input_device = self._infer_input_device() # Tokenize text tokens = self.tokenizer.encode(text, add_special_tokens=False) if len(tokens) == 0: raise ValueError("Input text produced an empty token sequence.") if context_length is None: context_length = self.model.config.max_position_embeddings elif context_length > self.model.config.max_position_embeddings: raise ValueError( f"Context length exceeds the max length allowed by the model: {self.model.config.max_position_embeddings}" ) token_stride = self._normalize_token_stride(stride) if dim_reduction is not None: if dim_reduction > self.model.config.vocab_size: raise ValueError(f"Dim reduction must be less than or equal to the vocabulary size: {self.model.config.vocab_size}") # print("Warning: Using dimension reduction may change the correlation dimension value.") if len(tokens) <= context_length: # Short text: process in a single pass, then sample by token position. return self._get_log_probabilities_single_pass( tokens, dim_reduction=dim_reduction, token_stride=token_stride, show_progress=show_progress, ) else: # Long text: sliding-window inference is internal; stride is only used for sampling output positions. return self._get_log_probabilities_sliding_window( tokens, context_length, dim_reduction=dim_reduction, token_stride=token_stride, show_progress=show_progress, ) def _normalize_token_stride(self, stride: int) -> int: if not isinstance(stride, int) or stride < 1: raise ValueError("stride must be a positive integer") return stride def _get_log_probabilities_single_pass( self, tokens: List[int], dim_reduction: int = None, token_stride: int = 1, show_progress: bool = False, ) -> torch.Tensor: """Get log probabilities for short texts in a single pass.""" input_ids = torch.tensor([tokens], device=self.input_device) def log_softmax_(logits: torch.Tensor, dim) -> torch.Tensor: """Log softmax in place.""" logsumexp = torch.logsumexp(logits, dim=dim, keepdim=True) torch.sub(logits, logsumexp, out=logits) return logits cache = None sampled_log_probs = [] total_seq_len = len(tokens) for chunk_start in tqdm.trange( 0, total_seq_len, _FORWARD_CHUNK_SIZE, disable=not show_progress, desc="Computing log-probs", ): chunk_end = min(chunk_start + _FORWARD_CHUNK_SIZE, total_seq_len) outputs = self.model(input_ids[:, chunk_start:chunk_end], past_key_values=cache, use_cache=True) cache = outputs.past_key_values # Convert to log logp = log_softmax_(outputs.logits[0], dim=-1) if dim_reduction is not None: logp = reduce_dimension(logp, method="group_add", num_groups=dim_reduction) # Keep vectors at global positions: 0, token_stride, 2*token_stride, ... offset = (-chunk_start) % token_stride sampled_chunk = logp[offset::token_stride] if sampled_chunk.shape[0] > 0: sampled_log_probs.append(sampled_chunk) log_probs = torch.cat(sampled_log_probs, dim=0) expected_rows = (total_seq_len - 1) // token_stride + 1 assert log_probs.shape[0] == expected_rows return log_probs def _get_log_probabilities_sliding_window( self, tokens: List[int], context_length: int, dim_reduction: int = None, token_stride: int = 1, show_progress: bool = False, ) -> torch.Tensor: """Get sampled log probabilities for long texts using internal sliding windows.""" seq_len = len(tokens) input_ids = torch.tensor([tokens], device=self.input_device) # Internal step for long-sequence inference. This is separate from token_stride. window_step = max(1, context_length // 10) sampled_log_probs = [] for pos in tqdm.trange(0, seq_len, window_step, disable=not show_progress, desc="Computing log-probs"): start = max(0, pos + window_step - context_length) end = min(pos + window_step, seq_len) ntokens_to_update = end - pos if ntokens_to_update == 0: continue outputs = self.model(input_ids[:, start:end]) logits = outputs.logits[0, -ntokens_to_update:, :] logp = torch.log_softmax(logits, dim=-1) if dim_reduction is not None: logp = reduce_dimension(logp, method="group_add", num_groups=dim_reduction) global_positions = torch.arange(pos, end, device=logp.device) sampled_chunk = logp[(global_positions % token_stride) == 0] if sampled_chunk.shape[0] > 0: sampled_log_probs.append(sampled_chunk) log_probs = torch.cat(sampled_log_probs, dim=0) expected_rows = (seq_len - 1) // token_stride + 1 assert log_probs.shape[0] == expected_rows return log_probs
[docs] class GPT2Wrapper(TransformersModelWrapper): """Specialized wrapper for GPT-2 models.""" def __init__(self, model_size: str = "gpt2", **kwargs): """ Initialize GPT-2 wrapper. Args: model_size: GPT-2 model size ('gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl') **kwargs: Additional arguments for TransformersModelWrapper """ model_name = model_size if model_size.startswith("gpt2") else f"gpt2-{model_size}" super().__init__(model_name, **kwargs)
[docs] class LLaMAWrapper(TransformersModelWrapper): """Specialized wrapper for LLaMA models.""" def __init__(self, model_name: str = "meta-llama/Llama-2-7b-hf", **kwargs): """ Initialize LLaMA wrapper. Args: model_name: LLaMA model name **kwargs: Additional arguments for TransformersModelWrapper """ super().__init__(model_name, **kwargs)
[docs] def create_model_wrapper( model_name: str, tokenizer: Optional[object] = None, device: Optional[str] = None, **kwargs, ) -> LanguageModelWrapper: """ Factory function to create appropriate model wrapper. Args: model_name: Name of the model tokenizer: Tokenizer instance device: Device to run on **kwargs: Additional arguments Returns: Appropriate model wrapper instance """ # Try to determine model type from name if "gpt2" in model_name.lower(): return GPT2Wrapper(model_name, tokenizer=tokenizer, device=device, **kwargs) elif "llama" in model_name.lower(): return LLaMAWrapper(model_name, tokenizer=tokenizer, device=device, **kwargs) else: # Default to generic transformers wrapper return TransformersModelWrapper(model_name, tokenizer=tokenizer, device=device, **kwargs)