|
import os |
|
import json |
|
from typing import Dict, Any, Optional |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import gradio as gr |
|
|
|
app = FastAPI(title="ZenoBot Travel API") |
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B") |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
MAX_MODEL_CONTEXT = 8192 |
|
DEFAULT_MAX_TOKENS = 4096 |
|
|
|
|
|
print(f"Loading model {MODEL_ID} on {DEVICE}...") |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./") |
|
print("Loaded tokenizer from local directory") |
|
except Exception as e: |
|
print(f"Couldn't load from local directory: {e}") |
|
print("Attempting to load from model hub...") |
|
|
|
hf_token = os.environ.get("HF_TOKENIZER_READ_TOKEN", None) |
|
if hf_token: |
|
print("Using token from environment variable") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token) |
|
else: |
|
print("No token found in environment, attempting without authentication") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
) |
|
print("Model loaded successfully!") |
|
|
|
class TravelRequest(BaseModel): |
|
query: str |
|
temperature: Optional[float] = 0.1 |
|
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS |
|
|
|
class TravelResponse(BaseModel): |
|
response: Dict[str, Any] |
|
|
|
|
|
def load_system_prompt(): |
|
try: |
|
|
|
with open("modelfile", "r") as f: |
|
content = f.read() |
|
start = content.find('SYSTEM """') + 9 |
|
end = content.rfind('"""') |
|
return content[start:end].strip() |
|
except Exception as e: |
|
print(f"Error loading from modelfile: {e}") |
|
try: |
|
|
|
with open("system_prompt.txt", "r") as f: |
|
return f.read().strip() |
|
except Exception as e2: |
|
print(f"Error loading from system_prompt.txt: {e2}") |
|
|
|
return """You are Zeno-Bot, a travel assistant specializing in creating detailed travel itineraries strictly within one state in a country.""" |
|
|
|
SYSTEM_PROMPT = load_system_prompt() |
|
|
|
@app.post("/generate", response_model=TravelResponse) |
|
async def generate_travel_plan(request: TravelRequest): |
|
try: |
|
|
|
if request.max_tokens > MAX_MODEL_CONTEXT - 1000: |
|
print(f"Warning: Requested {request.max_tokens} tokens exceeds safe limit. Capping at {MAX_MODEL_CONTEXT - 1000}.") |
|
request.max_tokens = MAX_MODEL_CONTEXT - 1000 |
|
|
|
|
|
|
|
prompt = f"""<|system|> |
|
{SYSTEM_PROMPT} |
|
<|user|> |
|
{request.query} |
|
<|assistant|>""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
output = model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=request.max_tokens, |
|
temperature=request.temperature, |
|
do_sample=True, |
|
) |
|
|
|
response_text = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
try: |
|
|
|
json_start = response_text.find('{') |
|
json_end = response_text.rfind('}') + 1 |
|
|
|
if json_start != -1 and json_end != -1: |
|
json_str = response_text[json_start:json_end] |
|
json_data = json.loads(json_str) |
|
return {"response": json_data} |
|
else: |
|
|
|
return {"response": {"error": "No valid JSON found", "raw_text": response_text}} |
|
|
|
except json.JSONDecodeError: |
|
return {"response": {"error": "Invalid JSON format", "raw_text": response_text}} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}") |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy"} |
|
|
|
def generate_itinerary(query, temperature=0.1, max_tokens=DEFAULT_MAX_TOKENS): |
|
"""Function for Gradio interface""" |
|
try: |
|
request = TravelRequest(query=query, temperature=temperature, max_tokens=max_tokens) |
|
result = generate_travel_plan(request) |
|
return json.dumps(result.response, indent=2) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_itinerary, |
|
inputs=[ |
|
gr.Textbox(lines=3, placeholder="Plan a 3-day trip to California starting on 15/04/2024", label="Travel Query"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"), |
|
gr.Slider(minimum=512, maximum=6144, value=DEFAULT_MAX_TOKENS, step=512, label="Max Tokens") |
|
], |
|
outputs=gr.JSON(label="Generated Itinerary"), |
|
title="ZenoBot Travel Assistant", |
|
description="Generate detailed travel itineraries within a single state. Example: 'Plan a 3-day trip to California starting on 15/04/2024'" |
|
) |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
port = int(os.environ.get("PORT", 7860)) |
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|