import os import sys import time import warnings from datetime import datetime, timedelta import matplotlib.dates as mdates import matplotlib.pyplot as plt import numpy as np import pandas as pd import requests warnings.filterwarnings('ignore') # 添加项目路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model import Kronos, KronosTokenizer, KronosPredictor # ==================== 配置参数 ==================== class PredictionConfig: """预测配置类""" # 支持的时间间隔及其对应的分钟数 SUPPORTED_INTERVALS = { '1m': 1, '3m': 3, '5m': 5, '15m': 15, '30m': 30, '1h': 60, '2h': 120, '4h': 240, '6h': 360, '8h': 480, '12h': 720, '1d': 1440 } def __init__(self, interval='15m', pred_len=24, lookback=400): self.interval = interval self.pred_len = pred_len self.lookback = lookback self.interval_minutes = self.SUPPORTED_INTERVALS.get(interval, 15) # 验证参数 if interval not in self.SUPPORTED_INTERVALS: raise ValueError(f"不支持的时间间隔: {interval}. 支持的间隔: {list(self.SUPPORTED_INTERVALS.keys())}") def get_24h_periods(self): """计算24小时对应的K线数量""" return int(24 * 60 / self.interval_minutes) def get_prediction_duration_hours(self): """计算预测时长(小时)""" return (self.pred_len * self.interval_minutes) / 60 def get_freq_string(self): """获取pandas频率字符串""" freq_map = { '1m': '1T', '3m': '3T', '5m': '5T', '15m': '15T', '30m': '30T', '1h': '1H', '2h': '2H', '4h': '4H', '6h': '6H', '8h': '8H', '12h': '12H', '1d': '1D' } return freq_map.get(self.interval, '15T') def get_time_format(self): """根据时间间隔获取最佳时间显示格式""" if self.interval_minutes <= 60: # 小时内 return '%H:%M' elif self.interval_minutes <= 1440: # 日内 return '%m-%d %H:%M' else: # 日级别 return '%Y-%m-%d' def get_prediction_time_points(self): """Get prediction time point descriptions""" duration_hours = self.get_prediction_duration_hours() if duration_hours < 1: return [ (int(self.pred_len * 0.25), f"{int(duration_hours * 0.25 * 60)}min later"), (int(self.pred_len * 0.5), f"{int(duration_hours * 0.5 * 60)}min later"), (self.pred_len - 1, f"{int(duration_hours * 60)}min later") ] elif duration_hours <= 24: return [ (int(self.pred_len * 0.25), f"{duration_hours * 0.25:.1f}h later"), (int(self.pred_len * 0.5), f"{duration_hours * 0.5:.1f}h later"), (self.pred_len - 1, f"{duration_hours:.1f}h later") ] else: days = duration_hours / 24 return [ (int(self.pred_len * 0.25), f"{days * 0.25:.1f}d later"), (int(self.pred_len * 0.5), f"{days * 0.5:.1f}d later"), (self.pred_len - 1, f"{days:.1f}d later") ] def get_user_config(): """获取用户配置(可扩展为交互式输入)""" # 配置示例: # 短期交易 (1分钟K线,预测30分钟) # config = PredictionConfig(interval='1m', pred_len=30, lookback=400) # 中期交易 (15分钟K线,预测6小时) config = PredictionConfig(interval='15m', pred_len=24, lookback=512) # 长期分析 (1小时K线,预测24小时) # config = PredictionConfig(interval='1h', pred_len=24, lookback=400) # 日线分析 (1天K线,预测7天) # config = PredictionConfig(interval='1d', pred_len=7, lookback=200) print("📋 当前预测配置:") print(f" 时间间隔: {config.interval}") print(f" 预测长度: {config.pred_len} 步") print(f" 预测时长: {config.get_prediction_duration_hours():.1f} 小时") print(f" 历史数据: {config.lookback} 个数据点") print(f" 支持的间隔: {list(config.SUPPORTED_INTERVALS.keys())}") print() return config class BinanceDataFetcher: """币安数据获取器 - 中国可访问""" def __init__(self): self.base_url = "https://api.binance.com" # 备用URL(如果主URL不可访问) self.backup_urls = [ "https://api1.binance.com", "https://api2.binance.com", "https://api3.binance.com" ] def get_klines(self, symbol="ETHUSDT", interval="15m", limit=500): """ 获取K线数据 Args: symbol: 交易对符号 (如 ETHUSDT) interval: 时间间隔 (1m, 3m, 5m, 15m, 30m, 1h, 2h, 4h, 6h, 8h, 12h, 1d, 3d, 1w, 1M) limit: 数据条数 (最大1000) """ endpoint = "/api/v3/klines" params = { 'symbol': symbol, 'interval': interval, 'limit': limit } # 尝试多个URL for url in [self.base_url] + self.backup_urls: try: response = requests.get(url + endpoint, params=params, timeout=10) if response.status_code == 200: data = response.json() return self._parse_klines(data) else: print(f"API返回错误: {response.status_code}") except Exception as e: print(f"尝试URL {url} 失败: {e}") continue raise Exception("所有API URL都无法访问,请检查网络连接") def _parse_klines(self, raw_data): """解析K线数据""" df = pd.DataFrame(raw_data, columns=[ 'timestamp', 'open', 'high', 'low', 'close', 'volume', 'close_time', 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore' ]) # 转换数据类型 df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') for col in ['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume']: df[col] = df[col].astype(float) # 重命名列以匹配Kronos格式 df = df.rename(columns={ 'timestamp': 'timestamps', 'quote_asset_volume': 'amount' }) # 选择需要的列 df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] return df.sort_values('timestamps').reset_index(drop=True) def plot_prediction_with_realtime(historical_df, pred_df, config, symbol="ETH-USDT"): """Plot professional financial chart style real-time prediction results""" # Set font to avoid Chinese display issues plt.rcParams['font.family'] = 'DejaVu Sans' plt.rcParams['axes.unicode_minus'] = False # 准备时间序列数据 hist_timestamps = pd.to_datetime(historical_df['timestamps']) pred_timestamps = pd.to_datetime(pred_df.index) if hasattr(pred_df, 'index') else pd.date_range( start=hist_timestamps.iloc[-1] + pd.Timedelta(minutes=config.interval_minutes), periods=len(pred_df), freq=config.get_freq_string() ) # 准备数据 hist_close = historical_df['close'].values pred_close = pred_df['close'].values hist_volume = historical_df['volume'].values pred_volume = pred_df['volume'].values # 计算关键统计信息 current_price = hist_close[-1] final_pred_price = pred_close[-1] price_change = final_pred_price - current_price price_change_pct = (price_change / current_price) * 100 max_pred_price = np.max(pred_close) min_pred_price = np.min(pred_close) volatility = ((max_pred_price - min_pred_price) / current_price) * 100 # 创建专业图表布局 fig = plt.figure(figsize=(16, 10)) gs = fig.add_gridspec(3, 4, height_ratios=[0.8, 2, 1], width_ratios=[1, 1, 1, 1], hspace=0.3, wspace=0.2) # 信息面板 ax_info = fig.add_subplot(gs[0, :]) ax_info.axis('off') # 主价格图 ax_price = fig.add_subplot(gs[1, :]) # 成交量图 ax_volume = fig.add_subplot(gs[2, :], sharex=ax_price) # === Information Panel === duration_text = f"{config.get_prediction_duration_hours():.1f}h" if config.get_prediction_duration_hours() >= 1 else f"{int(config.get_prediction_duration_hours() * 60)}min" info_text = f""" {symbol} {config.interval} Candlestick Prediction Analysis | {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Config: {config.pred_len}-step prediction ({duration_text}) | Historical data: {config.lookback} points Current Price: ${current_price:.2f} Predicted Price: ${final_pred_price:.2f} Change: {price_change:+.2f} ({price_change_pct:+.2f}%) Predicted High: ${max_pred_price:.2f} Predicted Low: ${min_pred_price:.2f} Volatility: {volatility:.2f}% """ ax_info.text(0.5, 0.5, info_text, transform=ax_info.transAxes, fontsize=12, ha='center', va='center', bbox=dict(boxstyle="round,pad=0.5", facecolor='lightblue', alpha=0.8)) # === 价格图绘制 === # 历史价格线 ax_price.plot(hist_timestamps, hist_close, color='#1f77b4', linewidth=2, label='Historical Price', alpha=0.8) # 预测价格线(虚线) ax_price.plot(pred_timestamps, pred_close, color='#ff7f0e', linewidth=2.5, linestyle='--', label='Predicted Price', alpha=0.9) # 预测区域背景高亮 y_min, y_max = ax_price.get_ylim() pred_start = pred_timestamps[0] pred_end = pred_timestamps[-1] ax_price.axvspan(pred_start, pred_end, alpha=0.1, color='orange', label='Prediction Zone') # 关键价格点标注 # 当前价格点 ax_price.scatter(hist_timestamps.iloc[-1], current_price, color='blue', s=100, zorder=5, marker='o') ax_price.annotate(f'Current: ${current_price:.2f}', xy=(hist_timestamps.iloc[-1], current_price), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue'), arrowprops=dict(arrowstyle='->', color='blue')) # 最终预测价格点 ax_price.scatter(pred_timestamps[-1], final_pred_price, color='red', s=100, zorder=5, marker='s') ax_price.annotate(f'Predicted: ${final_pred_price:.2f}', xy=(pred_timestamps[-1], final_pred_price), xytext=(-80, 10), textcoords='offset points', bbox=dict(boxstyle='round,pad=0.3', facecolor='lightyellow'), arrowprops=dict(arrowstyle='->', color='red')) # 最高最低价格标注 max_idx = np.argmax(pred_close) min_idx = np.argmin(pred_close) ax_price.scatter(pred_timestamps[max_idx], max_pred_price, color='green', s=80, zorder=5, marker='^') ax_price.annotate(f'High: ${max_pred_price:.2f}', xy=(pred_timestamps[max_idx], max_pred_price), xytext=(0, 15), textcoords='offset points', ha='center', fontsize=9, bbox=dict(boxstyle='round,pad=0.2', facecolor='lightgreen')) ax_price.scatter(pred_timestamps[min_idx], min_pred_price, color='red', s=80, zorder=5, marker='v') ax_price.annotate(f'Low: ${min_pred_price:.2f}', xy=(pred_timestamps[min_idx], min_pred_price), xytext=(0, -20), textcoords='offset points', ha='center', fontsize=9, bbox=dict(boxstyle='round,pad=0.2', facecolor='lightcoral')) # 趋势箭头 if price_change > 0: ax_price.annotate('', xy=(pred_timestamps[-1], final_pred_price), xytext=(pred_timestamps[0], pred_close[0]), arrowprops=dict(arrowstyle='->', color='green', lw=2, alpha=0.6)) else: ax_price.annotate('', xy=(pred_timestamps[-1], final_pred_price), xytext=(pred_timestamps[0], pred_close[0]), arrowprops=dict(arrowstyle='->', color='red', lw=2, alpha=0.6)) # === 成交量图绘制 === # 计算柱状图宽度(约为时间间隔的80%) bar_width = pd.Timedelta(minutes=config.interval_minutes * 0.8) # 历史成交量柱状图 ax_volume.bar(hist_timestamps, hist_volume, width=bar_width, color='#1f77b4', alpha=0.6, label='Historical Volume') # 预测成交量柱状图 ax_volume.bar(pred_timestamps, pred_volume, width=bar_width, color='#ff7f0e', alpha=0.7, label='Predicted Volume') # === 图表格式化 === # 价格图设置 ax_price.set_ylabel('Price (USDT)', fontsize=12, fontweight='bold') ax_price.grid(True, alpha=0.3, linestyle='-', linewidth=0.5) ax_price.legend(loc='upper left', frameon=True, fancybox=True, shadow=True) ax_price.set_facecolor('#fafafa') # 成交量图设置 ax_volume.set_ylabel('Volume (ETH)', fontsize=12, fontweight='bold') ax_volume.set_xlabel('Time', fontsize=12, fontweight='bold') ax_volume.grid(True, alpha=0.3, linestyle='-', linewidth=0.5) ax_volume.legend(loc='upper left', frameon=True, fancybox=True, shadow=True) ax_volume.set_facecolor('#fafafa') # 时间轴格式化(基于配置参数) time_format = config.get_time_format() ax_volume.xaxis.set_major_formatter(mdates.DateFormatter(time_format)) # 根据时间间隔设置刻度间隔 if config.interval_minutes <= 5: # 1m, 3m, 5m ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=1)) elif config.interval_minutes <= 30: # 15m, 30m ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=2)) elif config.interval_minutes <= 240: # 1h, 2h, 4h ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=6)) elif config.interval_minutes <= 720: # 6h, 8h, 12h ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=12)) else: # 1d ax_volume.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 旋转时间标签 plt.setp(ax_volume.xaxis.get_majorticklabels(), rotation=45, ha='right') # 添加分割线(历史/预测边界) boundary_time = pred_timestamps[0] ax_price.axvline(x=boundary_time, color='gray', linestyle=':', linewidth=2, alpha=0.8) ax_volume.axvline(x=boundary_time, color='gray', linestyle=':', linewidth=2, alpha=0.8) # 添加分割线标签 ax_price.text(boundary_time, ax_price.get_ylim()[1] * 0.95, 'Prediction Start', rotation=90, ha='right', va='top', fontsize=10, bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8)) # 整体布局优化 plt.tight_layout() # Save chart timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"eth_usdt_prediction_{timestamp}.png" plt.savefig(filename, dpi=300, bbox_inches='tight', facecolor='white') print(f"Prediction chart saved as: {filename}") # Also save SVG format (vector graphics) svg_filename = f"eth_usdt_prediction_{timestamp}.svg" plt.savefig(svg_filename, format='svg', bbox_inches='tight', facecolor='white') print(f"Vector chart saved as: {svg_filename}") # plt.show() # Commented out to avoid GUI issues print("📊 Chart generation completed!") def main(): print("🚀 ETH-USDT 实时预测系统启动") print("=" * 50) # 0. 获取配置参数 config = get_user_config() # 1. 初始化数据获取器 print("📡 初始化数据获取器...") fetcher = BinanceDataFetcher() # 2. 获取实时数据 print("📊 获取ETH-USDT实时数据...") try: df = fetcher.get_klines(symbol="ETHUSDT", interval=config.interval, limit=config.lookback) print(f"✅ 成功获取 {len(df)} 条数据") print(f"📅 数据时间范围: {df['timestamps'].min()} 到 {df['timestamps'].max()}") print(f"💰 当前价格: ${df['close'].iloc[-1]:.2f}") # 计算24小时涨跌(根据时间间隔调整) periods_24h = config.get_24h_periods() if len(df) >= periods_24h: price_24h_ago = df['close'].iloc[-periods_24h] change_24h = ((df['close'].iloc[-1] / price_24h_ago - 1) * 100) print(f"📈 24h涨跌: {change_24h:.2f}%") else: print(f"📈 数据不足24小时,无法计算24h涨跌") except Exception as e: print(f"❌ 数据获取失败: {e}") return # 3. 加载模型 print("\n🤖 加载Kronos模型...") try: tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") model = Kronos.from_pretrained("NeoQuasar/Kronos-base") predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) print("✅ 模型加载成功") except Exception as e: print(f"❌ 模型加载失败: {e}") return # 4. 准备预测数据 print("\n📋 准备预测数据...") if len(df) < config.lookback: print(f"❌ 数据不足,需要至少{config.lookback}条数据,当前只有{len(df)}条") return # 选择最近的数据 recent_df = df.tail(config.lookback + config.pred_len).copy() # 历史数据 x_df = recent_df.iloc[:config.lookback][['open', 'high', 'low', 'close', 'volume', 'amount']] x_timestamp = recent_df.iloc[:config.lookback]['timestamps'] # 生成未来时间戳(根据配置的时间间隔) last_time = x_timestamp.iloc[-1] future_timestamps = pd.date_range( start=last_time + timedelta(minutes=config.interval_minutes), periods=config.pred_len, freq=config.get_freq_string() ) # 转换为pandas Series以匹配模型期望的格式 future_timestamps = pd.Series(future_timestamps) duration_hours = config.get_prediction_duration_hours() duration_text = f"{duration_hours:.1f}小时" if duration_hours >= 1 else f"{int(duration_hours * 60)}分钟" print(f"📊 使用 {len(x_df)} 条历史数据") print(f"🔮 预测未来 {config.pred_len} 个时间点({duration_text})") print(f"⏰ 预测时间范围: {future_timestamps.iloc[0]} 到 {future_timestamps.iloc[-1]}") # 5. 执行预测 print("\n🔮 开始预测...") start_time = time.time() try: pred_df = predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=future_timestamps, pred_len=config.pred_len, T=0.8, # 降低温度以获得更稳定的预测 top_p=0.9, # 核采样 sample_count=3, # 多次采样取平均 verbose=True ) prediction_time = time.time() - start_time print(f"✅ 预测完成,耗时: {prediction_time:.1f}秒") except Exception as e: print(f"❌ 预测失败: {e}") return # 6. 分析预测结果 print("\n📈 预测结果分析:") print("=" * 30) current_price = x_df['close'].iloc[-1] pred_prices = pred_df['close'] print(f"当前价格: ${current_price:.2f}") # 使用配置的时间点进行分析 time_points = config.get_prediction_time_points() for idx, time_desc in time_points: if idx < len(pred_prices): price = pred_prices.iloc[idx] change_pct = ((price / current_price - 1) * 100) print(f"{time_desc}预测: ${price:.2f} ({change_pct:+.2f}%)") # 价格趋势分析 price_trend = "上涨" if pred_prices.iloc[-1] > current_price else "下跌" max_price = pred_prices.max() min_price = pred_prices.min() print(f"\n📊 趋势分析:") print(f"整体趋势: {price_trend}") print(f"预测最高价: ${max_price:.2f}") print(f"预测最低价: ${min_price:.2f}") print(f"价格波动范围: {((max_price - min_price) / current_price * 100):.2f}%") # 7. 可视化结果 print("\n📊 生成预测图表...") try: plot_prediction_with_realtime(recent_df.iloc[:config.lookback], pred_df, config, "ETH-USDT") except Exception as e: print(f"⚠️ 图表生成失败: {e}") # 8. 保存预测结果 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") result_file = f"eth_usdt_prediction_{config.interval}_{timestamp}.csv" pred_df.to_csv(result_file) print(f"💾 预测结果已保存为: {result_file}") print("\n🎉 预测完成!") if __name__ == "__main__": main()