import os import json from typing import Dict, Any, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch import gradio as gr app = FastAPI(title="ZenoBot Travel API") # Use the local model path in HF Spaces MODEL_ID = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Model configuration MAX_MODEL_CONTEXT = 8192 # Default context length for Llama-3.2-3B-Instruct DEFAULT_MAX_TOKENS = 4096 # Increased default token limit for more detailed itineraries # Load model and tokenizer print(f"Loading model {MODEL_ID} on {DEVICE}...") try: # First try to load the tokenizer from the current directory tokenizer = AutoTokenizer.from_pretrained("./") print("Loaded tokenizer from local directory") except Exception as e: print(f"Couldn't load from local directory: {e}") print("Attempting to load from model hub...") # Get the token from environment variable hf_token = os.environ.get("HF_TOKENIZER_READ_TOKEN", None) if hf_token: print("Using token from environment variable") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token) else: print("No token found in environment, attempting without authentication") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto", low_cpu_mem_usage=True ) print("Model loaded successfully!") class TravelRequest(BaseModel): query: str temperature: Optional[float] = 0.1 max_tokens: Optional[int] = DEFAULT_MAX_TOKENS # Using increased default class TravelResponse(BaseModel): response: Dict[str, Any] # Fixed typo - removed the extra 'e' # Load system prompt from modelfile or system_prompt.txt def load_system_prompt(): try: # First try loading from modelfile with open("modelfile", "r") as f: content = f.read() start = content.find('SYSTEM """') + 9 end = content.rfind('"""') return content[start:end].strip() except Exception as e: print(f"Error loading from modelfile: {e}") try: # Fall back to system_prompt.txt with open("system_prompt.txt", "r") as f: return f.read().strip() except Exception as e2: print(f"Error loading from system_prompt.txt: {e2}") # Hard-coded fallback system prompt return """You are Zeno-Bot, a travel assistant specializing in creating detailed travel itineraries strictly within one state in a country.""" SYSTEM_PROMPT = load_system_prompt() @app.post("/generate", response_model=TravelResponse) async def generate_travel_plan(request: TravelRequest): try: # Safety check for token limits if request.max_tokens > MAX_MODEL_CONTEXT - 1000: # Reserve ~1000 tokens for input print(f"Warning: Requested {request.max_tokens} tokens exceeds safe limit. Capping at {MAX_MODEL_CONTEXT - 1000}.") request.max_tokens = MAX_MODEL_CONTEXT - 1000 # Prepare the prompt - using Llama 3 chat format # For Meta's Llama models, this is the recommended format prompt = f"""<|system|> {SYSTEM_PROMPT} <|user|> {request.query} <|assistant|>""" # Generate response inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) output = model.generate( inputs.input_ids, max_new_tokens=request.max_tokens, temperature=request.temperature, do_sample=True, ) response_text = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) # Attempt to extract and parse JSON try: # Try to find JSON in the response json_start = response_text.find('{') json_end = response_text.rfind('}') + 1 if json_start != -1 and json_end != -1: json_str = response_text[json_start:json_end] json_data = json.loads(json_str) return {"response": json_data} else: # If no JSON structure found, return the raw text return {"response": {"error": "No valid JSON found", "raw_text": response_text}} except json.JSONDecodeError: return {"response": {"error": "Invalid JSON format", "raw_text": response_text}} except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}") @app.get("/health") async def health_check(): return {"status": "healthy"} def generate_itinerary(query, temperature=0.1, max_tokens=DEFAULT_MAX_TOKENS): # Updated default max_tokens """Function for Gradio interface""" try: request = TravelRequest(query=query, temperature=temperature, max_tokens=max_tokens) result = generate_travel_plan(request) return json.dumps(result.response, indent=2) except Exception as e: return f"Error: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=generate_itinerary, inputs=[ gr.Textbox(lines=3, placeholder="Plan a 3-day trip to California starting on 15/04/2024", label="Travel Query"), gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"), gr.Slider(minimum=512, maximum=6144, value=DEFAULT_MAX_TOKENS, step=512, label="Max Tokens") # Updated max and default values ], outputs=gr.JSON(label="Generated Itinerary"), title="ZenoBot Travel Assistant", description="Generate detailed travel itineraries within a single state. Example: 'Plan a 3-day trip to California starting on 15/04/2024'" ) # Create a combined app for Hugging Face Spaces app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": import uvicorn # Use PORT environment variable or default to 7860 (HF Spaces default) port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)