File size: 6,246 Bytes
0682f6a 66dc162 0682f6a 3d2de03 0682f6a 66dc162 d64662a 66dc162 d64662a 66dc162 0682f6a 3d2de03 0682f6a 3d2de03 0682f6a 3d2de03 0682f6a 3d2de03 0682f6a 3d2de03 0682f6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import json
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
app = FastAPI(title="ZenoBot Travel API")
# Use the local model path in HF Spaces
MODEL_ID = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model configuration
MAX_MODEL_CONTEXT = 8192 # Default context length for Llama-3.2-3B-Instruct
DEFAULT_MAX_TOKENS = 4096 # Increased default token limit for more detailed itineraries
# Load model and tokenizer
print(f"Loading model {MODEL_ID} on {DEVICE}...")
try:
# First try to load the tokenizer from the current directory
tokenizer = AutoTokenizer.from_pretrained("./")
print("Loaded tokenizer from local directory")
except Exception as e:
print(f"Couldn't load from local directory: {e}")
print("Attempting to load from model hub...")
# Get the token from environment variable
hf_token = os.environ.get("HF_TOKENIZER_READ_TOKEN", None)
if hf_token:
print("Using token from environment variable")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token)
else:
print("No token found in environment, attempting without authentication")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
print("Model loaded successfully!")
class TravelRequest(BaseModel):
query: str
temperature: Optional[float] = 0.1
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS # Using increased default
class TravelResponse(BaseModel):
response: Dict[str, Any] # Fixed typo - removed the extra 'e'
# Load system prompt from modelfile or system_prompt.txt
def load_system_prompt():
try:
# First try loading from modelfile
with open("modelfile", "r") as f:
content = f.read()
start = content.find('SYSTEM """') + 9
end = content.rfind('"""')
return content[start:end].strip()
except Exception as e:
print(f"Error loading from modelfile: {e}")
try:
# Fall back to system_prompt.txt
with open("system_prompt.txt", "r") as f:
return f.read().strip()
except Exception as e2:
print(f"Error loading from system_prompt.txt: {e2}")
# Hard-coded fallback system prompt
return """You are Zeno-Bot, a travel assistant specializing in creating detailed travel itineraries strictly within one state in a country."""
SYSTEM_PROMPT = load_system_prompt()
@app.post("/generate", response_model=TravelResponse)
async def generate_travel_plan(request: TravelRequest):
try:
# Safety check for token limits
if request.max_tokens > MAX_MODEL_CONTEXT - 1000: # Reserve ~1000 tokens for input
print(f"Warning: Requested {request.max_tokens} tokens exceeds safe limit. Capping at {MAX_MODEL_CONTEXT - 1000}.")
request.max_tokens = MAX_MODEL_CONTEXT - 1000
# Prepare the prompt - using Llama 3 chat format
# For Meta's Llama models, this is the recommended format
prompt = f"""<|system|>
{SYSTEM_PROMPT}
<|user|>
{request.query}
<|assistant|>"""
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
output = model.generate(
inputs.input_ids,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
)
response_text = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Attempt to extract and parse JSON
try:
# Try to find JSON in the response
json_start = response_text.find('{')
json_end = response_text.rfind('}') + 1
if json_start != -1 and json_end != -1:
json_str = response_text[json_start:json_end]
json_data = json.loads(json_str)
return {"response": json_data}
else:
# If no JSON structure found, return the raw text
return {"response": {"error": "No valid JSON found", "raw_text": response_text}}
except json.JSONDecodeError:
return {"response": {"error": "Invalid JSON format", "raw_text": response_text}}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "healthy"}
def generate_itinerary(query, temperature=0.1, max_tokens=DEFAULT_MAX_TOKENS): # Updated default max_tokens
"""Function for Gradio interface"""
try:
request = TravelRequest(query=query, temperature=temperature, max_tokens=max_tokens)
result = generate_travel_plan(request)
return json.dumps(result.response, indent=2)
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=generate_itinerary,
inputs=[
gr.Textbox(lines=3, placeholder="Plan a 3-day trip to California starting on 15/04/2024", label="Travel Query"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(minimum=512, maximum=6144, value=DEFAULT_MAX_TOKENS, step=512, label="Max Tokens") # Updated max and default values
],
outputs=gr.JSON(label="Generated Itinerary"),
title="ZenoBot Travel Assistant",
description="Generate detailed travel itineraries within a single state. Example: 'Plan a 3-day trip to California starting on 15/04/2024'"
)
# Create a combined app for Hugging Face Spaces
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
# Use PORT environment variable or default to 7860 (HF Spaces default)
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|