import os import json import uuid import httpx import gradio as gr import torch from fastapi import FastAPI, HTTPException, Request from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import uvicorn import asyncio # ✅ Reduce memory usage by setting float16 precision torch.set_default_dtype(torch.float16) # ✅ Hugging Face API Token HF_TOKEN = os.getenv("HF_TOKEN") MODEL_NAME = "hpyapali/tinyllama-workout" event_store = {} # Store AI responses for polling fallback # ✅ Webhook URL (Your Vapor Webhook Server) WEBHOOK_URL = "https://694a-50-35-76-93.ngrok-free.app/fineTuneModel" app = FastAPI() # ✅ Lazy Load AI Model (prevents timeout on Hugging Face) pipe = None def get_pipeline(): global pipe if pipe is None: try: print("🔄 Loading AI Model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, token=HF_TOKEN, torch_dtype=torch.float16, # Lower memory usage device_map="auto" # Load on available device (CPU/GPU) ) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) print("✅ AI Model Loaded Successfully!") except Exception as e: print(f"❌ Error loading model: {e}") pipe = None return pipe # ✅ AI Function - Processes and ranks workouts def analyze_workouts(last_workouts: str): pipe = get_pipeline() if pipe is None: return "❌ AI model is not loaded." if not last_workouts.strip(): return "❌ No workout data provided." instruction = ( "You are a fitness AI assistant. Rank these workouts by heart rate recovery:" f"\n\n{last_workouts}\n\nOnly return rankings. No extra text." ) print(f"📨 Sending prompt to AI: {instruction}") # ✅ Debug log try: result = pipe(instruction, max_new_tokens=200, temperature=0.3, top_p=0.9) if not result or "generated_text" not in result[0]: print("❌ AI response is empty or malformed!") return "❌ AI did not return a valid response." response_text = result[0]["generated_text"].strip() print(f"🔍 AI Response: {response_text}") # ✅ Debug log return response_text except Exception as e: print(f"❌ AI Error: {str(e)}") # ✅ Debug AI errors return f"❌ Error: {str(e)}" # ✅ API Route for Processing Workout Data @app.post("/gradio_api/call/predict") async def process_workout_request(request: Request): try: req_body = await request.json() print("📩 RAW REQUEST FROM HF:", req_body) if "data" not in req_body or not isinstance(req_body["data"], list): raise HTTPException(status_code=400, detail="Invalid request format.") last_workouts = req_body["data"][0] event_id = str(uuid.uuid4()) print(f"✅ Processing AI Request - Event ID: {event_id}") response_text = analyze_workouts(last_workouts) # ✅ Store response for polling fallback event_store[event_id] = response_text # ✅ Send AI response to Vapor Webhook async with httpx.AsyncClient() as client: try: webhook_response = await client.post(WEBHOOK_URL, json={"event_id": event_id, "data": [response_text]}) webhook_response.raise_for_status() print(f"✅ Webhook sent successfully: {webhook_response.json()}") except Exception as e: print(f"⚠️ Webhook failed: {e}") print("🔄 Switching to Polling Mode...") return {"event_id": event_id} except Exception as e: print(f"❌ Error processing request: {e}") raise HTTPException(status_code=500, detail=str(e)) # ✅ Polling Endpoint (If Webhook Fails) @app.get("/gradio_api/poll/{event_id}") async def poll(event_id: str): """Fetches stored AI response for a given event ID.""" if event_id in event_store: return {"data": [event_store.pop(event_id)]} return {"detail": "Not Found"} # ✅ Health Check @app.get("/") async def root(): return {"message": "Workout Analysis & Ranking AI is running!"} # ✅ Gradio UI for Testing iface = gr.Interface( fn=analyze_workouts, inputs="text", outputs="text", title="Workout Analysis & Ranking AI", description="Enter workout data to analyze effectiveness, rank workouts, and receive improvement recommendations." ) # ✅ Start Both FastAPI & Gradio def start_gradio(): iface.launch(server_name="0.0.0.0", server_port=7860, share=True) def start_fastapi(): uvicorn.run(app, host="0.0.0.0", port=7861) # ✅ Run both servers in parallel if __name__ == "__main__": import threading threading.Thread(target=start_gradio).start() threading.Thread(target=start_fastapi).start()