File size: 6,246 Bytes
0682f6a
 
 
 
 
 
 
 
 
 
 
 
66dc162
0682f6a
 
3d2de03
 
 
 
0682f6a
 
66dc162
d64662a
 
66dc162
 
 
 
d64662a
 
 
 
 
 
 
 
66dc162
0682f6a
 
 
 
 
 
 
 
 
 
 
3d2de03
0682f6a
 
3d2de03
0682f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d2de03
 
 
 
 
0682f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d2de03
0682f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
3d2de03
0682f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import json
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr

app = FastAPI(title="ZenoBot Travel API")

# Use the local model path in HF Spaces
MODEL_ID = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model configuration
MAX_MODEL_CONTEXT = 8192  # Default context length for Llama-3.2-3B-Instruct
DEFAULT_MAX_TOKENS = 4096  # Increased default token limit for more detailed itineraries

# Load model and tokenizer
print(f"Loading model {MODEL_ID} on {DEVICE}...")
try:
    # First try to load the tokenizer from the current directory
    tokenizer = AutoTokenizer.from_pretrained("./")
    print("Loaded tokenizer from local directory")
except Exception as e:
    print(f"Couldn't load from local directory: {e}")
    print("Attempting to load from model hub...")
    # Get the token from environment variable
    hf_token = os.environ.get("HF_TOKENIZER_READ_TOKEN", None)
    if hf_token:
        print("Using token from environment variable")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token)
    else:
        print("No token found in environment, attempting without authentication")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)
print("Model loaded successfully!")

class TravelRequest(BaseModel):
    query: str
    temperature: Optional[float] = 0.1
    max_tokens: Optional[int] = DEFAULT_MAX_TOKENS  # Using increased default

class TravelResponse(BaseModel):
    response: Dict[str, Any]  # Fixed typo - removed the extra 'e'

# Load system prompt from modelfile or system_prompt.txt
def load_system_prompt():
    try:
        # First try loading from modelfile
        with open("modelfile", "r") as f:
            content = f.read()
            start = content.find('SYSTEM """') + 9
            end = content.rfind('"""')
            return content[start:end].strip()
    except Exception as e:
        print(f"Error loading from modelfile: {e}")
        try:
            # Fall back to system_prompt.txt
            with open("system_prompt.txt", "r") as f:
                return f.read().strip()
        except Exception as e2:
            print(f"Error loading from system_prompt.txt: {e2}")
            # Hard-coded fallback system prompt
            return """You are Zeno-Bot, a travel assistant specializing in creating detailed travel itineraries strictly within one state in a country."""

SYSTEM_PROMPT = load_system_prompt()

@app.post("/generate", response_model=TravelResponse)
async def generate_travel_plan(request: TravelRequest):
    try:
        # Safety check for token limits
        if request.max_tokens > MAX_MODEL_CONTEXT - 1000:  # Reserve ~1000 tokens for input
            print(f"Warning: Requested {request.max_tokens} tokens exceeds safe limit. Capping at {MAX_MODEL_CONTEXT - 1000}.")
            request.max_tokens = MAX_MODEL_CONTEXT - 1000
            
        # Prepare the prompt - using Llama 3 chat format
        # For Meta's Llama models, this is the recommended format
        prompt = f"""<|system|>
{SYSTEM_PROMPT}
<|user|>
{request.query}
<|assistant|>"""
        
        # Generate response
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        output = model.generate(
            inputs.input_ids,
            max_new_tokens=request.max_tokens,
            temperature=request.temperature,
            do_sample=True,
        )
        
        response_text = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        
        # Attempt to extract and parse JSON
        try:
            # Try to find JSON in the response
            json_start = response_text.find('{')
            json_end = response_text.rfind('}') + 1
            
            if json_start != -1 and json_end != -1:
                json_str = response_text[json_start:json_end]
                json_data = json.loads(json_str)
                return {"response": json_data}
            else:
                # If no JSON structure found, return the raw text
                return {"response": {"error": "No valid JSON found", "raw_text": response_text}}
        
        except json.JSONDecodeError:
            return {"response": {"error": "Invalid JSON format", "raw_text": response_text}}
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}")

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

def generate_itinerary(query, temperature=0.1, max_tokens=DEFAULT_MAX_TOKENS):  # Updated default max_tokens
    """Function for Gradio interface"""
    try:
        request = TravelRequest(query=query, temperature=temperature, max_tokens=max_tokens)
        result = generate_travel_plan(request)
        return json.dumps(result.response, indent=2)
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=generate_itinerary,
    inputs=[
        gr.Textbox(lines=3, placeholder="Plan a 3-day trip to California starting on 15/04/2024", label="Travel Query"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(minimum=512, maximum=6144, value=DEFAULT_MAX_TOKENS, step=512, label="Max Tokens")  # Updated max and default values
    ],
    outputs=gr.JSON(label="Generated Itinerary"),
    title="ZenoBot Travel Assistant",
    description="Generate detailed travel itineraries within a single state. Example: 'Plan a 3-day trip to California starting on 15/04/2024'"
)

# Create a combined app for Hugging Face Spaces
app = gr.mount_gradio_app(app, demo, path="/")

if __name__ == "__main__":
    import uvicorn
    # Use PORT environment variable or default to 7860 (HF Spaces default)
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)