"""
Orthogonal test: residualize each exposure signal on VIX (and atm_iv), then test
the residual against outcomes. This tells us what each exposure adds ABOVE AND
BEYOND the vol-level information already in VIX/IV.

If residual Spearman ≈ 0, the signal is a pure VIX/IV proxy.

Writes: orthogonal_results.csv
"""
from __future__ import annotations

from pathlib import Path
import numpy as np
import pandas as pd
from scipy import stats

HERE = Path(__file__).parent


def residualize(x: np.ndarray, controls: np.ndarray) -> np.ndarray:
    """Return residuals from OLS x ~ controls (+ intercept)."""
    C = np.column_stack([np.ones(len(controls)), controls])
    beta, *_ = np.linalg.lstsq(C, x, rcond=None)
    return x - C @ beta


HYPOTHESES = [
    ("net_gex",  "rv_next_1d"),
    ("net_dex",  "ret_next_1d"),
    ("net_vex",  "iv_chg_next_1d"),
    ("net_chex", "ret_next_1d"),
]


def run(df: pd.DataFrame, control_cols: list[str]) -> pd.DataFrame:
    rows = []
    for sig, out in HYPOTHESES:
        d = df[[sig, out, "vix", "atm_iv"]].dropna().copy()
        # residualize both signal and outcome on controls
        C = d[control_cols].values
        sig_res = residualize(d[sig].values, C)
        out_res = residualize(d[out].values, C)

        # raw
        raw_rho, raw_p = stats.spearmanr(d[sig], d[out])
        # residualized
        res_rho, res_p = stats.spearmanr(sig_res, out_res)

        # quintile on residuals
        q = pd.qcut(pd.Series(sig_res), 5, labels=False)
        top = out_res[q == 4].mean()
        bot = out_res[q == 0].mean()
        t, pv = stats.ttest_ind(out_res[q == 4], out_res[q == 0], equal_var=False)

        rows.append(dict(
            signal=sig, outcome=out, n=len(d),
            controls="+".join(control_cols),
            raw_spearman=float(raw_rho), raw_p=float(raw_p),
            res_spearman=float(res_rho), res_p=float(res_p),
            res_top_minus_bot=float(top - bot),
            res_tb_t=float(t), res_tb_p=float(pv),
        ))
    return pd.DataFrame(rows)


def main() -> int:
    df = pd.read_parquet(HERE / "master_dataset.parquet")
    print(f"Loaded {len(df):,} rows")

    print("\n=== ORTHOGONAL: residualize signal on VIX ===")
    r1 = run(df, ["vix"])
    print(r1[["signal","outcome","n","raw_spearman","raw_p","res_spearman","res_p","res_top_minus_bot","res_tb_t","res_tb_p"]].to_string(index=False))

    print("\n=== ORTHOGONAL: residualize signal on VIX + atm_iv ===")
    r2 = run(df, ["vix", "atm_iv"])
    print(r2[["signal","outcome","n","raw_spearman","raw_p","res_spearman","res_p","res_top_minus_bot","res_tb_t","res_tb_p"]].to_string(index=False))

    # Bonferroni correction across 4 hypotheses — a signal must beat p < 0.05/4 = 0.0125
    print("\nBonferroni threshold for 4 tests: p < 0.0125")

    out = pd.concat([r1.assign(control_set="vix"), r2.assign(control_set="vix+atm_iv")], ignore_index=True)
    out.to_csv(HERE / "orthogonal_results.csv", index=False)
    print(f"\nWrote {HERE/'orthogonal_results.csv'}")
    return 0


if __name__ == "__main__":
    import sys
    sys.exit(main())
