File size: 5,302 Bytes
faa2d08
 
 
 
 
2ebc037
faa2d08
 
2ebc037
faa2d08
2ebc037
 
faa2d08
 
 
 
 
 
 
 
 
 
 
 
2ebc037
 
 
 
faa2d08
 
 
 
 
 
 
 
 
 
2ebc037
 
 
faa2d08
2ebc037
 
 
 
 
faa2d08
 
 
 
 
 
 
 
 
 
 
 
 
 
2ebc037
 
faa2d08
2ebc037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faa2d08
2ebc037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faa2d08
 
 
2ebc037
 
 
faa2d08
 
2ebc037
 
faa2d08
 
2ebc037
 
 
 
 
 
faa2d08
 
4fe609d
faa2d08
 
 
 
2ebc037
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import requests
import pandas as pd
import plotly.graph_objects as go
from ultralytics import YOLO
import cv2
import os
import gradio as gr

API_KEY = "ITWJ6NDTF45CBTDO"  # Consider using environment variables for API keys

def get_stock_candlestick_data(symbol, interval="1min", output_size="compact"):
    """Fetch stock candlestick data from Alpha Vantage."""
    url = f"https://www.alphavantage.co/query?function=TIME_SERIES_INTRADAY&symbol={symbol}&interval={interval}&apikey={API_KEY}&outputsize={output_size}"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        if f"Time Series ({interval})" in data:
            return data[f"Time Series ({interval})"]
        else:
            return None
    else:
        return None

def process_stock_candlestick_data(data):
    """Process Alpha Vantage stock candlestick data into a DataFrame."""
    if not data:
        return None
        
    rows = []
    for timestamp, values in data.items():
        rows.append({
            "timestamp": timestamp,
            "open": float(values["1. open"]),
            "high": float(values["2. high"]),
            "low": float(values["3. low"]),
            "close": float(values["4. close"]),
            "volume": float(values["5. volume"])
        })
    df = pd.DataFrame(rows)
    df = df.sort_values("timestamp")  # Ensure chronological order
    return df

def generate_candlestick_chart(df, n=50, output_path="candlestick.png"):
    """Generate a candlestick chart using Plotly with the last n data points."""
    if df is None or len(df) == 0:
        return None
        
    df = df.tail(n)  # Use only the last n rows
    fig = go.Figure(data=[go.Candlestick(
        x=df["timestamp"],
        open=df["open"],
        high=df["high"],
        low=df["low"],
        close=df["close"]
    )])
    fig.update_layout(
        title="Candlestick Chart",
        xaxis_title="Time",
        yaxis_title="Price",
        xaxis_rangeslider_visible=False
    )
    fig.write_image(output_path)
    return output_path

def yolo_model(img_path, model_path):
    """Run YOLO model on the image and count GAP UP and GAP DOWN patterns."""
    if not os.path.exists(img_path):
        return None, 0, 0
        
    # Load model each time to avoid persistence issues in Spaces
    try:
        model = YOLO(model_path)
        results = model(img_path)
        gap_up_count = 0
        gap_down_count = 0
        
        for result in results:
            boxes = result.boxes
            if hasattr(boxes, 'cls') and len(boxes.cls) > 0:
                classes = boxes.cls.cpu().numpy() if hasattr(boxes.cls, 'cpu') else boxes.cls
                for cls in classes:
                    if int(cls) == 0:
                        gap_down_count += 1
                    elif int(cls) == 1:
                        gap_up_count += 1
                        
        annotated_image = results[0].plot()
        output_path = "annotated_output.png"
        cv2.imwrite(output_path, annotated_image)
        return output_path, gap_up_count, gap_down_count
    except Exception as e:
        print(f"Error running YOLO model: {e}")
        return None, 0, 0

def detect_gap_patterns(symbol, model_path="best.pt"):
    """Non-streaming function to fetch data, generate charts, and detect GAP patterns."""
    # Check if the model file exists
    if not os.path.exists(model_path):
        return None, f"Model not found at {model_path}", f"Model not found at {model_path}"
    
    # Get stock data
    data = get_stock_candlestick_data(symbol)
    if not data:
        return None, "Failed to fetch stock data", "Failed to fetch stock data"
    
    # Process data and generate chart
    df = process_stock_candlestick_data(data)
    if df is None or len(df) == 0:
        return None, "No valid stock data available", "No valid stock data available"
    
    chart_path = generate_candlestick_chart(df, n=50)
    if not chart_path or not os.path.exists(chart_path):
        return None, "Failed to generate chart", "Failed to generate chart"
    
    # Run YOLO detection
    annotated_path, gap_up_count, gap_down_count = yolo_model(chart_path, model_path)
    if not annotated_path:
        return None, "Failed to run detection model", "Failed to run detection model"
    
    return annotated_path, f"GAP UP Count: {gap_up_count}", f"GAP DOWN Count: {gap_down_count}"

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# GAP Pattern Detection in Stock Charts")
    gr.Markdown("Enter a stock symbol (e.g., AAPL) to detect GAP UP and GAP DOWN patterns in candlestick charts.")

    with gr.Row():
        symbol_input = gr.Textbox(label="Stock Symbol", placeholder="Enter a stock symbol (e.g., AAPL)")
        submit_button = gr.Button("Detect Patterns")

    with gr.Row():
        output_image = gr.Image(label="Annotated Candlestick Chart")
        
    with gr.Row():
        gap_up_output = gr.Textbox(label="GAP UP Results")
        gap_down_output = gr.Textbox(label="GAP DOWN Results")

    # Run detection when the button is clicked
    submit_button.click(
        fn=detect_gap_patterns,
        inputs=symbol_input,
        outputs=[output_image, gap_up_output, gap_down_output]
    )

# Launch the Gradio app
demo.launch()