Jainish1808 commited on
Commit
0541d4e
Β·
1 Parent(s): c38fcf7

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +311 -13
main.py CHANGED
@@ -1,6 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import torch
 
4
  from fastapi import FastAPI, Request, Form
5
  from fastapi.templating import Jinja2Templates
6
  from fastapi.responses import HTMLResponse
@@ -20,14 +297,6 @@ os.environ["HF_HOME"] = cache_dir
20
  os.environ["TRANSFORMERS_CACHE"] = cache_dir
21
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
 
23
- # FastAPI setup
24
- app = FastAPI(title="Jack Patel AI Assistant", description="Personal AI Assistant powered by Fine-tuned TinyLlama")
25
- templates = Jinja2Templates(directory="templates")
26
-
27
- # Create static directory if it doesn't exist
28
- os.makedirs("static", exist_ok=True)
29
- app.mount("/static", StaticFiles(directory="static"), name="static")
30
-
31
  # Global variables for model and tokenizer
32
  model = None
33
  tokenizer = None
@@ -64,7 +333,7 @@ def load_model_and_tokenizer():
64
  try:
65
  tokenizer = AutoTokenizer.from_pretrained(lora_model_path)
66
  logger.info("βœ… Tokenizer loaded from LoRA model")
67
- except:
68
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir=cache_dir)
69
  logger.info("βœ… Tokenizer loaded from base model")
70
  else:
@@ -193,8 +462,6 @@ def generate_response(instruction: str) -> str:
193
  logger.error(f"❌ Generation error: {e}")
194
  return f"Sorry, I encountered an error while generating the response: {str(e)}"
195
 
196
- # Load everything on startup
197
- @app.on_event("startup")
198
  async def startup_event():
199
  """Load model and data on startup"""
200
  logger.info("πŸš€ Starting up...")
@@ -202,7 +469,6 @@ async def startup_event():
202
  load_model_and_tokenizer()
203
  logger.info("βœ… Startup complete!")
204
 
205
- @app.on_event("shutdown")
206
  async def shutdown_event():
207
  """Cleanup on shutdown"""
208
  global model, tokenizer
@@ -211,9 +477,32 @@ async def shutdown_event():
211
  del model
212
  if tokenizer is not None:
213
  del tokenizer
214
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
215
  logger.info("βœ… Shutdown complete!")
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  # Routes
218
  @app.get("/", response_class=HTMLResponse)
219
  async def read_index(request: Request):
@@ -255,6 +544,15 @@ async def api_generate(instruction: str):
255
  response = generate_response(instruction)
256
  return {"instruction": instruction, "response": response}
257
 
 
 
 
 
 
 
 
 
 
258
  @app.get("/health")
259
  async def health_check():
260
  """Health check endpoint"""
 
1
+ # import os
2
+ # import json
3
+ # import torch
4
+ # from fastapi import FastAPI, Request, Form
5
+ # from fastapi.templating import Jinja2Templates
6
+ # from fastapi.responses import HTMLResponse
7
+ # from fastapi.staticfiles import StaticFiles
8
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ # from peft import PeftModel, PeftConfig
10
+ # import logging
11
+
12
+ # # Setup logging
13
+ # logging.basicConfig(level=logging.INFO)
14
+ # logger = logging.getLogger(__name__)
15
+
16
+ # # Setup environment cache
17
+ # cache_dir = "/tmp/huggingface"
18
+ # os.makedirs(cache_dir, exist_ok=True)
19
+ # os.environ["HF_HOME"] = cache_dir
20
+ # os.environ["TRANSFORMERS_CACHE"] = cache_dir
21
+ # os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
+
23
+ # # FastAPI setup
24
+ # app = FastAPI(title="Jack Patel AI Assistant", description="Personal AI Assistant powered by Fine-tuned TinyLlama")
25
+ # templates = Jinja2Templates(directory="templates")
26
+
27
+ # # Create static directory if it doesn't exist
28
+ # os.makedirs("static", exist_ok=True)
29
+ # app.mount("/static", StaticFiles(directory="static"), name="static")
30
+
31
+ # # Global variables for model and tokenizer
32
+ # model = None
33
+ # tokenizer = None
34
+ # training_data = []
35
+
36
+ # def load_training_data():
37
+ # """Load training data from JSON file"""
38
+ # global training_data
39
+ # try:
40
+ # if os.path.exists("data.json"):
41
+ # with open("data.json", "r", encoding="utf-8") as f:
42
+ # training_data = json.load(f)
43
+ # logger.info(f"βœ… Loaded {len(training_data)} training examples")
44
+ # else:
45
+ # logger.warning("⚠️ data.json not found, using empty training data")
46
+ # training_data = []
47
+ # except Exception as e:
48
+ # logger.error(f"❌ Error loading training data: {e}")
49
+ # training_data = []
50
+
51
+ # def load_model_and_tokenizer():
52
+ # """Load the model and tokenizer"""
53
+ # global model, tokenizer
54
+
55
+ # try:
56
+ # # Model paths
57
+ # base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
58
+ # lora_model_path = "lora_model"
59
+
60
+ # logger.info("πŸ”„ Loading tokenizer...")
61
+
62
+ # # Try to load tokenizer from LoRA path first, then base model
63
+ # if os.path.exists(lora_model_path):
64
+ # try:
65
+ # tokenizer = AutoTokenizer.from_pretrained(lora_model_path)
66
+ # logger.info("βœ… Tokenizer loaded from LoRA model")
67
+ # except:
68
+ # tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir=cache_dir)
69
+ # logger.info("βœ… Tokenizer loaded from base model")
70
+ # else:
71
+ # tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir=cache_dir)
72
+ # logger.info("βœ… Tokenizer loaded from base model")
73
+
74
+ # # Set pad token
75
+ # if tokenizer.pad_token is None:
76
+ # tokenizer.pad_token = tokenizer.eos_token
77
+
78
+ # logger.info("πŸ”„ Loading model...")
79
+
80
+ # # Load base model
81
+ # base_model = AutoModelForCausalLM.from_pretrained(
82
+ # base_model_name,
83
+ # device_map="auto",
84
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
85
+ # cache_dir=cache_dir,
86
+ # trust_remote_code=True
87
+ # )
88
+
89
+ # # Try to load and merge LoRA model if it exists
90
+ # if os.path.exists(lora_model_path) and os.path.exists(os.path.join(lora_model_path, "adapter_config.json")):
91
+ # try:
92
+ # logger.info("πŸ”„ Loading LoRA adapter...")
93
+ # peft_model = PeftModel.from_pretrained(base_model, lora_model_path)
94
+ # model = peft_model.merge_and_unload()
95
+ # logger.info("βœ… LoRA model loaded and merged successfully")
96
+ # except Exception as e:
97
+ # logger.warning(f"⚠️ Could not load LoRA model: {e}, using base model")
98
+ # model = base_model
99
+ # else:
100
+ # logger.info("ℹ️ No LoRA model found, using base model")
101
+ # model = base_model
102
+
103
+ # model.eval()
104
+ # logger.info("βœ… Model loaded successfully")
105
+
106
+ # # Print device info
107
+ # device = next(model.parameters()).device
108
+ # logger.info(f"πŸ–₯️ Model running on: {device}")
109
+
110
+ # except Exception as e:
111
+ # logger.error(f"❌ Model loading error: {e}")
112
+ # raise
113
+
114
+ # def format_prompt(instruction: str) -> str:
115
+ # """Format the instruction as a proper prompt"""
116
+ # return f"""<|system|>
117
+ # You are Jack Patel's personal AI assistant. Answer questions about Jack Patel based on the information you've been trained on. Be friendly, helpful, and accurate.
118
+
119
+ # <|user|>
120
+ # {instruction}
121
+
122
+ # <|assistant|>
123
+ # """
124
+
125
+ # def find_similar_question(question: str) -> str:
126
+ # """Find similar question in training data and return answer"""
127
+ # question_lower = question.lower().strip()
128
+
129
+ # # Direct match
130
+ # for item in training_data:
131
+ # if item["question"].lower().strip() == question_lower:
132
+ # return item["answer"]
133
+
134
+ # # Partial match
135
+ # for item in training_data:
136
+ # if any(word in item["question"].lower() for word in question_lower.split() if len(word) > 2):
137
+ # return item["answer"]
138
+
139
+ # return None
140
+
141
+ # def generate_response(instruction: str) -> str:
142
+ # """Generate response from the model"""
143
+ # if model is None or tokenizer is None:
144
+ # return "Model not loaded. Please try again later."
145
+
146
+ # try:
147
+ # # First try to find answer in training data
148
+ # similar_answer = find_similar_question(instruction)
149
+ # if similar_answer:
150
+ # return similar_answer
151
+
152
+ # # If not found, use the model
153
+ # prompt = format_prompt(instruction)
154
+
155
+ # inputs = tokenizer(
156
+ # prompt,
157
+ # return_tensors="pt",
158
+ # truncation=True,
159
+ # max_length=512
160
+ # )
161
+
162
+ # # Move inputs to same device as model
163
+ # device = next(model.parameters()).device
164
+ # inputs = {k: v.to(device) for k, v in inputs.items()}
165
+
166
+ # with torch.no_grad():
167
+ # outputs = model.generate(
168
+ # **inputs,
169
+ # max_new_tokens=150,
170
+ # temperature=0.7,
171
+ # top_p=0.9,
172
+ # do_sample=True,
173
+ # pad_token_id=tokenizer.eos_token_id,
174
+ # eos_token_id=tokenizer.eos_token_id,
175
+ # repetition_penalty=1.1
176
+ # )
177
+
178
+ # # Decode the response
179
+ # full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
180
+
181
+ # # Extract only the assistant's response
182
+ # if "<|assistant|>" in full_response:
183
+ # response = full_response.split("<|assistant|>")[-1].strip()
184
+ # else:
185
+ # response = full_response.replace(prompt, "").strip()
186
+
187
+ # # Clean up the response
188
+ # response = response.replace("<|user|>", "").replace("<|system|>", "").strip()
189
+
190
+ # return response if response else "I'm sorry, I couldn't generate a proper response. Please try asking differently."
191
+
192
+ # except Exception as e:
193
+ # logger.error(f"❌ Generation error: {e}")
194
+ # return f"Sorry, I encountered an error while generating the response: {str(e)}"
195
+
196
+ # # Load everything on startup
197
+ # @app.on_event("startup")
198
+ # async def startup_event():
199
+ # """Load model and data on startup"""
200
+ # logger.info("πŸš€ Starting up...")
201
+ # load_training_data()
202
+ # load_model_and_tokenizer()
203
+ # logger.info("βœ… Startup complete!")
204
+
205
+ # @app.on_event("shutdown")
206
+ # async def shutdown_event():
207
+ # """Cleanup on shutdown"""
208
+ # global model, tokenizer
209
+ # logger.info("πŸ›‘ Shutting down...")
210
+ # if model is not None:
211
+ # del model
212
+ # if tokenizer is not None:
213
+ # del tokenizer
214
+ # torch.cuda.empty_cache() if torch.cuda.is_available() else None
215
+ # logger.info("βœ… Shutdown complete!")
216
+
217
+ # # Routes
218
+ # @app.get("/", response_class=HTMLResponse)
219
+ # async def read_index(request: Request):
220
+ # """Homepage"""
221
+ # return templates.TemplateResponse("index.html", {
222
+ # "request": request,
223
+ # "result": "",
224
+ # "instruction": "",
225
+ # "data_count": len(training_data)
226
+ # })
227
+
228
+ # @app.post("/", response_class=HTMLResponse)
229
+ # async def generate_output(request: Request, instruction: str = Form(...)):
230
+ # """Generate response for user input"""
231
+ # if not instruction.strip():
232
+ # return templates.TemplateResponse("index.html", {
233
+ # "request": request,
234
+ # "result": "Please enter a question or instruction.",
235
+ # "instruction": instruction,
236
+ # "data_count": len(training_data)
237
+ # })
238
+
239
+ # logger.info(f"πŸ€– Generating response for: {instruction}")
240
+ # response = generate_response(instruction)
241
+
242
+ # return templates.TemplateResponse("index.html", {
243
+ # "request": request,
244
+ # "result": response,
245
+ # "instruction": instruction,
246
+ # "data_count": len(training_data)
247
+ # })
248
+
249
+ # @app.get("/api/generate")
250
+ # async def api_generate(instruction: str):
251
+ # """API endpoint for generating responses"""
252
+ # if not instruction.strip():
253
+ # return {"error": "Please provide an instruction"}
254
+
255
+ # response = generate_response(instruction)
256
+ # return {"instruction": instruction, "response": response}
257
+
258
+ # @app.get("/health")
259
+ # async def health_check():
260
+ # """Health check endpoint"""
261
+ # return {
262
+ # "status": "healthy",
263
+ # "model_loaded": model is not None,
264
+ # "tokenizer_loaded": tokenizer is not None,
265
+ # "training_data_count": len(training_data),
266
+ # "device": str(next(model.parameters()).device) if model else "unknown"
267
+ # }
268
+
269
+ # if __name__ == "__main__":
270
+ # import uvicorn
271
+ # uvicorn.run(app, host="0.0.0.0", port=7860)
272
+
273
+
274
+
275
+
276
+
277
  import os
278
  import json
279
  import torch
280
+ from contextlib import asynccontextmanager
281
  from fastapi import FastAPI, Request, Form
282
  from fastapi.templating import Jinja2Templates
283
  from fastapi.responses import HTMLResponse
 
297
  os.environ["TRANSFORMERS_CACHE"] = cache_dir
298
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
299
 
 
 
 
 
 
 
 
 
300
  # Global variables for model and tokenizer
301
  model = None
302
  tokenizer = None
 
333
  try:
334
  tokenizer = AutoTokenizer.from_pretrained(lora_model_path)
335
  logger.info("βœ… Tokenizer loaded from LoRA model")
336
+ except Exception:
337
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir=cache_dir)
338
  logger.info("βœ… Tokenizer loaded from base model")
339
  else:
 
462
  logger.error(f"❌ Generation error: {e}")
463
  return f"Sorry, I encountered an error while generating the response: {str(e)}"
464
 
 
 
465
  async def startup_event():
466
  """Load model and data on startup"""
467
  logger.info("πŸš€ Starting up...")
 
469
  load_model_and_tokenizer()
470
  logger.info("βœ… Startup complete!")
471
 
 
472
  async def shutdown_event():
473
  """Cleanup on shutdown"""
474
  global model, tokenizer
 
477
  del model
478
  if tokenizer is not None:
479
  del tokenizer
480
+ if torch.cuda.is_available():
481
+ torch.cuda.empty_cache()
482
  logger.info("βœ… Shutdown complete!")
483
 
484
+ # Modern lifespan event handler
485
+ @asynccontextmanager
486
+ async def lifespan(app: FastAPI):
487
+ # Startup
488
+ await startup_event()
489
+ yield
490
+ # Shutdown
491
+ await shutdown_event()
492
+
493
+ # FastAPI setup with lifespan
494
+ app = FastAPI(
495
+ title="Jack Patel AI Assistant",
496
+ description="Personal AI Assistant powered by Fine-tuned TinyLlama",
497
+ lifespan=lifespan
498
+ )
499
+
500
+ templates = Jinja2Templates(directory="templates")
501
+
502
+ # Create static directory if it doesn't exist
503
+ os.makedirs("static", exist_ok=True)
504
+ app.mount("/static", StaticFiles(directory="static"), name="static")
505
+
506
  # Routes
507
  @app.get("/", response_class=HTMLResponse)
508
  async def read_index(request: Request):
 
544
  response = generate_response(instruction)
545
  return {"instruction": instruction, "response": response}
546
 
547
+ @app.get("/api/questions")
548
+ async def get_questions():
549
+ """API endpoint to get available questions"""
550
+ return {
551
+ "questions": [item["question"] for item in training_data[:10]], # First 10 questions
552
+ "total_count": len(training_data),
553
+ "status": "available"
554
+ }
555
+
556
  @app.get("/health")
557
  async def health_check():
558
  """Health check endpoint"""