import numpy as np import pandas as pd from dataclasses import dataclass from collections import deque from typing import Optional, Dict, Any def clip(x, lo=0.0, hi=1.0): return max(lo, min(hi, x)) def apply_position_limit(new_position: float, daily_change_limit: float, previous_position: float) -> float: """ 限制每日仓位变化幅度:例如 daily_change_limit=0.2 表示每日最多加/减 20% 仓位。 """ if previous_position is None or np.isnan(previous_position): return float(new_position) delta = new_position - previous_position if abs(delta) <= daily_change_limit: return float(new_position) return float(previous_position + np.sign(delta) * daily_change_limit) def atr14(high: pd.Series, low: pd.Series, close: pd.Series, n: int = 14) -> pd.Series: """ 经典ATR(Wilder TR rolling mean版本),足够稳健。 """ prev_close = close.shift(1) tr = pd.concat([ (high - low).abs(), (high - prev_close).abs(), (low - prev_close).abs() ], axis=1).max(axis=1) return tr.rolling(n, min_periods=n).mean() @dataclass class RegimeParams: # 经典低频窗口:尽量固定,减少过拟合自由度 ma_trend_window: int = 120 dd_window: int = 60 vol_window: int = 20 q_window: int = 252 # 分位数阈值(进出做滞回) vol_enter_q: float = 0.80 vol_exit_q: float = 0.70 dd_enter_q: float = 0.20 # dd60是负数,越小越差 dd_exit_q: float = 0.30 tr_enter_q: float = 0.60 # ma120_ratio 越大越强 tr_exit_q: float = 0.50 # 置信度与稳定性设置 stability_cap_days: int = 20 hist_len: int = 10 # 用10天统计抖动 # 仓位相关 daily_change_limit: float = 0.20 # 每日最多变动20%仓位(低频足够) # 各状态的仓位上限(cap)与基准(base) caps: Dict[str, float] = None bases: Dict[str, float] = None # 流动性过滤(可选,用Amount) use_amount_filter: bool = True amount_ma_window: int = 20 amount_min_ratio: float = 0.60 # amount / MA(amount,20) 低于该值,不允许加到高仓位 class MarketRegimeStateMachine: """ 两层大势状态机: Layer1: 风险闸门(vol20 + dd60) Layer2: 趋势/震荡(ma120_ratio) """ def __init__(self, params: Optional[RegimeParams] = None): self.p = params or RegimeParams() if self.p.caps is None: self.p.caps = { "RISK_SEVERE": 0.20, "RISK_MILD": 0.50, "RANGE": 0.50, "TREND_NORMAL": 0.80, "TREND_BULL": 1.00, } if self.p.bases is None: self.p.bases = { "RISK_SEVERE": 0.00, # 风险极端:默认空/极低 "RISK_MILD": 0.20, # 风险偏高:保守 "RANGE": 0.30, # 震荡:中低 "TREND_NORMAL": 0.60, # 正常趋势:中高 "TREND_BULL": 0.90, # 强趋势:高 } self.current_state: str = "RANGE" self.state_duration: int = 0 self.last_position: float = 0.0 self.state_hist: deque = deque(maxlen=self.p.hist_len) def _switch_count_last_n(self) -> int: h = list(self.state_hist) if len(h) <= 1: return 0 return int(sum(h[i] != h[i - 1] for i in range(1, len(h)))) def compute_indicators(self, df: pd.DataFrame) -> pd.DataFrame: """ 输入df至少包含: Date, Close, High, Low 可选:Amount(成交额) """ d = df.copy() # 统一列名(你可以按自己数据源改映射) # 要求列名大小写一致:Date Close High Low Amount for col in ["Close", "High", "Low"]: if col not in d.columns: raise ValueError(f"Missing required column: {col}") # 收益率 d["ret1"] = d["Close"].pct_change() # 趋势主轴:close / MA120 - 1 d["ma_trend"] = d["Close"].rolling(self.p.ma_trend_window, min_periods=self.p.ma_trend_window).mean() d["ma120_ratio"] = d["Close"] / d["ma_trend"] - 1.0 # 风险:波动 & 回撤 d["vol20"] = d["ret1"].rolling(self.p.vol_window, min_periods=self.p.vol_window).std() d["roll_max60"] = d["Close"].rolling(self.p.dd_window, min_periods=self.p.dd_window).max() d["dd60"] = d["Close"] / d["roll_max60"] - 1.0 # <=0 # 可选:ATR稳健风险代理 d["atr14"] = atr14(d["High"], d["Low"], d["Close"], n=14) d["atr14_ratio"] = d["atr14"] / d["Close"] # 可选:成交额过滤 if "Amount" in d.columns: d["amt_ma20"] = d["Amount"].rolling(self.p.amount_ma_window, min_periods=self.p.amount_ma_window).mean() d["amt_ratio"] = d["Amount"] / d["amt_ma20"] else: d["amt_ratio"] = np.nan # 252日滚动分位数阈值 qw = self.p.q_window d["vol_enter"] = d["vol20"].rolling(qw, min_periods=qw).quantile(self.p.vol_enter_q) d["vol_exit"] = d["vol20"].rolling(qw, min_periods=qw).quantile(self.p.vol_exit_q) d["dd_enter"] = d["dd60"].rolling(qw, min_periods=qw).quantile(self.p.dd_enter_q) d["dd_exit"] = d["dd60"].rolling(qw, min_periods=qw).quantile(self.p.dd_exit_q) d["tr_enter"] = d["ma120_ratio"].rolling(qw, min_periods=qw).quantile(self.p.tr_enter_q) d["tr_exit"] = d["ma120_ratio"].rolling(qw, min_periods=qw).quantile(self.p.tr_exit_q) return d def _decide_state(self, row: pd.Series) -> str: """ 两层判定:风险闸门优先 """ # 缺数据时返回当前状态(避免NaN期乱跳) needed = ["vol20", "dd60", "vol_enter", "vol_exit", "dd_enter", "dd_exit", "ma120_ratio", "tr_enter", "tr_exit"] if any(pd.isna(row.get(k, np.nan)) for k in needed): return self.current_state vol20 = float(row["vol20"]) dd60 = float(row["dd60"]) ma120_ratio = float(row["ma120_ratio"]) # Layer 1: risk gate if (vol20 > float(row["vol_enter"])) or (dd60 < float(row["dd_enter"])): return "RISK_SEVERE" if (vol20 > float(row["vol_exit"])) or (dd60 < float(row["dd_exit"])): return "RISK_MILD" # Layer 2: trend filter if ma120_ratio > float(row["tr_enter"]): return "TREND_BULL" if ma120_ratio > float(row["tr_exit"]): return "TREND_NORMAL" return "RANGE" def _confidence(self, row: pd.Series, new_state: str) -> float: """ 置信度 = margin(离阈值距离) + stability(连续性) - jitter(抖动) 全部低自由度、单调可解释。 """ # stability stability = min(self.state_duration, self.p.stability_cap_days) / float(self.p.stability_cap_days) # jitter: 最近hist_len天切换次数归一化 switch_cnt = self._switch_count_last_n() jitter = clip(switch_cnt / 6.0, 0.0, 1.0) # 经验归一化 # margin:根据所处状态,用最关键阈值计算“离边界多远” # 做法:取与该状态最相关的阈值差,除以阈值带宽(enter-exit)避免尺度问题 margin = 0.0 try: if new_state in ("RISK_SEVERE", "RISK_MILD"): # 风险由 vol/dd 触发:取更“危险”的那个因子来计算margin vol20 = float(row["vol20"]) vol_enter = float(row["vol_enter"]); vol_exit = float(row["vol_exit"]) dd60 = float(row["dd60"]) dd_enter = float(row["dd_enter"]); dd_exit = float(row["dd_exit"]) vol_band = max(1e-12, abs(vol_enter - vol_exit)) dd_band = max(1e-12, abs(dd_exit - dd_enter)) # 越超过enter越危险,margin越大 vol_m = clip((vol20 - vol_exit) / vol_band, 0.0, 1.0) dd_m = clip((dd_exit - dd60) / dd_band, 0.0, 1.0) # dd60更小更危险 margin = max(vol_m, dd_m) else: # 趋势由 ma120_ratio 相对 tr_enter/tr_exit tr = float(row["ma120_ratio"]) tr_enter = float(row["tr_enter"]); tr_exit = float(row["tr_exit"]) tr_band = max(1e-12, abs(tr_enter - tr_exit)) # 越站稳(高于exit/enter)margin越大 margin = clip((tr - tr_exit) / tr_band, 0.0, 1.0) except Exception: margin = 0.0 # 组合:基础0.45,margin与稳定性加分,抖动减分 conf = 0.45 + 0.35 * margin + 0.20 * stability - 0.30 * jitter return float(clip(conf, 0.0, 1.0)) def _amount_filter_cap(self, row: pd.Series, cap: float) -> float: """ 可选:成交额不足时,限制最高仓位(只在你提供Amount列且开启开关时生效) """ if not self.p.use_amount_filter: return cap if pd.isna(row.get("amt_ratio", np.nan)): return cap if float(row["amt_ratio"]) < self.p.amount_min_ratio: # 流动性差:不让上高仓位(你也可以改成更保守) return min(cap, 0.50) return cap def step(self, row: pd.Series) -> Dict[str, Any]: """ 输入单日指标行(compute_indicators 后的df某一行),输出: state, confidence, target_position, position(限速后) """ new_state = self._decide_state(row) # 更新状态持续时间 if new_state == self.current_state: self.state_duration += 1 else: self.current_state = new_state self.state_duration = 1 # 更新历史(用于抖动惩罚) self.state_hist.append(self.current_state) conf = self._confidence(row, self.current_state) base = float(self.p.bases[self.current_state]) cap = float(self.p.caps[self.current_state]) cap = self._amount_filter_cap(row, cap) # 目标仓位:基准仓位 * (0.5~1.0)由置信度拉动(低频更稳) target = base * (0.5 + 0.5 * conf) target = min(target, cap) # 仓位限速(关键:previous_position必须用last_position) pos = apply_position_limit( new_position=target, daily_change_limit=self.p.daily_change_limit, previous_position=self.last_position ) # 更新last_position(假设你每个bar只执行一次) self.last_position = pos return { "state": self.current_state, "confidence": conf, "target_position": float(target), "position": float(pos), "state_duration": int(self.state_duration), "switch_count": int(self._switch_count_last_n()), } def run_regime_backtest(df: pd.DataFrame, params: Optional[RegimeParams] = None) -> pd.DataFrame: """ 给整段历史跑出每日状态/置信度/仓位序列。 """ sm = MarketRegimeStateMachine(params=params) ind = sm.compute_indicators(df) out = [] for i in range(len(ind)): row = ind.iloc[i] res = sm.step(row) out.append(res) out_df = pd.DataFrame(out, index=ind.index) return pd.concat([ind, out_df], axis=1)