import sqlite3 import json from datetime import datetime, timedelta from typing import List, Dict, Optional, Tuple # ========================= # 1) 状态持久化:按“最近交易日记录”统计 # ========================= class StatePersistence: """状态持久化管理(最小改动增强版)""" def __init__(self, db_path="state_persistence.db"): self.db_path = db_path self.init_database() def init_database(self): conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( ''' CREATE TABLE IF NOT EXISTS daily_state ( date TEXT PRIMARY KEY, final_state TEXT, raw_state TEXT, score REAL, consecutive_days INTEGER, switch_count_last_5days INTEGER, previous_position REAL, alerts TEXT ) ''' ) conn.commit() conn.close() def save_daily_state( self, date: str, final_state: str, raw_state: str, score: float, consecutive_days: int, switch_count_last_5days: int, previous_position: float, alerts=None, ): conn = sqlite3.connect(self.db_path) cursor = conn.cursor() alerts_json = json.dumps(alerts) if alerts else "[]" cursor.execute( ''' INSERT OR REPLACE INTO daily_state (date, final_state, raw_state, score, consecutive_days, switch_count_last_5days, previous_position, alerts) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( date, final_state, raw_state, score, consecutive_days, switch_count_last_5days, previous_position, alerts_json, ), ) conn.commit() conn.close() def load_previous_state(self, date: str) -> Optional[Dict]: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute('SELECT * FROM daily_state WHERE date = ?', (date,)) result = cursor.fetchone() conn.close() if result: return { 'date': result[0], 'final_state': result[1], 'raw_state': result[2], 'score': result[3], 'consecutive_days': result[4], 'switch_count_last_5days': result[5], 'previous_position': result[6], 'alerts': json.loads(result[7]) if result[7] else [] } return None def load_latest_state_before(self, current_date: str) -> Optional[Dict]: """加载 current_date 之前最近一个交易日状态,而不是机械地 current_date-1。""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( ''' SELECT * FROM daily_state WHERE date < ? ORDER BY date DESC LIMIT 1 ''', (current_date,), ) result = cursor.fetchone() conn.close() if result: return { 'date': result[0], 'final_state': result[1], 'raw_state': result[2], 'score': result[3], 'consecutive_days': result[4], 'switch_count_last_5days': result[5], 'previous_position': result[6], 'alerts': json.loads(result[7]) if result[7] else [] } return None def get_recent_states(self, current_date: str, limit: int = 10) -> List[Dict]: """取 current_date 之前最近 N 个已保存交易日状态。""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( ''' SELECT * FROM daily_state WHERE date < ? ORDER BY date DESC LIMIT ? ''', (current_date, limit), ) rows = cursor.fetchall() conn.close() out = [] for r in rows: out.append({ 'date': r[0], 'final_state': r[1], 'raw_state': r[2], 'score': r[3], 'consecutive_days': r[4], 'switch_count_last_5days': r[5], 'previous_position': r[6], 'alerts': json.loads(r[7]) if r[7] else [] }) return out def calculate_consecutive_days(self, current_state: str, current_date: str, max_lookback: int = 20) -> int: """按最近交易日记录统计连续同状态天数,不受周末/节假日影响。""" recent = self.get_recent_states(current_date, limit=max_lookback) consecutive_days = 1 # 包含今天 for row in recent: if row['final_state'] == current_state: consecutive_days += 1 else: break return consecutive_days def calculate_switch_count_last_5days( self, current_date: str, candidate_state: Optional[str] = None, window: int = 5, ) -> int: """ 统计最近 window 个“交易日状态”的切换次数。 支持把今天候选状态 candidate_state 纳入统计,避免慢一天。 """ recent = self.get_recent_states(current_date, limit=max(window - 1, 1)) states = [] if candidate_state is not None: states.append(candidate_state) states.extend([x['final_state'] for x in recent]) if len(states) <= 1: return 0 switch_count = 0 prev = states[0] for s in states[1:window]: if s != prev: switch_count += 1 prev = s return switch_count # ========================= # 2) 轻量确认门:少改框架,减少碎切并提升准确率 # ========================= STATE_ORDER = { "RISK_SEVERE": 0, "RISK_MILD": 1, "RANGE": 2, "TREND_NORMAL": 3, "TREND_BULL": 4, } def _is_upgrade(prev_state: str, new_state: str) -> bool: return STATE_ORDER.get(new_state, -999) > STATE_ORDER.get(prev_state, -999) def _is_downgrade(prev_state: str, new_state: str) -> bool: return STATE_ORDER.get(new_state, -999) < STATE_ORDER.get(prev_state, -999) def stabilize_state_transition( candidate_state: str, current_date: str, persistence: StatePersistence, severe_state: str = "RISK_SEVERE", confirm_upgrade_days: int = 2, confirm_downgrade_days: int = 1, confirm_exit_severe_days: int = 2, min_hold_days: int = 2, ) -> str: """ 在不改核心评分框架前提下,对 final_state 再包一层轻量确认门。 原则: - 严重风险即时进入 - 严重风险退出需要更高确认 - 升级慢(减少追涨误判) - 降级可快(保留风控能力) - 最短持有期(防碎切) """ prev = persistence.load_latest_state_before(current_date) if not prev: return candidate_state prev_state = prev['final_state'] prev_consecutive = int(prev.get('consecutive_days') or 1) if candidate_state == prev_state: return candidate_state # 严重风险:允许即时进入 if candidate_state == severe_state: return candidate_state # 严重风险退出:要求更高确认,避免刚出险就来回打脸 if prev_state == severe_state and candidate_state != severe_state: recent = persistence.get_recent_states(current_date, limit=max(confirm_exit_severe_days, 5)) # 需要最近若干个“原已落地状态”都不是 severe,才放行退出 non_severe_streak = 0 for row in recent: if row['final_state'] != severe_state: non_severe_streak += 1 else: break if non_severe_streak < (confirm_exit_severe_days - 1): return prev_state return candidate_state # 普通状态最短持有期:优先减少一天一切 if prev_consecutive < min_hold_days: # 允许朝更危险方向快速切,但不允许朝更乐观方向太快切 if _is_upgrade(prev_state, candidate_state): return prev_state recent = persistence.get_recent_states(current_date, limit=max(confirm_upgrade_days, confirm_downgrade_days, 5)) same_as_candidate_streak = 0 for row in recent: if row['final_state'] == candidate_state: same_as_candidate_streak += 1 else: break if _is_upgrade(prev_state, candidate_state): # 升级慢:要求最近已有连续候选方向痕迹,降低假突破 need = max(confirm_upgrade_days - 1, 0) if same_as_candidate_streak < need: return prev_state return candidate_state if _is_downgrade(prev_state, candidate_state): # 降级快:一般放行;如果你想更稳,可把 confirm_downgrade_days 设成 2 need = max(confirm_downgrade_days - 1, 0) if same_as_candidate_streak < need: return prev_state if confirm_downgrade_days > 1 else candidate_state return candidate_state return candidate_state # ========================= # 3) 仓位限速:0仓也限速,降低误判伤害 # ========================= def apply_position_limit(new_position: float, daily_change_limit: float, previous_position: Optional[float]): """ 只对“真正首日(None)”不设限。 previous_position == 0 也要限速,防止误判时一步跳太大。 """ if previous_position is None: return new_position max_change = daily_change_limit actual_change = new_position - previous_position if abs(actual_change) > max_change: if actual_change > 0: return previous_position + max_change return previous_position - max_change return new_position # ========================= # 4) 主流程:最小侵入式接入 # ========================= def run_daily_pipeline_fixed(date, core_stocks, config): """ 最小侵入增强版: - 前一日状态改为“最近交易日状态” - final_state 后加轻量确认门 - 切换次数把今天候选状态纳入统计 - v144: 升级质量增强(指数一致性 + 分数改善 + 牛市更严格) """ state_persistence = StatePersistence() # 最近一个已保存交易日,而不是机械地 date-1 previous_state = state_persistence.load_latest_state_before(date) market_data = get_market_data(date) state_machine = V13StateMachine(config) score = state_machine.calculate_score(market_data) raw_state = state_machine.determine_state(score) cooled_state = state_machine.apply_cooling_period(raw_state, score) prev_final_state = previous_state['final_state'] if previous_state else None prev_score = previous_state['score'] if previous_state else None # === v144: 升级质量增强,只拦“升级”,不拦“降级” === # 你可以把下面这个函数接到自己的三个指数状态输出上。 # 若你已有三指数各自状态,就直接替换这个实现;若没有,可先返回空字典。 index_signals = get_index_state_signals(date, market_data, config) quality_filtered_state = apply_upgrade_quality_filter( prev_state=prev_final_state, candidate_state=cooled_state, current_score=score, previous_score=prev_score, index_signals=index_signals, config=config.get("STATE_STABILITY_CONFIG", {}), ) # 牛市升级再严格一层:要求更高确认天数 if prev_final_state == "TREND_NORMAL" and quality_filtered_state == "TREND_BULL": bull_cfg = dict(config.get("STATE_STABILITY_CONFIG", {})) bull_cfg["confirm_upgrade_days"] = bull_cfg.get("bull_upgrade_confirm_days", 3) else: bull_cfg = config.get("STATE_STABILITY_CONFIG", {}) # 轻量确认门:不改核心评分框架,只改“状态落地” final_state = stabilize_state_transition( candidate_state=quality_filtered_state, current_date=date, persistence=state_persistence, severe_state="RISK_SEVERE", confirm_upgrade_days=bull_cfg.get("confirm_upgrade_days", 2), confirm_downgrade_days=bull_cfg.get("confirm_downgrade_days", 1), confirm_exit_severe_days=bull_cfg.get("confirm_exit_severe_days", 2), min_hold_days=bull_cfg.get("min_hold_days", 2), ) consecutive_days = state_persistence.calculate_consecutive_days(final_state, date) switch_count_last_5days = state_persistence.calculate_switch_count_last_5days( current_date=date, candidate_state=final_state, window=5, ) base_confidence = state_machine.calculate_confidence(final_state, score) adjusted_confidence, _, _ = adjust_confidence_by_stability( base_confidence, consecutive_days, switch_count_last_5days, ) previous_position = previous_state['previous_position'] if previous_state else None suggested_position = calculate_continuous_position_fixed( final_state, score, config["POSITION_MAPPING"], config["POSITION_LIMIT_CONFIG"]["daily_change_limit"], previous_position, ) execution_permissions = calculate_execution_permissions(final_state, adjusted_confidence) trade_candidates = [] for stock in core_stocks: stock_data = get_stock_data(stock, date) alpha_score = calculate_alpha_score(stock_data, final_state) threshold = get_state_threshold(final_state) if alpha_score >= threshold and execution_permissions['allow_new_positions']: action, conf = calibrate_confidence_by_bucket( adjusted_confidence, config["EXECUTION_ACTIONS"], ) position = calculate_stock_position(action, suggested_position) trade_candidates.append({ "stock": stock, "alpha_score": alpha_score, "threshold": threshold, "action": action, "position": position, "intercepted": False, }) else: intercept_reason = [] if alpha_score < threshold: intercept_reason.append( f"Alpha score {alpha_score:.1f} < threshold {threshold:.1f}" ) if not execution_permissions['allow_new_positions']: intercept_reason.append("New positions not allowed in current state") trade_candidates.append({ "stock": stock, "alpha_score": alpha_score, "threshold": threshold, "action": "intercepted", "position": 0, "intercepted": True, "reason": "; ".join(intercept_reason), }) state_persistence.save_daily_state( date, final_state, raw_state, score, consecutive_days, switch_count_last_5days, suggested_position, [], ) report = { "date": date, "market_state": final_state, "raw_state": raw_state, "cooled_state": cooled_state, "quality_filtered_state": quality_filtered_state, "index_signals": index_signals, "confidence": adjusted_confidence, "suggested_position": suggested_position, "execution_permissions": execution_permissions, "trade_candidates": trade_candidates, "consecutive_days": consecutive_days, "switch_count_last_5days": switch_count_last_5days, } return report # ========================= # 5) 配置:只加很少参数,避免过拟合 # ========================= DEFAULT_STATE_STABILITY_CONFIG = { # 只用粗颗粒参数,避免调出“历史最优”细节 "confirm_upgrade_days": 2, "confirm_downgrade_days": 1, "confirm_exit_severe_days": 2, "min_hold_days": 2, # v144:升级质量增强(尽量少参数) "enable_upgrade_quality_filter": True, "upgrade_support_required": 2, # 三指数至少 2 个支持升级 "bull_upgrade_support_required": 2, # 升牛市至少 2 个支持 "bull_upgrade_confirm_days": 3, # 牛市升级比普通升级更严格 "require_score_improvement_on_upgrade": True, } # ========================= # 6) 升级质量增强:提高准确率,不明显增加过拟合风险 # ========================= def get_state_rank(state: str) -> int: return STATE_ORDER.get(state, -999) def count_upgrade_support_from_indices(index_signals: Dict[str, str], candidate_state: str) -> int: """ index_signals 例子: { "zz500": "TREND_NORMAL", "zz1000": "RANGE", "hs300": "TREND_BULL" } 规则:若指数给出的状态等级 >= candidate_state,则视为支持升级。 这是很克制的一致性过滤,不增加新因子,只要求内部一致。 """ target_rank = get_state_rank(candidate_state) support = 0 for _, state in (index_signals or {}).items(): if get_state_rank(state) >= target_rank: support += 1 return support def is_score_improving(current_score: float, previous_score: Optional[float]) -> bool: if previous_score is None: return True return current_score > previous_score def apply_upgrade_quality_filter( prev_state: Optional[str], candidate_state: str, current_score: float, previous_score: Optional[float], index_signals: Optional[Dict[str, str]], config: Dict, ) -> str: """ 只在“升级”时生效: 1) 三指数至少 2 个支持升级 2) 分数较昨日改善(可选) 3) 升牛市更严格 目的:提升升级准确率,而不是放松风控。 """ if not prev_state: return candidate_state if not config.get("enable_upgrade_quality_filter", True): return candidate_state if get_state_rank(candidate_state) <= get_state_rank(prev_state): return candidate_state support_required = config.get("upgrade_support_required", 2) if candidate_state == "TREND_BULL": support_required = config.get("bull_upgrade_support_required", support_required) support = count_upgrade_support_from_indices(index_signals or {}, candidate_state) if support < support_required: return prev_state if config.get("require_score_improvement_on_upgrade", True): if not is_score_improving(current_score, previous_score): return prev_state return candidate_state # ========================= # 7) 建议验证方法(写在注释里,方便回测时照着看) # ========================= # A. 先固定核心参数,不要来回精调: # confirm_upgrade_days = 2 # confirm_downgrade_days = 1 # confirm_exit_severe_days = 2 # min_hold_days = 2 # upgrade_support_required = 2 # bull_upgrade_confirm_days = 3 # B. 只做 very small grid: # - confirm_upgrade_days: [2, 3] # - confirm_downgrade_days: [1, 2] # - confirm_exit_severe_days: [2, 3] # - min_hold_days: [2, 3] # - upgrade_support_required: [2] # - bull_upgrade_confirm_days: [3] # C. 评价顺序: # 1) 切换次数是否明显下降或至少不反弹 # 2) 最大回撤是否不恶化 # 3) “升级日”后 5/20 日收益是否改善 # 4) 总收益是否改善 # D. 若只有某一组参数特别好、邻近参数都差,则视为过拟合嫌疑 # ========================= # 8) 需要你接入的最小适配函数 # ========================= def get_index_state_signals(date, market_data, config) -> Dict[str, str]: """ 这里请接你已有的三个指数状态判定结果。 目标不是新增模型,而是把你系统里“已经存在”的三指数判断拿出来做一致性过滤。 如果你现在还没有拆分成单指数状态,就先返回空字典: return {} 推荐返回格式: { "zz500": "TREND_NORMAL", "zz1000": "TREND_BULL", "hs300": "RANGE", } 最稳的接法: - 复用你现有的单指数评分/状态逻辑 - 不新增额外因子 - 不为每个指数单独调一套参数 """ return {}