"""
Build the comprehensive feature+outcome dataset for the exposure backtest.

Pulls from QuestDB (spy-dataset instance, pg-wire port 8813), computes outcome
variables (next-day/forward return, realized vol, IV change, multi-horizon), and
writes master_dataset.{csv,parquet}.

Outcome alignment: features live on row t, outcomes are built from t+1..t+k. We
explicitly shift with .shift(-k) so there is NO look-ahead at construction time.

Run:
    python build_dataset.py
"""
from __future__ import annotations

import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import psycopg2

HERE = Path(__file__).parent
OUT = HERE

PG = dict(host="localhost", port=8813, user="admin", password="quest", dbname="qdb")

# Core pull: one row per trading day with every field we will use.
# Joining StockSummary (all four exposures + IV + VIX + price) with VRP (rv_*d
# precomputed rolling realized vol) on (symbol, ts).
QUERY = """
SELECT
    s.ts                    AS ts,
    s.mid                   AS px,                     -- SPY mid at 16:00 ET
    s.atm_iv                AS atm_iv,
    s.hv_20                 AS hv_20,
    s.hv_60                 AS hv_60,
    s.vix                   AS vix,
    s.vix3m                 AS vix_3m,
    s.vvix                  AS vvix,
    s.skew_25d              AS skew_25d,
    s.smile_ratio           AS smile_ratio,
    s.total_call_oi         AS call_oi,
    s.total_put_oi          AS put_oi,
    s.pc_ratio_oi           AS pc_ratio_oi,
    s.net_gex               AS net_gex,
    s.net_dex               AS net_dex,
    s.net_vex               AS net_vex,
    s.net_chex              AS net_chex,
    s.gamma_flip            AS gamma_flip,
    s.exposure_regime       AS exposure_regime,
    s.call_wall             AS call_wall,
    s.put_wall              AS put_wall,
    s.max_pain              AS max_pain,
    s.vix_structure         AS vix_structure,
    v.rv_5d                 AS rv_5d_trailing,
    v.rv_10d                AS rv_10d_trailing,
    v.rv_20d                AS rv_20d_trailing,
    v.rv_30d                AS rv_30d_trailing,
    v.variance_risk_premium AS vrp_pts,
    v.vrp_regime            AS vrp_regime,
    v.gex_regime            AS gex_regime_vrp
FROM SPY_StockSummary s
ASOF JOIN SPY_VRP v
ORDER BY s.ts
"""


def pull_raw() -> pd.DataFrame:
    with psycopg2.connect(connect_timeout=15, **PG) as conn:
        df = pd.read_sql(QUERY, conn)
    return df


def compute_outcomes(df: pd.DataFrame) -> pd.DataFrame:
    """All outcome variables, strictly built from t+1 onward via .shift(-k).

    Conventions:
      - Returns: log, close-to-close, annualized versions ×√252 where applicable.
      - RV_next_1d: |log return t->t+1| × √252 (Parkinson-style 1-day single-obs proxy).
      - RV_fwd_5d / _10d: std of forward 5/10 daily log returns × √252.
      - IV change: atm_iv[t+1] - atm_iv[t].
      - px_dist_to_flip: (px/gamma_flip)-1  (feature, signed distance).
    """
    df = df.copy().sort_values("ts").reset_index(drop=True)

    # Log returns (feature side: trailing; outcome side: forward)
    logret = np.log(df["px"] / df["px"].shift(1))
    df["ret_1d_trailing"] = logret
    df["ret_5d_trailing"] = logret.rolling(5).sum()

    # Forward returns (outcomes)
    df["ret_next_1d"] = logret.shift(-1)            # t -> t+1
    df["ret_fwd_5d"] = logret.shift(-1).rolling(5).sum().shift(-(5 - 1))
    df["ret_fwd_10d"] = logret.shift(-1).rolling(10).sum().shift(-(10 - 1))

    # Forward realized vol (outcomes). Single-day |ret|×√252 is noisy but unbiased.
    # 5d/10d use std of forward window × √252 — standard RV definition.
    df["rv_next_1d"] = df["ret_next_1d"].abs() * np.sqrt(252) * 100  # percent vol points
    fwd1 = logret.shift(-1)
    df["rv_fwd_5d"] = fwd1.rolling(5).std().shift(-(5 - 1)) * np.sqrt(252) * 100
    df["rv_fwd_10d"] = fwd1.rolling(10).std().shift(-(10 - 1)) * np.sqrt(252) * 100

    # IV change (outcome)
    df["iv_chg_next_1d"] = df["atm_iv"].shift(-1) - df["atm_iv"]
    df["iv_chg_fwd_5d"] = df["atm_iv"].shift(-5) - df["atm_iv"]

    # VIX change (for cross-checking)
    df["vix_chg_next_1d"] = df["vix"].shift(-1) - df["vix"]

    # Features — distance to gamma flip (signed, as fraction)
    df["px_over_flip_minus1"] = np.where(
        df["gamma_flip"].gt(0), df["px"] / df["gamma_flip"] - 1, np.nan
    )

    # Abs value and sign helpers
    for col in ["net_gex", "net_dex", "net_vex", "net_chex"]:
        df[f"sign_{col}"] = np.sign(df[col])
        df[f"abs_{col}"] = df[col].abs()

    # Calendar features
    df["date"] = pd.to_datetime(df["ts"]).dt.date
    df["year"] = pd.to_datetime(df["ts"]).dt.year
    df["month"] = pd.to_datetime(df["ts"]).dt.month
    df["dow"] = pd.to_datetime(df["ts"]).dt.dayofweek

    return df


def audit(df: pd.DataFrame) -> None:
    """Print a sanity-check block. Critical for spotting alignment errors."""
    print("=" * 70)
    print("DATASET AUDIT")
    print("=" * 70)
    print(f"Rows:       {len(df):,}")
    print(f"Date range: {df['ts'].min()} -> {df['ts'].max()}")
    print(f"Unique days: {df['date'].nunique():,}")
    print()

    # Null audit on features (features should be mostly non-null)
    feats = ["net_gex", "net_dex", "net_vex", "net_chex", "px", "atm_iv", "vix", "gamma_flip"]
    print("FEATURE NULLS:")
    for c in feats:
        n = df[c].isna().sum()
        print(f"  {c:22s}  {n:4d}  ({n/len(df):5.2%})")

    # Null audit on outcomes (last k rows will be null by construction — that's ok)
    outs = ["ret_next_1d", "rv_next_1d", "rv_fwd_5d", "rv_fwd_10d",
            "iv_chg_next_1d", "vix_chg_next_1d"]
    print("\nOUTCOME NULLS (last k-1 rows naturally null due to shift):")
    for c in outs:
        n = df[c].isna().sum()
        print(f"  {c:22s}  {n:4d}  ({n/len(df):5.2%})")

    # Look-ahead smoke test: ret_next_1d for day t should equal log(px[t+1]/px[t]).
    sample = df.dropna(subset=["ret_next_1d"]).iloc[100:103]
    print("\nLOOK-AHEAD SMOKE TEST (3 rows, manual-compute vs dataset):")
    for i in sample.index:
        px_t = df.loc[i, "px"]
        px_t1 = df.loc[i + 1, "px"] if i + 1 < len(df) else np.nan
        manual = np.log(px_t1 / px_t)
        stored = df.loc[i, "ret_next_1d"]
        match = "OK" if abs(manual - stored) < 1e-10 else "MISMATCH"
        print(f"  {df.loc[i, 'ts']}  manual={manual:+.6f}  stored={stored:+.6f}  {match}")

    # Distribution cheat sheet
    print("\nFEATURE DISTRIBUTIONS (quantiles):")
    qs = [0.01, 0.25, 0.5, 0.75, 0.99]
    for c in ["net_gex", "net_dex", "net_vex", "net_chex"]:
        q = df[c].quantile(qs)
        print(f"  {c:10s}  " + "  ".join(f"p{int(p*100):02d}={q.loc[p]:+.2e}" for p in qs))

    print("\nOUTCOME DISTRIBUTIONS:")
    for c in ["ret_next_1d", "rv_next_1d", "iv_chg_next_1d"]:
        q = df[c].quantile(qs)
        print(f"  {c:22s}  " + "  ".join(f"p{int(p*100):02d}={q.loc[p]:+.4f}" for p in qs))

    # Regime counts
    print("\nREGIME COUNTS:")
    print(df["exposure_regime"].value_counts(dropna=False).to_string())


def main() -> int:
    print("Pulling from QuestDB...")
    raw = pull_raw()
    print(f"  got {len(raw):,} rows")

    print("Computing outcomes...")
    df = compute_outcomes(raw)

    audit(df)

    csv_path = OUT / "master_dataset.csv"
    pq_path = OUT / "master_dataset.parquet"
    df.to_csv(csv_path, index=False)
    try:
        df.to_parquet(pq_path, index=False)
    except Exception as e:
        print(f"  (parquet skipped: {e})")

    print(f"\nWrote: {csv_path}  ({csv_path.stat().st_size/1024:.0f} KB)")
    if pq_path.exists():
        print(f"Wrote: {pq_path}  ({pq_path.stat().st_size/1024:.0f} KB)")
    return 0


if __name__ == "__main__":
    sys.exit(main())
