import gradio as gr from openai import OpenAI import snowflake.connector import os import json from decimal import Decimal from datetime import date, datetime from urllib.parse import urlencode from utils.functions import ( intraday_stock_prices, daily_stock_prices, get_income_statement, ticker_search, company_profile, current_market_cap, historical_market_cap, analyst_recommendations, stock_peers, earnings_historical_and_upcoming ) # Initialize OpenAI client client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) def fetch_trades_of_the_day(): """ Fetches RSI data from Snowflake, calculates future return percentages based on a lead value, and returns the data as a formatted JSON string. Parameters: - horizon (int): Number of days to use in the lead function to calculate future return. Default is 21. Returns: - json_data (str): Formatted JSON string containing the RSI data. """ def custom_json_serializer(obj): """ Custom JSON serializer for handling date objects and Decimal types """ if isinstance(obj, (datetime, date)): return obj.isoformat() # Convert date/datetime to ISO format elif isinstance(obj, Decimal): return float(obj) # Convert Decimal to float raise TypeError(f"Type {type(obj)} not serializable") try: # Establish connection to Snowflake conn = snowflake.connector.connect( user=os.environ['SNOWFLAKE_USER'], password=os.environ['SNOWFLAKE_PW'], account=os.environ['SNOWFLAKE_ACCOUNT'], warehouse=os.environ['SNOWFLAKE_WH'], database=os.environ['SNOWFLAKE_DB'], schema=os.environ['SNOWFLAKE_SCHEMA'] ) # Define the query # query = os.environ['QUERY'] query = "select BEST_TRADE_STRING from RESEARCHDATA.RSI_TRADE_OF_THE_DAY rs order by rk desc;" # Execute the query and fetch data cur = conn.cursor() rows = cur.execute(query).fetchall() columns = [desc[0] for desc in cur.description] # Get column names # Close the cursor and connection cur.close() conn.close() # Convert the rows into a list of dictionaries (for JSON serialization) result = [dict(zip(columns, row)) for row in rows] # Convert the result to a formatted JSON string, with the custom serializer json_data = json.dumps(result, indent=4, default=custom_json_serializer) return json_data except Exception as e: print(f"Failed to connect to Snowflake: {e}") return None # Function to interact with the OpenAI assistant def interact_with_assistant(user_input): thread = client.beta.threads.create() client.beta.threads.messages.create( thread_id=thread.id, role="user", content=user_input, ) run = client.beta.threads.runs.create( thread_id=thread.id, assistant_id= os.environ['ASSISTANT_ID'], ) while run.status != 'completed': run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) if run.status == 'requires_action': tool_outputs = [] for tool_call in run.required_action.submit_tool_outputs.tool_calls: if tool_call.function.name == "fetch_trades_of_the_day": output = fetch_trades_of_the_day() tool_outputs.append({"tool_call_id": tool_call.id, "output": output}) client.beta.threads.runs.submit_tool_outputs( thread_id=thread.id, run_id=run.id, tool_outputs=tool_outputs ) messages = client.beta.threads.messages.list(thread_id=thread.id) return messages.data[0].content[0].text.value def fetch_best_trades(): try: return interact_with_assistant("What are the best trades for today?") except Exception as e: return f"An error occurred: {str(e)}" css = """ body { font-family: Arial, sans-serif; background-color: #f0f2f5; } .container { margin: 0 auto; padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); } .output-box { margin-bottom: 20px; } """ with gr.Blocks(css=css) as iface: with gr.Column(elem_classes="container"): gr.Markdown("# 📈 Stock Market Assistant") gr.Markdown("Get insights on the best trades for today based on RSI data.") output = gr.Textbox( label="Trade Recommendations", lines=20, # Doubled from 10 to 20 interactive=False, elem_classes="output-box" ) fetch_button = gr.Button("🚀 Fetch me the best trades for today", variant="primary") fetch_button.click(fn=fetch_best_trades, outputs=output) iface.launch()