import requests import pandas as pd import plotly.graph_objects as go from ultralytics import YOLO import cv2 import os import gradio as gr API_KEY = "ITWJ6NDTF45CBTDO" # Consider using environment variables for API keys def get_stock_candlestick_data(symbol, interval="1min", output_size="compact"): """Fetch stock candlestick data from Alpha Vantage.""" url = f"https://www.alphavantage.co/query?function=TIME_SERIES_INTRADAY&symbol={symbol}&interval={interval}&apikey={API_KEY}&outputsize={output_size}" response = requests.get(url) if response.status_code == 200: data = response.json() if f"Time Series ({interval})" in data: return data[f"Time Series ({interval})"] else: return None else: return None def process_stock_candlestick_data(data): """Process Alpha Vantage stock candlestick data into a DataFrame.""" if not data: return None rows = [] for timestamp, values in data.items(): rows.append({ "timestamp": timestamp, "open": float(values["1. open"]), "high": float(values["2. high"]), "low": float(values["3. low"]), "close": float(values["4. close"]), "volume": float(values["5. volume"]) }) df = pd.DataFrame(rows) df = df.sort_values("timestamp") # Ensure chronological order return df def generate_candlestick_chart(df, n=50, output_path="candlestick.png"): """Generate a candlestick chart using Plotly with the last n data points.""" if df is None or len(df) == 0: return None df = df.tail(n) # Use only the last n rows fig = go.Figure(data=[go.Candlestick( x=df["timestamp"], open=df["open"], high=df["high"], low=df["low"], close=df["close"] )]) fig.update_layout( title="Candlestick Chart", xaxis_title="Time", yaxis_title="Price", xaxis_rangeslider_visible=False ) fig.write_image(output_path) return output_path def yolo_model(img_path, model_path): """Run YOLO model on the image and count GAP UP and GAP DOWN patterns.""" if not os.path.exists(img_path): return None, 0, 0 # Load model each time to avoid persistence issues in Spaces try: model = YOLO(model_path) results = model(img_path) gap_up_count = 0 gap_down_count = 0 for result in results: boxes = result.boxes if hasattr(boxes, 'cls') and len(boxes.cls) > 0: classes = boxes.cls.cpu().numpy() if hasattr(boxes.cls, 'cpu') else boxes.cls for cls in classes: if int(cls) == 0: gap_down_count += 1 elif int(cls) == 1: gap_up_count += 1 annotated_image = results[0].plot() output_path = "annotated_output.png" cv2.imwrite(output_path, annotated_image) return output_path, gap_up_count, gap_down_count except Exception as e: print(f"Error running YOLO model: {e}") return None, 0, 0 def detect_gap_patterns(symbol, model_path="best.pt"): """Non-streaming function to fetch data, generate charts, and detect GAP patterns.""" # Check if the model file exists if not os.path.exists(model_path): return None, f"Model not found at {model_path}", f"Model not found at {model_path}" # Get stock data data = get_stock_candlestick_data(symbol) if not data: return None, "Failed to fetch stock data", "Failed to fetch stock data" # Process data and generate chart df = process_stock_candlestick_data(data) if df is None or len(df) == 0: return None, "No valid stock data available", "No valid stock data available" chart_path = generate_candlestick_chart(df, n=50) if not chart_path or not os.path.exists(chart_path): return None, "Failed to generate chart", "Failed to generate chart" # Run YOLO detection annotated_path, gap_up_count, gap_down_count = yolo_model(chart_path, model_path) if not annotated_path: return None, "Failed to run detection model", "Failed to run detection model" return annotated_path, f"GAP UP Count: {gap_up_count}", f"GAP DOWN Count: {gap_down_count}" # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# GAP Pattern Detection in Stock Charts") gr.Markdown("Enter a stock symbol (e.g., AAPL) to detect GAP UP and GAP DOWN patterns in candlestick charts.") with gr.Row(): symbol_input = gr.Textbox(label="Stock Symbol", placeholder="Enter a stock symbol (e.g., AAPL)") submit_button = gr.Button("Detect Patterns") with gr.Row(): output_image = gr.Image(label="Annotated Candlestick Chart") with gr.Row(): gap_up_output = gr.Textbox(label="GAP UP Results") gap_down_output = gr.Textbox(label="GAP DOWN Results") # Run detection when the button is clicked submit_button.click( fn=detect_gap_patterns, inputs=symbol_input, outputs=[output_image, gap_up_output, gap_down_output] ) # Launch the Gradio app demo.launch()