Source code for corrdim.cli

"""Command-line interface for corrdim."""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import numpy as np

from .dimension import estimate_dimension_from_curve
from .high_level import measure_text
from .low_level import curve_from_text
from .types import CurveResult


def _parse_dtype(dtype: str):
    if dtype is None:
        return None
    s = str(dtype).strip().lower()
    mapping = {
        "float16": "float16",
        "fp16": "float16",
        "half": "float16",
        "float32": "float32",
        "fp32": "float32",
        "bfloat16": "bfloat16",
        "bf16": "bfloat16",
    }
    s = mapping.get(s, s)
    try:
        return getattr(__import__("torch"), s)
    except Exception as e:
        raise ValueError(f"Unsupported dtype={dtype!r}. Use one of: float16/float32/bfloat16.") from e


def _parse_stride(stride: str):
    if stride is None:
        return 1
    s = str(stride).strip()
    try:
        v = int(s)
    except ValueError as e:
        raise ValueError("stride must be a positive integer") from e
    if v < 1:
        raise ValueError("stride must be a positive integer")
    return v


def _parse_dim_reduction(dim_reduction):
    if dim_reduction is None:
        return 8192
    s = str(dim_reduction).strip().lower()
    if s in {"none", "null", "off"}:
        return None
    try:
        v = int(s)
    except ValueError as e:
        raise ValueError("dim-reduction must be an integer or 'none'") from e
    if v < 1:
        raise ValueError("dim-reduction must be a positive integer or 'none'")
    return v


def _read_text_file(path: str, line_sep: str) -> str:
    try:
        text = Path(path).read_text(encoding="utf-8").strip()
    except Exception as exc:
        raise ValueError(f"Error reading file {path}: {exc}") from exc
    if not text:
        raise ValueError(f"File {path} is empty")
    return line_sep.join(text.split("\n"))


def _add_common_text_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument("--file", required=True, help="Text file to analyze")
    parser.add_argument("--model", "-m", required=True, help="Model name for from_pretrained")
    parser.add_argument("--line-sep", default="\n")
    parser.add_argument("--context-length", type=int, default=None)
    parser.add_argument("--stride", type=int, default=1)
    parser.add_argument("--dim-reduction", type=str, default="8192")
    parser.add_argument("--dtype", type=str, default="float16")
    parser.add_argument("--num-epsilon", type=int, default=1024)
    parser.add_argument("--backend", type=str, default=None)
    parser.add_argument("--output", "-o", default=None)


def _write_or_print(payload: str, output: str | None) -> None:
    if output:
        Path(output).write_text(payload, encoding="utf-8")
        print(f"Results saved to {output}")
    else:
        print(payload)


def _cmd_measure_text(args: argparse.Namespace) -> None:
    text = _read_text_file(args.file, args.line_sep)
    result = measure_text(
        text=text,
        model=args.model,
        context_length=args.context_length,
        stride=_parse_stride(args.stride),
        dim_reduction=_parse_dim_reduction(args.dim_reduction),
        precision=_parse_dtype(args.dtype),
        num_epsilon=args.num_epsilon,
        backend=args.backend,
    )
    payload = {
        "sequence_length": result.sequence_length,
        "corrdim": result.corrdim,
        "fit_r2": result.fit_r2,
        "linear_region_bounds": result.linear_region_bounds,
    }
    _write_or_print(json.dumps(payload, ensure_ascii=False, indent=2), args.output)


def _cmd_curve_from_text(args: argparse.Namespace) -> None:
    text = _read_text_file(args.file, args.line_sep)
    curve = curve_from_text(
        text=text,
        model=args.model,
        context_length=args.context_length,
        stride=_parse_stride(args.stride),
        dim_reduction=_parse_dim_reduction(args.dim_reduction),
        precision=_parse_dtype(args.dtype),
        num_epsilon=args.num_epsilon,
        backend=args.backend,
    )
    payload = {
        "sequence_length": curve.sequence_length,
        "epsilons": curve.epsilons.tolist(),
        "corrints": curve.corrints.tolist(),
    }
    _write_or_print(json.dumps(payload, ensure_ascii=False), args.output)


def _cmd_estimate_dimension(args: argparse.Namespace) -> None:
    data = json.loads(Path(args.curve_json).read_text(encoding="utf-8"))
    curve = CurveResult(
        sequence_length=int(data["sequence_length"]),
        epsilons=np.asarray(data["epsilons"], dtype=float),
        corrints=np.asarray(data["corrints"], dtype=float),
    )
    result = estimate_dimension_from_curve(curve)
    payload = {
        "sequence_length": result.sequence_length,
        "corrdim": result.corrdim,
        "fit_r2": result.fit_r2,
        "linear_region_bounds": result.linear_region_bounds,
    }
    _write_or_print(json.dumps(payload, ensure_ascii=False, indent=2), args.output)


[docs] def main() -> None: parser = argparse.ArgumentParser(description="CorrDim CLI") subparsers = parser.add_subparsers(dest="command", required=True) measure_text_parser = subparsers.add_parser("measure-text", help="Measure correlation dimension from text") _add_common_text_args(measure_text_parser) measure_text_parser.set_defaults(handler=_cmd_measure_text) curve_text_parser = subparsers.add_parser("curve-from-text", help="Compute correlation-integral curve from text") _add_common_text_args(curve_text_parser) curve_text_parser.set_defaults(handler=_cmd_curve_from_text) estimate_dim_parser = subparsers.add_parser("estimate-dimension", help="Estimate dimension from curve JSON") estimate_dim_parser.add_argument("--curve-json", required=True, help="Curve JSON file") estimate_dim_parser.add_argument("--output", "-o", default=None) estimate_dim_parser.set_defaults(handler=_cmd_estimate_dimension) args = parser.parse_args() try: args.handler(args) except Exception as exc: print(f"Error: {exc}", file=sys.stderr) sys.exit(1)
if __name__ == '__main__': main()