Source code for corrdim.utils

"""
Utility functions for the corrdim library.
"""

from typing import Tuple, Union

import numpy as np
import torch

Tensor = Union[torch.Tensor, np.ndarray]


[docs] def clamp(values: Tensor, reference: Tensor, low: float, high: float) -> Tuple[Tensor, Tensor]: """Filter values by a reference range, returning filtered values and reference.""" in_range = (reference > low) & (reference < high) return values[in_range], reference[in_range]
[docs] def reduce_dimension(vectors: Tensor, num_groups: int = 8192, method: str = "group_add") -> Tensor: if method == "group_add": return group_add(vectors, num_groups) elif method == "group_mean": return group_mean(vectors, num_groups) else: raise ValueError(f"Invalid method: {method}")
[docs] def group_add(vectors: Tensor, num_groups: int) -> Tensor: if not isinstance(vectors, torch.Tensor): vectors = torch.as_tensor(vectors) if num_groups <= 0: raise ValueError("num_groups must be a positive integer") vocab_size = vectors.shape[-1] group_index = torch.arange(vocab_size, device=vectors.device) % num_groups scatter_index = group_index.view(*([1] * (vectors.dim() - 1)), vocab_size).expand_as(vectors) reduced = torch.zeros(*vectors.shape[:-1], num_groups, dtype=vectors.dtype, device=vectors.device) reduced.scatter_add_(-1, scatter_index, vectors) return reduced
[docs] def group_mean(vectors: Tensor, num_groups: int) -> Tensor: if not isinstance(vectors, torch.Tensor): vectors = torch.as_tensor(vectors) vocab_size = vectors.shape[-1] reduced = group_add(vectors, num_groups) counts = torch.bincount( torch.arange(vocab_size, device=reduced.device) % num_groups, minlength=num_groups, ).to(dtype=reduced.dtype) counts = counts.clamp_min(1) return reduced / counts.view(*([1] * (reduced.dim() - 1)), num_groups)