"""
Run the per-exposure statistical tests against pre-registered hypotheses.

Outputs (in this folder):
    results_gex.csv         — quintile table for H1 (GEX -> next-day RV)
    results_dex.csv         — quintile table for H2 (DEX -> next-day return)
    results_vex.csv         — quintile table for H3 (VEX -> next-day IV change)
    results_chex.csv        — quintile table for H4 (CHEX -> next-day return)
    summary_headline.csv    — one-row-per-hypothesis headline stats
    regime_splits.csv       — same tests run inside each regime bucket
    baselines.csv           — each signal vs prior-return / VIX / outcome-AR1
    train_test.csv          — robustness: in-sample vs out-of-sample signs
"""
from __future__ import annotations

from pathlib import Path
from dataclasses import dataclass
from typing import Callable

import numpy as np
import pandas as pd
from scipy import stats

HERE = Path(__file__).parent


# ---------- helpers ----------

def _pct(x: float) -> str:
    return f"{x*100:+.2f}%" if pd.notna(x) else "nan"


def quintile_table(df: pd.DataFrame, signal: str, outcome: str, n_bins: int = 5) -> pd.DataFrame:
    """Quintile sort of `outcome` by `signal`. Each row: mean, median, std, n, hit rate.

    Hit rate is defined as P(outcome > 0) per bucket — meaningful for return-like
    outcomes; interpret cautiously for RV (always positive) or IV change (symmetric).
    """
    d = df[[signal, outcome]].dropna().copy()
    d["q"] = pd.qcut(d[signal], q=n_bins, labels=[f"Q{i+1}" for i in range(n_bins)])
    g = d.groupby("q", observed=True)[outcome]
    tbl = pd.DataFrame({
        "n": g.count(),
        "signal_mean": d.groupby("q", observed=True)[signal].mean(),
        "outcome_mean": g.mean(),
        "outcome_median": g.median(),
        "outcome_std": g.std(),
        "pct_pos": (d.groupby("q", observed=True)[outcome]
                    .apply(lambda s: (s > 0).mean())),
    })
    return tbl


def top_minus_bottom_ttest(df: pd.DataFrame, signal: str, outcome: str,
                            n_bins: int = 5) -> dict:
    d = df[[signal, outcome]].dropna().copy()
    d["q"] = pd.qcut(d[signal], q=n_bins, labels=False)
    top = d.loc[d["q"] == n_bins - 1, outcome].values
    bot = d.loc[d["q"] == 0, outcome].values
    t, p = stats.ttest_ind(top, bot, equal_var=False)
    return dict(
        top_mean=float(np.mean(top)),
        bot_mean=float(np.mean(bot)),
        diff=float(np.mean(top) - np.mean(bot)),
        t_stat=float(t),
        p_value=float(p),
        n_top=len(top),
        n_bot=len(bot),
    )


def spearman(df: pd.DataFrame, signal: str, outcome: str) -> dict:
    d = df[[signal, outcome]].dropna()
    rho, p = stats.spearmanr(d[signal], d[outcome])
    return dict(spearman_rho=float(rho), spearman_p=float(p), spearman_n=len(d))


def hit_rate_sign(df: pd.DataFrame, signal: str, outcome: str) -> dict:
    """Sign agreement: sign(signal) == sign(outcome)."""
    d = df[[signal, outcome]].dropna()
    d = d[(d[signal] != 0) & (d[outcome] != 0)]
    hr = float((np.sign(d[signal]) == np.sign(d[outcome])).mean())
    # Binomial p-value vs 50%.
    k = int((np.sign(d[signal]) == np.sign(d[outcome])).sum())
    n = len(d)
    p = stats.binomtest(k, n, p=0.5).pvalue if n > 0 else np.nan
    return dict(hit_rate=hr, hit_n=n, hit_binom_p=float(p))


# ---------- baselines ----------

def baseline_comparison(df: pd.DataFrame, signal: str, outcome: str) -> pd.DataFrame:
    """Compare raw signal correlation with outcome vs simpler baselines.

    Baselines:
      - prior_return (ret_1d_trailing)  — momentum
      - vix                              — vol regime
      - outcome_ar1 (lag of outcome)     — pure persistence
    """
    d = df[[signal, outcome, "ret_1d_trailing", "vix"]].dropna().copy()
    d["outcome_lag1"] = d[outcome].shift(1)
    d = d.dropna()

    rows = []
    for name, s in [
        (signal, d[signal]),
        ("prior_return", d["ret_1d_trailing"]),
        ("vix", d["vix"]),
        ("outcome_ar1", d["outcome_lag1"]),
    ]:
        rho, p = stats.spearmanr(s, d[outcome])
        rows.append(dict(predictor=name, spearman_rho=float(rho),
                         spearman_p=float(p), n=len(d)))
    return pd.DataFrame(rows)


# ---------- regime splits ----------

def regime_splits(df: pd.DataFrame, signal: str, outcome: str,
                   n_bins: int = 5) -> pd.DataFrame:
    """Same top-vs-bottom test inside different regimes."""
    d = df.dropna(subset=[signal, outcome, "vix", "ts"]).copy()
    d["ts"] = pd.to_datetime(d["ts"])

    vix_75 = d["vix"].rolling(252, min_periods=30).quantile(0.75)
    d["vix_regime"] = np.where(d["vix"] > vix_75, "high_vix", "low_vix")

    splits = {
        "all": d,
        "pre_covid": d[d["ts"] < "2020-02-15"],
        "covid": d[(d["ts"] >= "2020-02-15") & (d["ts"] < "2020-06-01")],
        "post_covid": d[d["ts"] >= "2020-06-01"],
        "high_vix": d[d["vix_regime"] == "high_vix"],
        "low_vix": d[d["vix_regime"] == "low_vix"],
    }

    rows = []
    for name, sub in splits.items():
        if len(sub) < 50:  # too small to quintile
            continue
        try:
            sub2 = sub.copy()
            sub2["q"] = pd.qcut(sub2[signal], q=n_bins, labels=False, duplicates="drop")
            top = sub2.loc[sub2["q"] == sub2["q"].max(), outcome].values
            bot = sub2.loc[sub2["q"] == sub2["q"].min(), outcome].values
            t, p = stats.ttest_ind(top, bot, equal_var=False)
            rho, rp = stats.spearmanr(sub[signal], sub[outcome])
            rows.append(dict(
                regime=name, n=len(sub),
                bot_mean=float(np.mean(bot)), top_mean=float(np.mean(top)),
                diff=float(np.mean(top) - np.mean(bot)),
                t_stat=float(t), p_value=float(p),
                spearman_rho=float(rho), spearman_p=float(rp),
            ))
        except Exception as e:
            rows.append(dict(regime=name, n=len(sub), error=str(e)))
    return pd.DataFrame(rows)


def train_test_split_test(df: pd.DataFrame, signal: str, outcome: str,
                           frac: float = 0.7) -> pd.DataFrame:
    d = df.dropna(subset=[signal, outcome, "ts"]).copy()
    d["ts"] = pd.to_datetime(d["ts"]).sort_values().values
    cut = d["ts"].quantile(frac)
    rows = []
    for name, sub in [("in_sample", d[d["ts"] <= cut]), ("out_of_sample", d[d["ts"] > cut])]:
        rho, p = stats.spearmanr(sub[signal], sub[outcome])
        sub2 = sub.copy()
        sub2["q"] = pd.qcut(sub2[signal], 5, labels=False, duplicates="drop")
        top = sub2.loc[sub2["q"] == sub2["q"].max(), outcome].values
        bot = sub2.loc[sub2["q"] == sub2["q"].min(), outcome].values
        t, pv = stats.ttest_ind(top, bot, equal_var=False)
        rows.append(dict(
            split=name, n=len(sub),
            spearman_rho=float(rho), spearman_p=float(p),
            tb_diff=float(np.mean(top) - np.mean(bot)),
            tb_t=float(t), tb_p=float(pv),
        ))
    return pd.DataFrame(rows)


# ---------- main runner ----------

@dataclass
class Hypothesis:
    name: str
    signal: str
    outcome: str
    note: str


HYPOTHESES = [
    Hypothesis("H1_GEX_to_RV",     "net_gex",  "rv_next_1d",    "GEX -> next-day realized vol"),
    Hypothesis("H2_DEX_to_ret",    "net_dex",  "ret_next_1d",   "DEX -> next-day return"),
    Hypothesis("H3_VEX_to_dIV",    "net_vex",  "iv_chg_next_1d","VEX -> next-day IV change"),
    Hypothesis("H4_CHEX_to_ret",   "net_chex", "ret_next_1d",   "CHEX -> next-day return"),
]

# Secondary: longer horizons + regime variable
SECONDARY = [
    Hypothesis("S1_GEX_to_RV5d",   "net_gex",           "rv_fwd_5d",      "GEX -> 5d forward RV"),
    Hypothesis("S2_DistFlip_to_r", "px_over_flip_minus1","ret_next_1d",   "Dist-to-flip -> next-day return"),
    Hypothesis("S3_GEX_to_VIXchg", "net_gex",           "vix_chg_next_1d","GEX -> VIX change"),
]


def run_all(df: pd.DataFrame) -> None:
    headline_rows = []
    all_regime_rows = []
    all_baseline_rows = []
    all_tt_rows = []

    for h in HYPOTHESES + SECONDARY:
        # Per-hypothesis quintile table
        qt = quintile_table(df, h.signal, h.outcome)
        qt.to_csv(HERE / f"results_{h.name}.csv")

        # Headline metrics
        tt = top_minus_bottom_ttest(df, h.signal, h.outcome)
        sp = spearman(df, h.signal, h.outcome)
        hr = hit_rate_sign(df, h.signal, h.outcome)
        headline_rows.append(dict(
            hypothesis=h.name,
            signal=h.signal,
            outcome=h.outcome,
            note=h.note,
            **tt, **sp, **hr,
        ))

        # Regime splits
        rs = regime_splits(df, h.signal, h.outcome)
        rs["hypothesis"] = h.name
        all_regime_rows.append(rs)

        # Baselines
        bl = baseline_comparison(df, h.signal, h.outcome)
        bl["hypothesis"] = h.name
        all_baseline_rows.append(bl)

        # Train/test robustness
        tt_df = train_test_split_test(df, h.signal, h.outcome)
        tt_df["hypothesis"] = h.name
        all_tt_rows.append(tt_df)

    pd.DataFrame(headline_rows).to_csv(HERE / "summary_headline.csv", index=False)
    pd.concat(all_regime_rows, ignore_index=True).to_csv(HERE / "regime_splits.csv", index=False)
    pd.concat(all_baseline_rows, ignore_index=True).to_csv(HERE / "baselines.csv", index=False)
    pd.concat(all_tt_rows, ignore_index=True).to_csv(HERE / "train_test.csv", index=False)


def print_headline(df: pd.DataFrame) -> None:
    """Human-readable headline print for the transcript / console."""
    print("=" * 100)
    print("HEADLINE RESULTS")
    print("=" * 100)
    for h in HYPOTHESES:
        tt = top_minus_bottom_ttest(df, h.signal, h.outcome)
        sp = spearman(df, h.signal, h.outcome)
        hr = hit_rate_sign(df, h.signal, h.outcome)
        print(f"\n[{h.name}] {h.note}")
        print(f"  quintile top vs bot:   top={tt['top_mean']:+.4f}  bot={tt['bot_mean']:+.4f}  "
              f"diff={tt['diff']:+.4f}  t={tt['t_stat']:+.2f}  p={tt['p_value']:.3g}  n={tt['n_top']+tt['n_bot']}")
        print(f"  Spearman rho:          {sp['spearman_rho']:+.4f}  p={sp['spearman_p']:.3g}")
        print(f"  sign(signal)==sign(outcome) hit rate: {hr['hit_rate']:.3f}  p_vs_50={hr['hit_binom_p']:.3g}  n={hr['hit_n']}")


def main() -> int:
    pq = HERE / "master_dataset.parquet"
    csv = HERE / "master_dataset.csv"
    df = pd.read_parquet(pq) if pq.exists() else pd.read_csv(csv, parse_dates=["ts"])
    print(f"Loaded {len(df):,} rows  ({df['ts'].min()} -> {df['ts'].max()})")

    run_all(df)
    print_headline(df)
    print("\nWrote: results_*.csv, summary_headline.csv, regime_splits.csv, baselines.csv, train_test.csv")
    return 0


if __name__ == "__main__":
    import sys
    sys.exit(main())
