zenobot / app.py
Zeno0007's picture
Upload folder using huggingface_hub
d64662a verified
import os
import json
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
app = FastAPI(title="ZenoBot Travel API")
# Use the local model path in HF Spaces
MODEL_ID = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model configuration
MAX_MODEL_CONTEXT = 8192 # Default context length for Llama-3.2-3B-Instruct
DEFAULT_MAX_TOKENS = 4096 # Increased default token limit for more detailed itineraries
# Load model and tokenizer
print(f"Loading model {MODEL_ID} on {DEVICE}...")
try:
# First try to load the tokenizer from the current directory
tokenizer = AutoTokenizer.from_pretrained("./")
print("Loaded tokenizer from local directory")
except Exception as e:
print(f"Couldn't load from local directory: {e}")
print("Attempting to load from model hub...")
# Get the token from environment variable
hf_token = os.environ.get("HF_TOKENIZER_READ_TOKEN", None)
if hf_token:
print("Using token from environment variable")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token)
else:
print("No token found in environment, attempting without authentication")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
print("Model loaded successfully!")
class TravelRequest(BaseModel):
query: str
temperature: Optional[float] = 0.1
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS # Using increased default
class TravelResponse(BaseModel):
response: Dict[str, Any] # Fixed typo - removed the extra 'e'
# Load system prompt from modelfile or system_prompt.txt
def load_system_prompt():
try:
# First try loading from modelfile
with open("modelfile", "r") as f:
content = f.read()
start = content.find('SYSTEM """') + 9
end = content.rfind('"""')
return content[start:end].strip()
except Exception as e:
print(f"Error loading from modelfile: {e}")
try:
# Fall back to system_prompt.txt
with open("system_prompt.txt", "r") as f:
return f.read().strip()
except Exception as e2:
print(f"Error loading from system_prompt.txt: {e2}")
# Hard-coded fallback system prompt
return """You are Zeno-Bot, a travel assistant specializing in creating detailed travel itineraries strictly within one state in a country."""
SYSTEM_PROMPT = load_system_prompt()
@app.post("/generate", response_model=TravelResponse)
async def generate_travel_plan(request: TravelRequest):
try:
# Safety check for token limits
if request.max_tokens > MAX_MODEL_CONTEXT - 1000: # Reserve ~1000 tokens for input
print(f"Warning: Requested {request.max_tokens} tokens exceeds safe limit. Capping at {MAX_MODEL_CONTEXT - 1000}.")
request.max_tokens = MAX_MODEL_CONTEXT - 1000
# Prepare the prompt - using Llama 3 chat format
# For Meta's Llama models, this is the recommended format
prompt = f"""<|system|>
{SYSTEM_PROMPT}
<|user|>
{request.query}
<|assistant|>"""
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
output = model.generate(
inputs.input_ids,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
)
response_text = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Attempt to extract and parse JSON
try:
# Try to find JSON in the response
json_start = response_text.find('{')
json_end = response_text.rfind('}') + 1
if json_start != -1 and json_end != -1:
json_str = response_text[json_start:json_end]
json_data = json.loads(json_str)
return {"response": json_data}
else:
# If no JSON structure found, return the raw text
return {"response": {"error": "No valid JSON found", "raw_text": response_text}}
except json.JSONDecodeError:
return {"response": {"error": "Invalid JSON format", "raw_text": response_text}}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "healthy"}
def generate_itinerary(query, temperature=0.1, max_tokens=DEFAULT_MAX_TOKENS): # Updated default max_tokens
"""Function for Gradio interface"""
try:
request = TravelRequest(query=query, temperature=temperature, max_tokens=max_tokens)
result = generate_travel_plan(request)
return json.dumps(result.response, indent=2)
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=generate_itinerary,
inputs=[
gr.Textbox(lines=3, placeholder="Plan a 3-day trip to California starting on 15/04/2024", label="Travel Query"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(minimum=512, maximum=6144, value=DEFAULT_MAX_TOKENS, step=512, label="Max Tokens") # Updated max and default values
],
outputs=gr.JSON(label="Generated Itinerary"),
title="ZenoBot Travel Assistant",
description="Generate detailed travel itineraries within a single state. Example: 'Plan a 3-day trip to California starting on 15/04/2024'"
)
# Create a combined app for Hugging Face Spaces
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
# Use PORT environment variable or default to 7860 (HF Spaces default)
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)