Source code for corrdim.corrint

from __future__ import annotations

import os
from enum import Enum
from typing import Any, Dict, Optional, Union

import torch


[docs] class CorrIntBackend(str, Enum): """Correlation integral backend selector.""" AUTO = "auto" CUDA = "cuda" TRITON = "triton" PYTORCH = "pytorch" PYTORCH_FAST = "pytorch_fast"
BackendLike = Union[str, CorrIntBackend, None] _KNOWN_BACKENDS = {"auto", "cuda", "triton", "pytorch", "pytorch_fast"} def _normalize_backend_name(name: Any, *, source: str) -> str: normalized = str(name).strip().lower() if normalized not in _KNOWN_BACKENDS: raise ValueError( f"Unknown backend from {source}: {name!r}. " f"Available: {sorted(_KNOWN_BACKENDS)}" ) return normalized def _env_default_backend() -> str: # Canonical env var for selecting the correlation-integral backend. raw = os.environ.get("CORRDIM_CORRINT_BACKEND") source = "CORRDIM_CORRINT_BACKEND" if raw is None: return "triton" return _normalize_backend_name(raw, source=source) # Default backend: environment variable first, otherwise auto (auto-detect). _DEFAULT_BACKEND: str = _env_default_backend() def _normalize_backend(backend: BackendLike) -> str: if backend is None: name = _DEFAULT_BACKEND elif isinstance(backend, CorrIntBackend): name = backend.value else: name = str(backend) return _normalize_backend_name(name, source="set_corrint_backend/backend") def _triton_is_available() -> bool: if not torch.cuda.is_available(): return False try: from . import triton as _impl # noqa: F401 except Exception as e: # pragma: no cover return False return True def _cuda_is_available() -> bool: if not torch.cuda.is_available(): return False try: from . import cuda as _impl # noqa: F401 except Exception: # pragma: no cover return False return True
[docs] def resolve_corrint_backend(backend: BackendLike = None) -> str: """Resolve AUTO to a concrete backend name.""" name = _normalize_backend(backend) if name != "auto": return name return "cuda" if _cuda_is_available() else "pytorch"
[docs] def available_corrint_backends() -> Dict[str, bool]: """Return availability of known backends.""" return { "cuda": _cuda_is_available(), "triton": _triton_is_available(), "pytorch": True, "pytorch_fast": True, }
[docs] def set_corrint_backend(backend: BackendLike = "auto") -> str: """Set process-wide default backend; returns the resolved backend name.""" global _DEFAULT_BACKEND _DEFAULT_BACKEND = _normalize_backend(backend) return resolve_corrint_backend(None)
[docs] def get_corrint_backend() -> str: """Get the current default backend (AUTO is resolved).""" return resolve_corrint_backend(None)
def _all_cuda(*tensors: Optional[torch.Tensor]) -> bool: present = [t for t in tensors if t is not None] return len(present) > 0 and all(getattr(t, "device", None) is not None and t.device.type == "cuda" for t in present) def _select_impl(backend: BackendLike, *tensors: Optional[torch.Tensor]): name = resolve_corrint_backend(backend) if name == "cuda" and _all_cuda(*tensors): from . import cuda as impl return name, impl, {} if name == "triton" and _all_cuda(*tensors): from . import triton as impl return name, impl, {} from . import pytorch as impl return name, impl, {"fast": name == "pytorch_fast"}
[docs] def correlation_counts( vecs: torch.FloatTensor, epsilons: torch.FloatTensor, vecs_other: Optional[torch.FloatTensor] = None, *, backend: BackendLike = None, **kwargs: Any, ) -> torch.Tensor: """Return counts (unnormalized); implemented by the selected backend module.""" name, impl, extra = _select_impl(backend, vecs, vecs_other) if vecs_other is None: return impl.correlation_counts(vecs, epsilons, **extra, **kwargs) return impl.correlation_counts(vecs, epsilons, vecs_other=vecs_other, **extra, **kwargs)
[docs] def correlation_integral( vecs: torch.FloatTensor, epsilons: torch.FloatTensor, vecs_other: Optional[torch.FloatTensor] = None, return_counts: bool = False, *, backend: BackendLike = None, **kwargs: Any, ) -> torch.Tensor: """Return correlation integral; if return_counts=True, return counts instead.""" if return_counts: return correlation_counts(vecs, epsilons, vecs_other=vecs_other, backend=backend, **kwargs) name, impl, extra = _select_impl(backend, vecs, vecs_other) if vecs_other is None: return impl.correlation_integral(vecs, epsilons, **extra, **kwargs) return impl.correlation_integral(vecs, epsilons, vecs_other=vecs_other, **extra, **kwargs)
[docs] def progressive_correlation_counts( vecs: torch.FloatTensor, epsilons: torch.FloatTensor, *, backend: BackendLike = None, **kwargs: Any, ) -> torch.Tensor: """Progressive counts over sequence prefixes (if implemented by backend).""" name, impl, extra = _select_impl(backend, vecs) fn = getattr(impl, "progressive_correlation_counts", None) if fn is None: raise NotImplementedError(f"backend={name!r} does not implement progressive_correlation_counts") return fn(vecs, epsilons, **extra, **kwargs)
[docs] def progressive_correlation_integral( vecs: torch.FloatTensor, epsilons: torch.FloatTensor, return_counts: bool = False, *, backend: BackendLike = None, **kwargs: Any, ) -> torch.Tensor: """Progressive correlation integral over sequence prefixes (if implemented by backend).""" if return_counts: return progressive_correlation_counts(vecs, epsilons, backend=backend, **kwargs) name, impl, extra = _select_impl(backend, vecs) fn = getattr(impl, "progressive_correlation_integral", None) if fn is None: raise NotImplementedError(f"backend={name!r} does not implement progressive_correlation_integral") return fn(vecs, epsilons, **extra, **kwargs)
def correlation_integral_cross( vecs1: torch.FloatTensor, vecs2: torch.FloatTensor, epsilons: torch.FloatTensor, *, backend: BackendLike = None, **kwargs: Any, ) -> torch.Tensor: """Cross correlation integral (requires backend support for vecs_other).""" return correlation_integral(vecs1, epsilons, vecs_other=vecs2, backend=backend, **kwargs) def correlation_integral_sequence( vecs: torch.FloatTensor, epsilons: torch.FloatTensor, *, backend: BackendLike = None, **kwargs: Any, ) -> torch.Tensor: """Correlation integral curve for every prefix (same as progressive_correlation_integral).""" return progressive_correlation_integral(vecs, epsilons, backend=backend, **kwargs) __all__ = [ "CorrIntBackend", "available_corrint_backends", "resolve_corrint_backend", "set_corrint_backend", "get_corrint_backend", "correlation_counts", "correlation_integral", "progressive_correlation_counts", "progressive_correlation_integral", ]