import sqlite3 import json from datetime import datetime, timedelta from typing import List, Dict, Optional, Tuple # ========================= # 1) 状态持久化:按“最近交易日记录”统计 # ========================= class StatePersistence: """状态持久化管理(最小改动增强版)""" def __init__(self, db_path="state_persistence.db"): self.db_path = db_path self.init_database() def init_database(self): conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( ''' CREATE TABLE IF NOT EXISTS daily_state ( date TEXT PRIMARY KEY, final_state TEXT, raw_state TEXT, score REAL, consecutive_days INTEGER, switch_count_last_5days INTEGER, previous_position REAL, alerts TEXT ) ''' ) conn.commit() conn.close() def save_daily_state( self, date: str, final_state: str, raw_state: str, score: float, consecutive_days: int, switch_count_last_5days: int, previous_position: float, alerts=None, ): conn = sqlite3.connect(self.db_path) cursor = conn.cursor() alerts_json = json.dumps(alerts) if alerts else "[]" cursor.execute( ''' INSERT OR REPLACE INTO daily_state (date, final_state, raw_state, score, consecutive_days, switch_count_last_5days, previous_position, alerts) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( date, final_state, raw_state, score, consecutive_days, switch_count_last_5days, previous_position, alerts_json, ), ) conn.commit() conn.close() def load_previous_state(self, date: str) -> Optional[Dict]: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute('SELECT * FROM daily_state WHERE date = ?', (date,)) result = cursor.fetchone() conn.close() if result: return { 'date': result[0], 'final_state': result[1], 'raw_state': result[2], 'score': result[3], 'consecutive_days': result[4], 'switch_count_last_5days': result[5], 'previous_position': result[6], 'alerts': json.loads(result[7]) if result[7] else [] } return None def load_latest_state_before(self, current_date: str) -> Optional[Dict]: """加载 current_date 之前最近一个交易日状态,而不是机械地 current_date-1。""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( ''' SELECT * FROM daily_state WHERE date < ? ORDER BY date DESC LIMIT 1 ''', (current_date,), ) result = cursor.fetchone() conn.close() if result: return { 'date': result[0], 'final_state': result[1], 'raw_state': result[2], 'score': result[3], 'consecutive_days': result[4], 'switch_count_last_5days': result[5], 'previous_position': result[6], 'alerts': json.loads(result[7]) if result[7] else [] } return None def get_recent_states(self, current_date: str, limit: int = 10) -> List[Dict]: """取 current_date 之前最近 N 个已保存交易日状态。""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( ''' SELECT * FROM daily_state WHERE date < ? ORDER BY date DESC LIMIT ? ''', (current_date, limit), ) rows = cursor.fetchall() conn.close() out = [] for r in rows: out.append({ 'date': r[0], 'final_state': r[1], 'raw_state': r[2], 'score': r[3], 'consecutive_days': r[4], 'switch_count_last_5days': r[5], 'previous_position': r[6], 'alerts': json.loads(r[7]) if r[7] else [] }) return out def calculate_consecutive_days(self, current_state: str, current_date: str, max_lookback: int = 20) -> int: """按最近交易日记录统计连续同状态天数,不受周末/节假日影响。""" recent = self.get_recent_states(current_date, limit=max_lookback) consecutive_days = 1 # 包含今天 for row in recent: if row['final_state'] == current_state: consecutive_days += 1 else: break return consecutive_days def calculate_switch_count_last_5days( self, current_date: str, candidate_state: Optional[str] = None, window: int = 5, ) -> int: """ 统计最近 window 个“交易日状态”的切换次数。 支持把今天候选状态 candidate_state 纳入统计,避免慢一天。 """ recent = self.get_recent_states(current_date, limit=max(window - 1, 1)) states = [] if candidate_state is not None: states.append(candidate_state) states.extend([x['final_state'] for x in recent]) if len(states) <= 1: return 0 switch_count = 0 prev = states[0] for s in states[1:window]: if s != prev: switch_count += 1 prev = s return switch_count # ========================= # 2) 轻量确认门:少改框架,减少碎切并提升准确率 # ========================= STATE_ORDER = { "RISK_SEVERE": 0, "RISK_MILD": 1, "RANGE": 2, "TREND_NORMAL": 3, "TREND_BULL": 4, } def _is_upgrade(prev_state: str, new_state: str) -> bool: return STATE_ORDER.get(new_state, -999) > STATE_ORDER.get(prev_state, -999) def _is_downgrade(prev_state: str, new_state: str) -> bool: return STATE_ORDER.get(new_state, -999) < STATE_ORDER.get(prev_state, -999) def stabilize_state_transition( candidate_state: str, current_date: str, persistence: StatePersistence, severe_state: str = "RISK_SEVERE", confirm_upgrade_days: int = 2, confirm_downgrade_days: int = 1, confirm_exit_severe_days: int = 2, min_hold_days: int = 2, ) -> str: """ 在不改核心评分框架前提下,对 final_state 再包一层轻量确认门。 原则: - 严重风险即时进入 - 严重风险退出需要更高确认 - 升级慢(减少追涨误判) - 降级可快(保留风控能力) - 最短持有期(防碎切) """ prev = persistence.load_latest_state_before(current_date) if not prev: return candidate_state prev_state = prev['final_state'] prev_consecutive = int(prev.get('consecutive_days') or 1) if candidate_state == prev_state: return candidate_state # 严重风险:允许即时进入 if candidate_state == severe_state: return candidate_state # 严重风险退出:要求更高确认,避免刚出险就来回打脸 if prev_state == severe_state and candidate_state != severe_state: recent = persistence.get_recent_states(current_date, limit=max(confirm_exit_severe_days, 5)) # 需要最近若干个“原已落地状态”都不是 severe,才放行退出 non_severe_streak = 0 for row in recent: if row['final_state'] != severe_state: non_severe_streak += 1 else: break if non_severe_streak < (confirm_exit_severe_days - 1): return prev_state return candidate_state # 普通状态最短持有期:优先减少一天一切 if prev_consecutive < min_hold_days: # 允许朝更危险方向快速切,但不允许朝更乐观方向太快切 if _is_upgrade(prev_state, candidate_state): return prev_state recent = persistence.get_recent_states(current_date, limit=max(confirm_upgrade_days, confirm_downgrade_days, 5)) same_as_candidate_streak = 0 for row in recent: if row['final_state'] == candidate_state: same_as_candidate_streak += 1 else: break if _is_upgrade(prev_state, candidate_state): # 升级慢:要求最近已有连续候选方向痕迹,降低假突破 need = max(confirm_upgrade_days - 1, 0) if same_as_candidate_streak < need: return prev_state return candidate_state if _is_downgrade(prev_state, candidate_state): # 降级快:一般放行;如果你想更稳,可把 confirm_downgrade_days 设成 2 need = max(confirm_downgrade_days - 1, 0) if same_as_candidate_streak < need: return prev_state if confirm_downgrade_days > 1 else candidate_state return candidate_state return candidate_state # ========================= # 3) 仓位限速:0仓也限速,降低误判伤害 # ========================= def apply_position_limit(new_position: float, daily_change_limit: float, previous_position: Optional[float]): """ 只对“真正首日(None)”不设限。 previous_position == 0 也要限速,防止误判时一步跳太大。 """ if previous_position is None: return new_position max_change = daily_change_limit actual_change = new_position - previous_position if abs(actual_change) > max_change: if actual_change > 0: return previous_position + max_change return previous_position - max_change return new_position # ========================= # 4) 主流程:最小侵入式接入 # ========================= def run_daily_pipeline_fixed(date, core_stocks, config): """ 最小侵入增强版: - 前一日状态改为“最近交易日状态” - final_state 后加轻量确认门 - 切换次数把今天候选状态纳入统计 """ state_persistence = StatePersistence() # 最近一个已保存交易日,而不是机械地 date-1 previous_state = state_persistence.load_latest_state_before(date) market_data = get_market_data(date) state_machine = V13StateMachine(config) score = state_machine.calculate_score(market_data) raw_state = state_machine.determine_state(score) cooled_state = state_machine.apply_cooling_period(raw_state, score) # 轻量确认门:不改核心评分框架,只改“状态落地” final_state = stabilize_state_transition( candidate_state=cooled_state, current_date=date, persistence=state_persistence, severe_state="RISK_SEVERE", confirm_upgrade_days=config.get("STATE_STABILITY_CONFIG", {}).get("confirm_upgrade_days", 2), confirm_downgrade_days=config.get("STATE_STABILITY_CONFIG", {}).get("confirm_downgrade_days", 1), confirm_exit_severe_days=config.get("STATE_STABILITY_CONFIG", {}).get("confirm_exit_severe_days", 2), min_hold_days=config.get("STATE_STABILITY_CONFIG", {}).get("min_hold_days", 2), ) consecutive_days = state_persistence.calculate_consecutive_days(final_state, date) switch_count_last_5days = state_persistence.calculate_switch_count_last_5days( current_date=date, candidate_state=final_state, window=5, ) base_confidence = state_machine.calculate_confidence(final_state, score) adjusted_confidence, _, _ = adjust_confidence_by_stability( base_confidence, consecutive_days, switch_count_last_5days, ) previous_position = previous_state['previous_position'] if previous_state else None suggested_position = calculate_continuous_position_fixed( final_state, score, config["POSITION_MAPPING"], config["POSITION_LIMIT_CONFIG"]["daily_change_limit"], previous_position, ) execution_permissions = calculate_execution_permissions(final_state, adjusted_confidence) trade_candidates = [] for stock in core_stocks: stock_data = get_stock_data(stock, date) alpha_score = calculate_alpha_score(stock_data, final_state) threshold = get_state_threshold(final_state) if alpha_score >= threshold and execution_permissions['allow_new_positions']: action, conf = calibrate_confidence_by_bucket( adjusted_confidence, config["EXECUTION_ACTIONS"], ) position = calculate_stock_position(action, suggested_position) trade_candidates.append({ "stock": stock, "alpha_score": alpha_score, "threshold": threshold, "action": action, "position": position, "intercepted": False, }) else: intercept_reason = [] if alpha_score < threshold: intercept_reason.append( f"Alpha score {alpha_score:.1f} < threshold {threshold:.1f}" ) if not execution_permissions['allow_new_positions']: intercept_reason.append("New positions not allowed in current state") trade_candidates.append({ "stock": stock, "alpha_score": alpha_score, "threshold": threshold, "action": "intercepted", "position": 0, "intercepted": True, "reason": "; ".join(intercept_reason), }) state_persistence.save_daily_state( date, final_state, raw_state, score, consecutive_days, switch_count_last_5days, suggested_position, [], ) report = { "date": date, "market_state": final_state, "raw_state": raw_state, "cooled_state": cooled_state, "confidence": adjusted_confidence, "suggested_position": suggested_position, "execution_permissions": execution_permissions, "trade_candidates": trade_candidates, "consecutive_days": consecutive_days, "switch_count_last_5days": switch_count_last_5days, } return report # ========================= # 5) 配置:只加很少参数,避免过拟合 # ========================= DEFAULT_STATE_STABILITY_CONFIG = { # 只用粗颗粒参数,避免调出“历史最优”细节 "confirm_upgrade_days": 2, "confirm_downgrade_days": 1, "confirm_exit_severe_days": 2, "min_hold_days": 2, } # ========================= # 6) 建议验证方法(写在注释里,方便回测时照着看) # ========================= # A. 先固定这4个参数,不要来回精调:2 / 1 / 2 / 2 # B. 只做 very small grid: # - confirm_upgrade_days: [2, 3] # - confirm_downgrade_days: [1, 2] # - confirm_exit_severe_days: [2, 3] # - min_hold_days: [2, 3] # C. 评价顺序: # 1) 切换次数是否明显下降 # 2) 最大回撤是否不恶化 # 3) 收益是否改善 # D. 若只有某一组参数特别好、邻近参数都差,则视为过拟合嫌疑