fineTuneModel / app.py
hpyapali's picture
Update app.py
d3e207c verified
raw
history blame
4.64 kB
import os
import json
import uuid
import httpx
import gradio as gr
import torch
from fastapi import FastAPI, HTTPException, Request
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import uvicorn
import asyncio
# βœ… Use float16 to reduce memory usage
torch.set_default_dtype(torch.float16)
# βœ… Hugging Face API Token
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME = "hpyapali/tinyllama-workout"
event_store = {} # Store AI responses for polling fallback
app = FastAPI()
# βœ… Lazy Load AI Model (to prevent Space timeout)
pipe = None
def get_pipeline():
global pipe
if pipe is None:
try:
print("πŸ”„ Loading AI Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
torch_dtype=torch.float16, # Lower memory usage
device_map="auto" # Load on available device (CPU/GPU)
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print("βœ… AI Model Loaded Successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
pipe = None
return pipe
# βœ… AI Function - Processes and ranks workouts
def analyze_workouts(last_workouts: str):
"""Generates AI-based workout rankings based on heart rate recovery."""
pipe = get_pipeline()
if pipe is None:
return "❌ AI model is not loaded."
if not last_workouts.strip():
return "❌ No workout data provided."
instruction = (
"You are a fitness AI assistant. Rank these workouts by heart rate recovery:"
f"\n\n{last_workouts}\n\nOnly return rankings. No extra text."
)
try:
result = pipe(instruction, max_new_tokens=200, temperature=0.3, top_p=0.9)
response_text = result[0]["generated_text"].strip()
return response_text
except Exception as e:
return f"❌ Error: {str(e)}"
# βœ… API Route for Processing Workout Data
@app.post("/gradio_api/call/predict")
async def process_workout_request(request: Request):
try:
req_body = await request.json()
print("πŸ“© RAW REQUEST FROM HF:", req_body)
if "data" not in req_body or not isinstance(req_body["data"], list):
raise HTTPException(status_code=400, detail="Invalid request format.")
last_workouts = req_body["data"][0]
event_id = str(uuid.uuid4())
print(f"βœ… Processing AI Request - Event ID: {event_id}")
response_text = analyze_workouts(last_workouts)
# βœ… Store response for polling fallback
event_store[event_id] = response_text
# βœ… Send AI response to Vapor Webhook
webhook_url = "https://694a-50-35-76-93.ngrok-free.app/fineTuneModel"
async with httpx.AsyncClient() as client:
try:
webhook_response = await client.post(webhook_url, json={"event_id": event_id, "data": [response_text]})
webhook_response.raise_for_status()
print(f"βœ… Webhook sent successfully: {webhook_response.json()}")
except Exception as e:
print(f"⚠️ Webhook failed: {e}")
print("πŸ”„ Switching to Polling Mode...")
return {"event_id": event_id}
except Exception as e:
print(f"❌ Error processing request: {e}")
raise HTTPException(status_code=500, detail=str(e))
# βœ… Polling Endpoint (If Webhook Fails)
@app.get("/gradio_api/poll/{event_id}")
async def poll(event_id: str):
"""Fetches stored AI response for a given event ID."""
if event_id in event_store:
return {"data": [event_store.pop(event_id)]}
return {"detail": "Not Found"}
# βœ… Health Check
@app.get("/")
async def root():
return {"message": "Workout Analysis & Ranking AI is running!"}
# βœ… Gradio UI for Testing
iface = gr.Interface(
fn=analyze_workouts,
inputs="text",
outputs="text",
title="Workout Analysis & Ranking AI",
description="Enter workout data to analyze effectiveness, rank workouts, and receive improvement recommendations."
)
# βœ… Start Both FastAPI & Gradio
def start_gradio():
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
def start_fastapi():
uvicorn.run(app, host="0.0.0.0", port=7861)
# βœ… Run both servers in parallel
if __name__ == "__main__":
import threading
threading.Thread(target=start_gradio).start()
threading.Thread(target=start_fastapi).start()