from __future__ import annotations from dataclasses import dataclass from typing import Dict, Optional, Tuple import numpy as np import pandas as pd # ----------------------------- # 1) 通用工具:滚动分位数 + 指标 # ----------------------------- def rolling_quantile(s: pd.Series, q: float, window: int, min_periods: Optional[int] = None) -> pd.Series: if min_periods is None: min_periods = window return s.rolling(window=window, min_periods=min_periods).quantile(q) def compute_indicators_ohlc(df: pd.DataFrame, ma_window: int = 120, vol_window: int = 20, dd_window: int = 60) -> pd.DataFrame: """ df: 必须包含 Date, Open, High, Low, Close (至少 Close) 返回:附加 ma, ma_ratio, ma_slope, ret, vol, dd """ out = df.copy() # 确保排序 if "Date" in out.columns: out = out.sort_values("Date").reset_index(drop=True) close = pd.to_numeric(out["Close"], errors="coerce") out["ret"] = close.pct_change() out["ma"] = close.rolling(ma_window, min_periods=ma_window).mean() out["ma_ratio"] = close / out["ma"] # MA 斜率:用 ma 的差分(也可用线性回归斜率,这里最小改动) out["ma_slope"] = out["ma"].diff() out["vol"] = out["ret"].rolling(vol_window, min_periods=vol_window).std() # drawdown: 近 dd_window 的最大回撤(从滚动峰值算) roll_max = close.rolling(dd_window, min_periods=dd_window).max() out["dd"] = (close / roll_max) - 1.0 # <=0 return out # ----------------------------- # 2) 单指数状态机:带趋势兜底 + 高波动牛 # ----------------------------- @dataclass class RegimeParams: q_window: int = 252 # 趋势分位阈值(你可以按需要微调) trend_enter_q: float = 0.75 trend_exit_q: float = 0.60 # 风险分位阈值(用于“风险”判别/或高波动牛) vol_enter_q: float = 0.80 vol_exit_q: float = 0.70 dd_enter_q: float = 0.80 dd_exit_q: float = 0.70 # 趋势兜底:连续N天 Close>MA 且 MA斜率>0 trend_fallback_n: int = 8 # 仓位上限(可按你的原系统对齐) cap_trend_lowrisk: float = 1.00 cap_trend_highvol: float = 0.70 cap_range: float = 0.45 cap_risk_mild: float = 0.30 cap_risk_severe: float = 0.20 class SingleIndexRegimeModel: """ 输入:单个指数OHLC DataFrame 输出:state + position_cap """ STATES = ("TREND_BULL", "TREND_BULL_HIGHVOL", "RANGE", "RISK_MILD", "RISK_SEVERE") def __init__(self, params: RegimeParams): self.p = params self.current_state: str = "RANGE" def _fallback_trend(self, df_ind: pd.DataFrame) -> pd.Series: """ 趋势兜底:Close>MA 且 MA_slope>0 连续 N 天 => True """ cond = (df_ind["ma_ratio"] > 1.0) & (df_ind["ma_slope"] > 0) n = self.p.trend_fallback_n # 连续N天满足 return cond.rolling(n, min_periods=n).apply(lambda x: 1.0 if np.all(x) else 0.0, raw=False) == 1.0 def run(self, df_ohlc: pd.DataFrame) -> pd.DataFrame: ind = compute_indicators_ohlc(df_ohlc) # 阈值序列 qwin = self.p.q_window # 趋势阈值:ma_ratio 的滚动分位数 tr_enter = rolling_quantile(ind["ma_ratio"], self.p.trend_enter_q, qwin, min_periods=qwin) tr_exit = rolling_quantile(ind["ma_ratio"], self.p.trend_exit_q, qwin, min_periods=qwin) # 风险阈值:vol, dd 的滚动分位数 vol_enter = rolling_quantile(ind["vol"], self.p.vol_enter_q, qwin, min_periods=qwin) vol_exit = rolling_quantile(ind["vol"], self.p.vol_exit_q, qwin, min_periods=qwin) # dd 是负数,分位数要小心:这里用 dd 的“跌幅绝对值”来分位更直观 dd_abs = (-ind["dd"]).clip(lower=0) dd_enter = rolling_quantile(dd_abs, self.p.dd_enter_q, qwin, min_periods=qwin) dd_exit = rolling_quantile(dd_abs, self.p.dd_exit_q, qwin, min_periods=qwin) ind["tr_enter"], ind["tr_exit"] = tr_enter, tr_exit ind["vol_enter"], ind["vol_exit"] = vol_enter, vol_exit ind["dd_enter"], ind["dd_exit"] = dd_enter, dd_exit # 趋势兜底 ind["trend_fallback"] = self._fallback_trend(ind) # 逐日状态机(最小可读实现) states = [] caps = [] for i in range(len(ind)): row = ind.iloc[i] # 若关键值缺失:保持状态(你之前100%RANGE就是这里被NaN卡住) needed = ["ma_ratio", "vol", "dd", "tr_enter", "tr_exit", "vol_enter", "vol_exit", "dd_enter", "dd_exit"] if any(pd.isna(row[c]) for c in needed): state = self.current_state cap = self._cap_from_state(state) states.append(state) caps.append(cap) continue # 风险判别(用“绝对回撤”和波动) risk_high = (row["vol"] >= row["vol_enter"]) or ((-row["dd"]) >= row["dd_enter"]) risk_low = (row["vol"] <= row["vol_exit"]) and ((-row["dd"]) <= row["dd_exit"]) # 趋势判别:分位阈值 + 兜底 trend_on = (row["ma_ratio"] >= row["tr_enter"]) or bool(row["trend_fallback"]) trend_off = (row["ma_ratio"] <= row["tr_exit"]) and (not bool(row["trend_fallback"])) # 状态转移:先决定“趋势标签”,再用风险决定“低风险牛/高波动牛/风险期” if trend_on: # 牛市已成立:按风险拆成两档(关键!不再让风险把你打回 RANGE) state = "TREND_BULL_HIGHVOL" if risk_high else "TREND_BULL" elif trend_off: # 非趋势:按风险给风险状态,否则震荡 if risk_high: # severe vs mild:这里给一个很轻的区分(你可按原系统更细) # 若 dd_abs 特别高就 severe severe = ((-row["dd"]) >= (row["dd_enter"] * 1.15)) or (row["vol"] >= (row["vol_enter"] * 1.15)) state = "RISK_SEVERE" if severe else "RISK_MILD" else: state = "RANGE" else: # 中间地带:保持原状态,但做风险兜底 state = self.current_state if state.startswith("TREND") and risk_high: state = "TREND_BULL_HIGHVOL" elif state.startswith("RISK") and risk_low: state = "RANGE" self.current_state = state cap = self._cap_from_state(state) states.append(state) caps.append(cap) out = ind.copy() out["state"] = states out["position_cap"] = caps return out def _cap_from_state(self, state: str) -> float: if state == "TREND_BULL": return self.p.cap_trend_lowrisk if state == "TREND_BULL_HIGHVOL": return self.p.cap_trend_highvol if state == "RISK_SEVERE": return self.p.cap_risk_severe if state == "RISK_MILD": return self.p.cap_risk_mild return self.p.cap_range # ----------------------------- # 3) 双指数合成:投票制(推荐) # ----------------------------- @dataclass class FusionParams: """ vote_mode: - "either_trend": 任一指数趋势 => 合成趋势(更敏感) - "both_trend": 两者都趋势 => 合成趋势(更保守) """ vote_mode: str = "either_trend" # 合成后的仓位上限取法: # "min": 取更保守(推荐) # "avg": 平均(更激进) cap_merge: str = "min" class FusionRegimeModel: """ 输入:CSI300 OHLC,CSI500 OHLC 输出:fusion_state + fusion_cap + 两者各自状态 """ def __init__(self, single_params: RegimeParams, fusion_params: FusionParams): self.p_single = single_params self.p_fusion = fusion_params self.m300 = SingleIndexRegimeModel(single_params) self.m500 = SingleIndexRegimeModel(single_params) @staticmethod def _is_trend(state: str) -> bool: return state.startswith("TREND") def run(self, df300: pd.DataFrame, df500: pd.DataFrame, date_col: str = "Date") -> pd.DataFrame: r300 = self.m300.run(df300) r500 = self.m500.run(df500) # 对齐日期 key = date_col if date_col in r300.columns and date_col in r500.columns else None if key is None: # 若没有 Date 列,就按 index 对齐(不推荐) merged = pd.DataFrame({ "state_300": r300["state"].values, "cap_300": r300["position_cap"].values, "state_500": r500["state"].values, "cap_500": r500["position_cap"].values, }) else: merged = pd.merge( r300[[key, "state", "position_cap"]].rename(columns={"state": "state_300", "position_cap": "cap_300"}), r500[[key, "state", "position_cap"]].rename(columns={"state": "state_500", "position_cap": "cap_500"}), on=key, how="inner", ) fusion_states = [] fusion_caps = [] for _, row in merged.iterrows(): s300, s500 = row["state_300"], row["state_500"] c300, c500 = float(row["cap_300"]), float(row["cap_500"]) t300, t500 = self._is_trend(s300), self._is_trend(s500) # 1) 投票决定“趋势标签” if self.p_fusion.vote_mode == "both_trend": is_trend = t300 and t500 else: is_trend = t300 or t500 # either_trend(推荐) # 2) 合成 state:趋势时若任一为 HIGHVOL,则合成 HIGHVOL if is_trend: highvol = (s300 == "TREND_BULL_HIGHVOL") or (s500 == "TREND_BULL_HIGHVOL") fstate = "TREND_BULL_HIGHVOL" if highvol else "TREND_BULL" else: # 非趋势:风险用更“坏”的那个(更稳) if (s300 == "RISK_SEVERE") or (s500 == "RISK_SEVERE"): fstate = "RISK_SEVERE" elif (s300 == "RISK_MILD") or (s500 == "RISK_MILD"): fstate = "RISK_MILD" else: fstate = "RANGE" # 3) 合成仓位上限:默认取 min(更保守) if self.p_fusion.cap_merge == "avg": fcap = (c300 + c500) / 2.0 else: fcap = min(c300, c500) fusion_states.append(fstate) fusion_caps.append(fcap) merged["fusion_state"] = fusion_states merged["fusion_cap"] = fusion_caps return merged # ----------------------------- # 4) 用法示例(你在主流程里这样接) # ----------------------------- def example_usage(df_csi300: pd.DataFrame, df_csi500: pd.DataFrame) -> pd.DataFrame: single_params = RegimeParams( q_window=252, trend_enter_q=0.75, trend_exit_q=0.60, vol_enter_q=0.80, vol_exit_q=0.70, dd_enter_q=0.80, dd_exit_q=0.70, trend_fallback_n=8, cap_trend_lowrisk=1.0, cap_trend_highvol=0.70, cap_range=0.45, cap_risk_mild=0.30, cap_risk_severe=0.20, ) fusion_params = FusionParams(vote_mode="either_trend", cap_merge="min") model = FusionRegimeModel(single_params, fusion_params) fusion = model.run(df_csi300, df_csi500, date_col="Date") # 你后续用 fusion["fusion_cap"] 当作“市场仓位上限”, # fusion["fusion_state"] 用于日志/解释/状态分桶 return fusion