|
import os
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from datetime import datetime, timedelta
|
|
|
|
import matplotlib.dates as mdates
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import requests
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
|
|
|
|
|
|
class PredictionConfig:
|
|
"""预测配置类"""
|
|
|
|
|
|
SUPPORTED_INTERVALS = {
|
|
'1m': 1, '3m': 3, '5m': 5, '15m': 15, '30m': 30,
|
|
'1h': 60, '2h': 120, '4h': 240, '6h': 360, '8h': 480, '12h': 720, '1d': 1440
|
|
}
|
|
|
|
def __init__(self, interval='15m', pred_len=24, lookback=400):
|
|
self.interval = interval
|
|
self.pred_len = pred_len
|
|
self.lookback = lookback
|
|
self.interval_minutes = self.SUPPORTED_INTERVALS.get(interval, 15)
|
|
|
|
|
|
if interval not in self.SUPPORTED_INTERVALS:
|
|
raise ValueError(f"不支持的时间间隔: {interval}. 支持的间隔: {list(self.SUPPORTED_INTERVALS.keys())}")
|
|
|
|
def get_24h_periods(self):
|
|
"""计算24小时对应的K线数量"""
|
|
return int(24 * 60 / self.interval_minutes)
|
|
|
|
def get_prediction_duration_hours(self):
|
|
"""计算预测时长(小时)"""
|
|
return (self.pred_len * self.interval_minutes) / 60
|
|
|
|
def get_freq_string(self):
|
|
"""获取pandas频率字符串"""
|
|
freq_map = {
|
|
'1m': '1T', '3m': '3T', '5m': '5T', '15m': '15T', '30m': '30T',
|
|
'1h': '1H', '2h': '2H', '4h': '4H', '6h': '6H', '8h': '8H', '12h': '12H', '1d': '1D'
|
|
}
|
|
return freq_map.get(self.interval, '15T')
|
|
|
|
def get_time_format(self):
|
|
"""根据时间间隔获取最佳时间显示格式"""
|
|
if self.interval_minutes <= 60:
|
|
return '%H:%M'
|
|
elif self.interval_minutes <= 1440:
|
|
return '%m-%d %H:%M'
|
|
else:
|
|
return '%Y-%m-%d'
|
|
|
|
def get_prediction_time_points(self):
|
|
"""Get prediction time point descriptions"""
|
|
duration_hours = self.get_prediction_duration_hours()
|
|
|
|
if duration_hours < 1:
|
|
return [
|
|
(int(self.pred_len * 0.25), f"{int(duration_hours * 0.25 * 60)}min later"),
|
|
(int(self.pred_len * 0.5), f"{int(duration_hours * 0.5 * 60)}min later"),
|
|
(self.pred_len - 1, f"{int(duration_hours * 60)}min later")
|
|
]
|
|
elif duration_hours <= 24:
|
|
return [
|
|
(int(self.pred_len * 0.25), f"{duration_hours * 0.25:.1f}h later"),
|
|
(int(self.pred_len * 0.5), f"{duration_hours * 0.5:.1f}h later"),
|
|
(self.pred_len - 1, f"{duration_hours:.1f}h later")
|
|
]
|
|
else:
|
|
days = duration_hours / 24
|
|
return [
|
|
(int(self.pred_len * 0.25), f"{days * 0.25:.1f}d later"),
|
|
(int(self.pred_len * 0.5), f"{days * 0.5:.1f}d later"),
|
|
(self.pred_len - 1, f"{days:.1f}d later")
|
|
]
|
|
|
|
|
|
def get_user_config():
|
|
"""获取用户配置(可扩展为交互式输入)"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = PredictionConfig(interval='15m', pred_len=24, lookback=512)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("📋 当前预测配置:")
|
|
print(f" 时间间隔: {config.interval}")
|
|
print(f" 预测长度: {config.pred_len} 步")
|
|
print(f" 预测时长: {config.get_prediction_duration_hours():.1f} 小时")
|
|
print(f" 历史数据: {config.lookback} 个数据点")
|
|
print(f" 支持的间隔: {list(config.SUPPORTED_INTERVALS.keys())}")
|
|
print()
|
|
|
|
return config
|
|
|
|
|
|
class BinanceDataFetcher:
|
|
"""币安数据获取器 - 中国可访问"""
|
|
|
|
def __init__(self):
|
|
self.base_url = "https://api.binance.com"
|
|
|
|
self.backup_urls = [
|
|
"https://api1.binance.com",
|
|
"https://api2.binance.com",
|
|
"https://api3.binance.com"
|
|
]
|
|
|
|
def get_klines(self, symbol="ETHUSDT", interval="15m", limit=500):
|
|
"""
|
|
获取K线数据
|
|
|
|
Args:
|
|
symbol: 交易对符号 (如 ETHUSDT)
|
|
interval: 时间间隔 (1m, 3m, 5m, 15m, 30m, 1h, 2h, 4h, 6h, 8h, 12h, 1d, 3d, 1w, 1M)
|
|
limit: 数据条数 (最大1000)
|
|
"""
|
|
endpoint = "/api/v3/klines"
|
|
params = {
|
|
'symbol': symbol,
|
|
'interval': interval,
|
|
'limit': limit
|
|
}
|
|
|
|
|
|
for url in [self.base_url] + self.backup_urls:
|
|
try:
|
|
response = requests.get(url + endpoint, params=params, timeout=10)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
return self._parse_klines(data)
|
|
else:
|
|
print(f"API返回错误: {response.status_code}")
|
|
except Exception as e:
|
|
print(f"尝试URL {url} 失败: {e}")
|
|
continue
|
|
|
|
raise Exception("所有API URL都无法访问,请检查网络连接")
|
|
|
|
def _parse_klines(self, raw_data):
|
|
"""解析K线数据"""
|
|
df = pd.DataFrame(raw_data, columns=[
|
|
'timestamp', 'open', 'high', 'low', 'close', 'volume',
|
|
'close_time', 'quote_asset_volume', 'number_of_trades',
|
|
'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore'
|
|
])
|
|
|
|
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
|
for col in ['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume']:
|
|
df[col] = df[col].astype(float)
|
|
|
|
|
|
df = df.rename(columns={
|
|
'timestamp': 'timestamps',
|
|
'quote_asset_volume': 'amount'
|
|
})
|
|
|
|
|
|
df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']]
|
|
|
|
return df.sort_values('timestamps').reset_index(drop=True)
|
|
|
|
|
|
def plot_prediction_with_realtime(historical_df, pred_df, config, symbol="ETH-USDT"):
|
|
"""Plot professional financial chart style real-time prediction results"""
|
|
|
|
|
|
plt.rcParams['font.family'] = 'DejaVu Sans'
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
|
|
hist_timestamps = pd.to_datetime(historical_df['timestamps'])
|
|
pred_timestamps = pd.to_datetime(pred_df.index) if hasattr(pred_df, 'index') else pd.date_range(
|
|
start=hist_timestamps.iloc[-1] + pd.Timedelta(minutes=config.interval_minutes),
|
|
periods=len(pred_df),
|
|
freq=config.get_freq_string()
|
|
)
|
|
|
|
|
|
hist_close = historical_df['close'].values
|
|
pred_close = pred_df['close'].values
|
|
hist_volume = historical_df['volume'].values
|
|
pred_volume = pred_df['volume'].values
|
|
|
|
|
|
current_price = hist_close[-1]
|
|
final_pred_price = pred_close[-1]
|
|
price_change = final_pred_price - current_price
|
|
price_change_pct = (price_change / current_price) * 100
|
|
max_pred_price = np.max(pred_close)
|
|
min_pred_price = np.min(pred_close)
|
|
volatility = ((max_pred_price - min_pred_price) / current_price) * 100
|
|
|
|
|
|
fig = plt.figure(figsize=(16, 10))
|
|
gs = fig.add_gridspec(3, 4, height_ratios=[0.8, 2, 1], width_ratios=[1, 1, 1, 1],
|
|
hspace=0.3, wspace=0.2)
|
|
|
|
|
|
ax_info = fig.add_subplot(gs[0, :])
|
|
ax_info.axis('off')
|
|
|
|
|
|
ax_price = fig.add_subplot(gs[1, :])
|
|
|
|
|
|
ax_volume = fig.add_subplot(gs[2, :], sharex=ax_price)
|
|
|
|
|
|
duration_text = f"{config.get_prediction_duration_hours():.1f}h" if config.get_prediction_duration_hours() >= 1 else f"{int(config.get_prediction_duration_hours() * 60)}min"
|
|
info_text = f"""
|
|
{symbol} {config.interval} Candlestick Prediction Analysis | {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
|
Config: {config.pred_len}-step prediction ({duration_text}) | Historical data: {config.lookback} points
|
|
|
|
Current Price: ${current_price:.2f} Predicted Price: ${final_pred_price:.2f} Change: {price_change:+.2f} ({price_change_pct:+.2f}%)
|
|
Predicted High: ${max_pred_price:.2f} Predicted Low: ${min_pred_price:.2f} Volatility: {volatility:.2f}%
|
|
"""
|
|
ax_info.text(0.5, 0.5, info_text, transform=ax_info.transAxes,
|
|
fontsize=12, ha='center', va='center',
|
|
bbox=dict(boxstyle="round,pad=0.5", facecolor='lightblue', alpha=0.8))
|
|
|
|
|
|
|
|
ax_price.plot(hist_timestamps, hist_close,
|
|
color='#1f77b4', linewidth=2, label='Historical Price', alpha=0.8)
|
|
|
|
|
|
ax_price.plot(pred_timestamps, pred_close,
|
|
color='#ff7f0e', linewidth=2.5, linestyle='--',
|
|
label='Predicted Price', alpha=0.9)
|
|
|
|
|
|
y_min, y_max = ax_price.get_ylim()
|
|
pred_start = pred_timestamps[0]
|
|
pred_end = pred_timestamps[-1]
|
|
ax_price.axvspan(pred_start, pred_end, alpha=0.1, color='orange', label='Prediction Zone')
|
|
|
|
|
|
|
|
ax_price.scatter(hist_timestamps.iloc[-1], current_price,
|
|
color='blue', s=100, zorder=5, marker='o')
|
|
ax_price.annotate(f'Current: ${current_price:.2f}',
|
|
xy=(hist_timestamps.iloc[-1], current_price),
|
|
xytext=(10, 10), textcoords='offset points',
|
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue'),
|
|
arrowprops=dict(arrowstyle='->', color='blue'))
|
|
|
|
|
|
ax_price.scatter(pred_timestamps[-1], final_pred_price,
|
|
color='red', s=100, zorder=5, marker='s')
|
|
ax_price.annotate(f'Predicted: ${final_pred_price:.2f}',
|
|
xy=(pred_timestamps[-1], final_pred_price),
|
|
xytext=(-80, 10), textcoords='offset points',
|
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='lightyellow'),
|
|
arrowprops=dict(arrowstyle='->', color='red'))
|
|
|
|
|
|
max_idx = np.argmax(pred_close)
|
|
min_idx = np.argmin(pred_close)
|
|
|
|
ax_price.scatter(pred_timestamps[max_idx], max_pred_price,
|
|
color='green', s=80, zorder=5, marker='^')
|
|
ax_price.annotate(f'High: ${max_pred_price:.2f}',
|
|
xy=(pred_timestamps[max_idx], max_pred_price),
|
|
xytext=(0, 15), textcoords='offset points',
|
|
ha='center', fontsize=9,
|
|
bbox=dict(boxstyle='round,pad=0.2', facecolor='lightgreen'))
|
|
|
|
ax_price.scatter(pred_timestamps[min_idx], min_pred_price,
|
|
color='red', s=80, zorder=5, marker='v')
|
|
ax_price.annotate(f'Low: ${min_pred_price:.2f}',
|
|
xy=(pred_timestamps[min_idx], min_pred_price),
|
|
xytext=(0, -20), textcoords='offset points',
|
|
ha='center', fontsize=9,
|
|
bbox=dict(boxstyle='round,pad=0.2', facecolor='lightcoral'))
|
|
|
|
|
|
if price_change > 0:
|
|
ax_price.annotate('', xy=(pred_timestamps[-1], final_pred_price),
|
|
xytext=(pred_timestamps[0], pred_close[0]),
|
|
arrowprops=dict(arrowstyle='->', color='green', lw=2, alpha=0.6))
|
|
else:
|
|
ax_price.annotate('', xy=(pred_timestamps[-1], final_pred_price),
|
|
xytext=(pred_timestamps[0], pred_close[0]),
|
|
arrowprops=dict(arrowstyle='->', color='red', lw=2, alpha=0.6))
|
|
|
|
|
|
|
|
bar_width = pd.Timedelta(minutes=config.interval_minutes * 0.8)
|
|
|
|
|
|
ax_volume.bar(hist_timestamps, hist_volume,
|
|
width=bar_width, color='#1f77b4',
|
|
alpha=0.6, label='Historical Volume')
|
|
|
|
|
|
ax_volume.bar(pred_timestamps, pred_volume,
|
|
width=bar_width, color='#ff7f0e',
|
|
alpha=0.7, label='Predicted Volume')
|
|
|
|
|
|
|
|
ax_price.set_ylabel('Price (USDT)', fontsize=12, fontweight='bold')
|
|
ax_price.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
|
|
ax_price.legend(loc='upper left', frameon=True, fancybox=True, shadow=True)
|
|
ax_price.set_facecolor('#fafafa')
|
|
|
|
|
|
ax_volume.set_ylabel('Volume (ETH)', fontsize=12, fontweight='bold')
|
|
ax_volume.set_xlabel('Time', fontsize=12, fontweight='bold')
|
|
ax_volume.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
|
|
ax_volume.legend(loc='upper left', frameon=True, fancybox=True, shadow=True)
|
|
ax_volume.set_facecolor('#fafafa')
|
|
|
|
|
|
time_format = config.get_time_format()
|
|
ax_volume.xaxis.set_major_formatter(mdates.DateFormatter(time_format))
|
|
|
|
|
|
if config.interval_minutes <= 5:
|
|
ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=1))
|
|
elif config.interval_minutes <= 30:
|
|
ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=2))
|
|
elif config.interval_minutes <= 240:
|
|
ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=6))
|
|
elif config.interval_minutes <= 720:
|
|
ax_volume.xaxis.set_major_locator(mdates.HourLocator(interval=12))
|
|
else:
|
|
ax_volume.xaxis.set_major_locator(mdates.DayLocator(interval=1))
|
|
|
|
|
|
plt.setp(ax_volume.xaxis.get_majorticklabels(), rotation=45, ha='right')
|
|
|
|
|
|
boundary_time = pred_timestamps[0]
|
|
ax_price.axvline(x=boundary_time, color='gray', linestyle=':', linewidth=2, alpha=0.8)
|
|
ax_volume.axvline(x=boundary_time, color='gray', linestyle=':', linewidth=2, alpha=0.8)
|
|
|
|
|
|
ax_price.text(boundary_time, ax_price.get_ylim()[1] * 0.95, 'Prediction Start',
|
|
rotation=90, ha='right', va='top', fontsize=10,
|
|
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8))
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"eth_usdt_prediction_{timestamp}.png"
|
|
plt.savefig(filename, dpi=300, bbox_inches='tight', facecolor='white')
|
|
print(f"Prediction chart saved as: {filename}")
|
|
|
|
|
|
svg_filename = f"eth_usdt_prediction_{timestamp}.svg"
|
|
plt.savefig(svg_filename, format='svg', bbox_inches='tight', facecolor='white')
|
|
print(f"Vector chart saved as: {svg_filename}")
|
|
|
|
|
|
print("📊 Chart generation completed!")
|
|
|
|
|
|
def main():
|
|
print("🚀 ETH-USDT 实时预测系统启动")
|
|
print("=" * 50)
|
|
|
|
|
|
config = get_user_config()
|
|
|
|
|
|
print("📡 初始化数据获取器...")
|
|
fetcher = BinanceDataFetcher()
|
|
|
|
|
|
print("📊 获取ETH-USDT实时数据...")
|
|
try:
|
|
df = fetcher.get_klines(symbol="ETHUSDT", interval=config.interval, limit=config.lookback)
|
|
print(f"✅ 成功获取 {len(df)} 条数据")
|
|
print(f"📅 数据时间范围: {df['timestamps'].min()} 到 {df['timestamps'].max()}")
|
|
print(f"💰 当前价格: ${df['close'].iloc[-1]:.2f}")
|
|
|
|
|
|
periods_24h = config.get_24h_periods()
|
|
if len(df) >= periods_24h:
|
|
price_24h_ago = df['close'].iloc[-periods_24h]
|
|
change_24h = ((df['close'].iloc[-1] / price_24h_ago - 1) * 100)
|
|
print(f"📈 24h涨跌: {change_24h:.2f}%")
|
|
else:
|
|
print(f"📈 数据不足24小时,无法计算24h涨跌")
|
|
except Exception as e:
|
|
print(f"❌ 数据获取失败: {e}")
|
|
return
|
|
|
|
|
|
print("\n🤖 加载Kronos模型...")
|
|
try:
|
|
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
|
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
|
|
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
|
|
print("✅ 模型加载成功")
|
|
except Exception as e:
|
|
print(f"❌ 模型加载失败: {e}")
|
|
return
|
|
|
|
|
|
print("\n📋 准备预测数据...")
|
|
|
|
if len(df) < config.lookback:
|
|
print(f"❌ 数据不足,需要至少{config.lookback}条数据,当前只有{len(df)}条")
|
|
return
|
|
|
|
|
|
recent_df = df.tail(config.lookback + config.pred_len).copy()
|
|
|
|
|
|
x_df = recent_df.iloc[:config.lookback][['open', 'high', 'low', 'close', 'volume', 'amount']]
|
|
x_timestamp = recent_df.iloc[:config.lookback]['timestamps']
|
|
|
|
|
|
last_time = x_timestamp.iloc[-1]
|
|
future_timestamps = pd.date_range(
|
|
start=last_time + timedelta(minutes=config.interval_minutes),
|
|
periods=config.pred_len,
|
|
freq=config.get_freq_string()
|
|
)
|
|
|
|
future_timestamps = pd.Series(future_timestamps)
|
|
|
|
duration_hours = config.get_prediction_duration_hours()
|
|
duration_text = f"{duration_hours:.1f}小时" if duration_hours >= 1 else f"{int(duration_hours * 60)}分钟"
|
|
|
|
print(f"📊 使用 {len(x_df)} 条历史数据")
|
|
print(f"🔮 预测未来 {config.pred_len} 个时间点({duration_text})")
|
|
print(f"⏰ 预测时间范围: {future_timestamps.iloc[0]} 到 {future_timestamps.iloc[-1]}")
|
|
|
|
|
|
print("\n🔮 开始预测...")
|
|
start_time = time.time()
|
|
|
|
try:
|
|
pred_df = predictor.predict(
|
|
df=x_df,
|
|
x_timestamp=x_timestamp,
|
|
y_timestamp=future_timestamps,
|
|
pred_len=config.pred_len,
|
|
T=0.8,
|
|
top_p=0.9,
|
|
sample_count=3,
|
|
verbose=True
|
|
)
|
|
|
|
prediction_time = time.time() - start_time
|
|
print(f"✅ 预测完成,耗时: {prediction_time:.1f}秒")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 预测失败: {e}")
|
|
return
|
|
|
|
|
|
print("\n📈 预测结果分析:")
|
|
print("=" * 30)
|
|
|
|
current_price = x_df['close'].iloc[-1]
|
|
pred_prices = pred_df['close']
|
|
|
|
print(f"当前价格: ${current_price:.2f}")
|
|
|
|
|
|
time_points = config.get_prediction_time_points()
|
|
for idx, time_desc in time_points:
|
|
if idx < len(pred_prices):
|
|
price = pred_prices.iloc[idx]
|
|
change_pct = ((price / current_price - 1) * 100)
|
|
print(f"{time_desc}预测: ${price:.2f} ({change_pct:+.2f}%)")
|
|
|
|
|
|
price_trend = "上涨" if pred_prices.iloc[-1] > current_price else "下跌"
|
|
max_price = pred_prices.max()
|
|
min_price = pred_prices.min()
|
|
|
|
print(f"\n📊 趋势分析:")
|
|
print(f"整体趋势: {price_trend}")
|
|
print(f"预测最高价: ${max_price:.2f}")
|
|
print(f"预测最低价: ${min_price:.2f}")
|
|
print(f"价格波动范围: {((max_price - min_price) / current_price * 100):.2f}%")
|
|
|
|
|
|
print("\n📊 生成预测图表...")
|
|
try:
|
|
plot_prediction_with_realtime(recent_df.iloc[:config.lookback], pred_df, config, "ETH-USDT")
|
|
except Exception as e:
|
|
print(f"⚠️ 图表生成失败: {e}")
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
result_file = f"eth_usdt_prediction_{config.interval}_{timestamp}.csv"
|
|
pred_df.to_csv(result_file)
|
|
print(f"💾 预测结果已保存为: {result_file}")
|
|
|
|
print("\n🎉 预测完成!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|