"""
Controlled / partial-correlation tests.

The naive tests in run_analysis.py show GEX strongly predicts realized vol.
But VIX alone has an even stronger correlation. The critical question: does GEX
add incremental predictive power BEYOND what VIX already captures?

Method:
  1. Partial Spearman / OLS of outcome on (signal, VIX, prior-outcome).
  2. Residualize signal on controls; test residual vs outcome residuals.
  3. Double-sort: quintile on VIX, then within each VIX bucket sort on signal.

Writes:
    controls_incremental.csv   — coefficient / significance of each signal
                                 after controlling for VIX + outcome_ar1
    controls_double_sort.csv   — outcome mean in (VIX quintile x signal quintile)
"""
from __future__ import annotations

from pathlib import Path
import numpy as np
import pandas as pd
from scipy import stats

HERE = Path(__file__).parent


def _ols(X: np.ndarray, y: np.ndarray):
    """Thin wrapper returning coefs, t-stats, p-values, R²."""
    X = np.asarray(X, dtype=float)
    y = np.asarray(y, dtype=float)
    X_ = np.column_stack([np.ones(len(X)), X])
    beta, *_ = np.linalg.lstsq(X_, y, rcond=None)
    yhat = X_ @ beta
    resid = y - yhat
    n, k = X_.shape
    sigma2 = (resid @ resid) / (n - k)
    cov = sigma2 * np.linalg.inv(X_.T @ X_)
    se = np.sqrt(np.diag(cov))
    tvals = beta / se
    pvals = 2 * (1 - stats.t.cdf(np.abs(tvals), df=n - k))
    ss_res = resid @ resid
    ss_tot = ((y - y.mean()) ** 2).sum()
    r2 = 1 - ss_res / ss_tot
    return dict(beta=beta, se=se, t=tvals, p=pvals, r2=float(r2), n=n, k=k)


def incremental(df: pd.DataFrame, signal: str, outcome: str) -> dict:
    """Two regressions: y ~ VIX + AR1 (baseline) and y ~ VIX + AR1 + signal.
    Report incremental R² and signal t-stat after controls."""
    d = df[[signal, outcome, "vix"]].dropna().copy()
    d["outcome_ar1"] = d[outcome].shift(1)
    d = d.dropna()
    y = d[outcome].values
    base = _ols(d[["vix", "outcome_ar1"]].values, y)
    full = _ols(d[["vix", "outcome_ar1", signal]].values, y)
    return dict(
        signal=signal, outcome=outcome, n=len(d),
        r2_baseline=base["r2"], r2_full=full["r2"],
        delta_r2=full["r2"] - base["r2"],
        signal_coef=float(full["beta"][3]),
        signal_t=float(full["t"][3]),
        signal_p=float(full["p"][3]),
        vix_coef=float(full["beta"][1]),
        vix_t=float(full["t"][1]),
        ar1_coef=float(full["beta"][2]),
        ar1_t=float(full["t"][2]),
    )


def double_sort(df: pd.DataFrame, signal: str, outcome: str,
                n_bins: int = 5) -> pd.DataFrame:
    """Double-sort: within each VIX quintile, split on signal, report outcome means."""
    d = df[[signal, outcome, "vix"]].dropna().copy()
    d["vix_q"] = pd.qcut(d["vix"], n_bins, labels=[f"V{i+1}" for i in range(n_bins)])
    d["sig_q"] = d.groupby("vix_q", observed=True)[signal].transform(
        lambda s: pd.qcut(s, n_bins, labels=[f"S{i+1}" for i in range(n_bins)], duplicates="drop")
    )
    tbl = (d.groupby(["vix_q", "sig_q"], observed=True)[outcome]
             .agg(["mean", "count"])
             .reset_index())
    return tbl


HYPOTHESES = [
    ("net_gex",  "rv_next_1d"),
    ("net_dex",  "ret_next_1d"),
    ("net_vex",  "iv_chg_next_1d"),
    ("net_chex", "ret_next_1d"),
    ("net_gex",  "rv_fwd_5d"),   # long-horizon reference
]


def main() -> int:
    pq = HERE / "master_dataset.parquet"
    csv = HERE / "master_dataset.csv"
    df = pd.read_parquet(pq) if pq.exists() else pd.read_csv(csv, parse_dates=["ts"])
    print(f"Loaded {len(df):,} rows")

    inc_rows = [incremental(df, s, o) for s, o in HYPOTHESES]
    inc = pd.DataFrame(inc_rows)
    inc.to_csv(HERE / "controls_incremental.csv", index=False)

    print("\n=== INCREMENTAL VALUE OF SIGNAL AFTER VIX + AR1 ===")
    print(inc[["signal","outcome","n","r2_baseline","r2_full","delta_r2",
               "signal_coef","signal_t","signal_p"]].to_string(index=False))

    # Save one double-sort per primary hypothesis
    ds_rows = []
    for s, o in HYPOTHESES[:4]:
        ds = double_sort(df, s, o)
        ds["signal"] = s
        ds["outcome"] = o
        ds_rows.append(ds)
    ds_all = pd.concat(ds_rows, ignore_index=True)
    ds_all.to_csv(HERE / "controls_double_sort.csv", index=False)

    # Compact print for GEX->RV double-sort (most interesting)
    print("\n=== DOUBLE-SORT: GEX -> rv_next_1d (rows=VIX quintile, cols=GEX quintile) ===")
    ds_gex = double_sort(df, "net_gex", "rv_next_1d")
    pivot = ds_gex.pivot(index="vix_q", columns="sig_q", values="mean")
    print(pivot.to_string(float_format=lambda x: f"{x:6.2f}"))

    print("\n=== DOUBLE-SORT: VEX -> iv_chg_next_1d ===")
    ds_vex = double_sort(df, "net_vex", "iv_chg_next_1d")
    pivot = ds_vex.pivot(index="vix_q", columns="sig_q", values="mean")
    print(pivot.to_string(float_format=lambda x: f"{x:+6.3f}"))

    print("\nWrote: controls_incremental.csv, controls_double_sort.csv")
    return 0


if __name__ == "__main__":
    import sys
    sys.exit(main())
