Jainish1808 commited on
Commit
c5c6aed
·
1 Parent(s): 46a03f3

Uploaded 21-06 (7)

Browse files
Files changed (2) hide show
  1. main.py +75 -80
  2. templates/index.html +3 -18
main.py CHANGED
@@ -1,127 +1,122 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI, Request, Form
4
- from fastapi.templating import Jinja2Templates
5
  from fastapi.responses import HTMLResponse
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from peft import PeftModel
8
- from pathlib import Path
9
-
10
- # Set up Hugging Face cache directories
11
- cache_dir = "/tmp/huggingface"
12
- offload_dir = os.path.join(cache_dir, "offload")
13
- os.makedirs(cache_dir, exist_ok=True)
14
- os.makedirs(offload_dir, exist_ok=True)
15
- os.environ["HF_HOME"] = cache_dir
16
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
17
- os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
18
-
19
- # FastAPI setup
20
- app = FastAPI()
21
- templates = Jinja2Templates(directory="templates")
22
 
23
- # Load base model
24
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
 
 
 
25
 
26
- # FIXED PROMPT TEMPLATE
27
  PROMPT_TEMPLATE = """<|system|>
28
- You are Jack Patel. Answer questions about yourself using only information you were trained on. If you don't know something specific about yourself, say "I don't have that information."
29
- If the user's question is not about Jack Patel, answer as an AI assistant using your general knowledge and provide the most accurate answer possible.
 
30
  <|user|>
31
  {prompt}
32
  <|assistant|>
33
  """
34
 
35
- def load_model(base_model, lora_path):
36
- try:
37
- tokenizer = AutoTokenizer.from_pretrained(lora_path)
38
- except:
39
- tokenizer = AutoTokenizer.from_pretrained(base_model)
 
 
 
 
 
40
 
 
 
41
  tokenizer.pad_token = tokenizer.eos_token
42
 
43
  base = AutoModelForCausalLM.from_pretrained(
44
- base_model,
45
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
46
- low_cpu_mem_usage=True,
47
- cache_dir=cache_dir
48
  )
49
-
50
- model = PeftModel.from_pretrained(base, lora_path)
51
- model = model.merge_and_unload()
52
  model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
53
 
 
54
  return tokenizer, model
55
 
56
- # Load LoRA and QLoRA models
57
- try:
58
- TOKENIZER_LORA, MODEL_LORA = load_model(BASE_MODEL, "lora_model")
59
- TOKENIZER_QLORA, MODEL_QLORA = load_model(BASE_MODEL, "Qlora_model")
60
- except Exception as e:
61
- print(f"Model loading failed: {e}")
62
- exit(1)
63
-
64
  def generate_response(prompt, tokenizer, model):
65
  full_prompt = PROMPT_TEMPLATE.format(prompt=prompt)
66
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
67
-
68
  with torch.no_grad():
69
- outputs = model.generate(
70
  **inputs,
71
- max_new_tokens=100,
72
- temperature=0.3,
73
  top_p=0.9,
74
  do_sample=True,
75
  pad_token_id=tokenizer.eos_token_id,
76
  eos_token_id=tokenizer.eos_token_id,
77
  repetition_penalty=1.1
78
  )
79
-
80
- # Decode and clean the response
81
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
-
83
- # Extract only the assistant's response
84
- if "<|assistant|>" in full_response:
85
- response = full_response.split("<|assistant|>")[-1].strip()
86
- else:
87
- response = full_response.split("### Response:")[-1].strip() if "### Response:" in full_response else full_response
88
-
89
- # Clean up any remaining artifacts
90
- response = response.replace("<|user|>", "").replace("<|system|>", "").strip()
91
-
92
- return response
93
 
94
  @app.get("/", response_class=HTMLResponse)
95
- def index(request: Request):
96
  return templates.TemplateResponse("index.html", {
97
  "request": request,
98
- "data_count": 184,
 
99
  "prompt": "",
100
- "result": "",
101
- "model": ""
102
  })
103
 
104
  @app.post("/", response_class=HTMLResponse)
105
- async def query(request: Request, prompt: str = Form(...), model_type: str = Form(...)):
106
- if model_type == "lora":
107
- response = generate_response(prompt, TOKENIZER_LORA, MODEL_LORA)
108
- model_label = "LoRA - lora-tinyllama-final"
109
- elif model_type == "Qlora1":
110
- response = generate_response(prompt, TOKENIZER_QLORA, MODEL_QLORA)
111
- model_label = "QLoRA - lora-tinyllama-final1"
112
- else:
113
- response = "Invalid model selected."
114
- model_label = model_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  return templates.TemplateResponse("index.html", {
117
  "request": request,
118
- "data_count": 184,
 
119
  "prompt": prompt,
120
- "result": response,
121
- "model": model_label
122
- })
123
-
124
- # Run server
125
- if __name__ == "__main__":
126
- import uvicorn
127
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI, Form, Request
 
4
  from fastapi.responses import HTMLResponse
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.templating import Jinja2Templates
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from peft import PeftModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Paths
11
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
12
+ LORA_MODEL_DIR = "./lora_model"
13
+ QLORA_MODEL_DIR = "./Qlora_model"
14
+ ADALORA_MODEL_DIR = "./adalora_model"
15
+ cache_dir = "./cache"
16
 
17
+ # Prompt Template
18
  PROMPT_TEMPLATE = """<|system|>
19
+ You are Jack Patel. Answer questions about yourself using only information you were trained on. If you don't know something specific about yourself, say "I don't have that information."
20
+ If the user's question is not about Jack Patel, answer as an AI assistant using your general knowledge.
21
+ Always respond in 2 to 3 short sentences.
22
  <|user|>
23
  {prompt}
24
  <|assistant|>
25
  """
26
 
27
+ app = FastAPI()
28
+ app.mount("/static", StaticFiles(directory="static"), name="static")
29
+ templates = Jinja2Templates(directory="templates")
30
+
31
+ # Global cache to avoid reloading models
32
+ model_cache = {}
33
+
34
+ def load_model(adapter_path):
35
+ if adapter_path in model_cache:
36
+ return model_cache[adapter_path]
37
 
38
+ print(f"🔄 Loading model from: {adapter_path}")
39
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
40
  tokenizer.pad_token = tokenizer.eos_token
41
 
42
  base = AutoModelForCausalLM.from_pretrained(
43
+ BASE_MODEL,
44
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
+ cache_dir=cache_dir,
 
46
  )
47
+ model = PeftModel.from_pretrained(base, adapter_path)
 
 
48
  model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
49
 
50
+ model_cache[adapter_path] = (tokenizer, model)
51
  return tokenizer, model
52
 
 
 
 
 
 
 
 
 
53
  def generate_response(prompt, tokenizer, model):
54
  full_prompt = PROMPT_TEMPLATE.format(prompt=prompt)
55
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
 
56
  with torch.no_grad():
57
+ output = model.generate(
58
  **inputs,
59
+ max_new_tokens=50,
60
+ temperature=0.7,
61
  top_p=0.9,
62
  do_sample=True,
63
  pad_token_id=tokenizer.eos_token_id,
64
  eos_token_id=tokenizer.eos_token_id,
65
  repetition_penalty=1.1
66
  )
67
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
68
+ return decoded.split("<|assistant|>")[-1].strip() if "<|assistant|>" in decoded else decoded.strip()
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  @app.get("/", response_class=HTMLResponse)
71
+ async def form_get(request: Request):
72
  return templates.TemplateResponse("index.html", {
73
  "request": request,
74
+ "result": None,
75
+ "model": "",
76
  "prompt": "",
77
+ "data_count": 0
 
78
  })
79
 
80
  @app.post("/", response_class=HTMLResponse)
81
+ async def form_post(
82
+ request: Request,
83
+ prompt: str = Form(...),
84
+ model_type: str = Form(...)
85
+ ):
86
+ model_paths = {
87
+ "lora": LORA_MODEL_DIR,
88
+ "Qlora1": QLORA_MODEL_DIR,
89
+ "adalora": ADALORA_MODEL_DIR
90
+ }
91
+
92
+ model_labels = {
93
+ "lora": "LoRA - lora-tinyllama-final",
94
+ "Qlora1": "QLoRA - lora-tinyllama-final1",
95
+ "adalora": "AdaLoRA - adalora-tinyllama-final"
96
+ }
97
+
98
+ adapter_path = model_paths.get(model_type)
99
+ model_label = model_labels.get(model_type, model_type.upper())
100
+
101
+ if not adapter_path or not os.path.exists(adapter_path):
102
+ return templates.TemplateResponse("index.html", {
103
+ "request": request,
104
+ "result": "Invalid or missing model selected.",
105
+ "model": model_label,
106
+ "prompt": prompt,
107
+ "data_count": 0
108
+ })
109
+
110
+ try:
111
+ tokenizer, model = load_model(adapter_path)
112
+ result = generate_response(prompt, tokenizer, model)
113
+ except Exception as e:
114
+ result = f"Error generating response: {str(e)}"
115
 
116
  return templates.TemplateResponse("index.html", {
117
  "request": request,
118
+ "result": result,
119
+ "model": model_label,
120
  "prompt": prompt,
121
+ "data_count": 0 # Replace with real data count if available
122
+ })
 
 
 
 
 
 
templates/index.html CHANGED
@@ -5,7 +5,7 @@
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
  <title>Jack Patel AI Assistant</title>
7
  <style>
8
- /* Same CSS as your original index.html */
9
  * {
10
  margin: 0;
11
  padding: 0;
@@ -143,6 +143,7 @@
143
  .suggestions-dropdown {
144
  position: absolute;
145
  top: 100%;
 
146
  left: 0;
147
  right: 0;
148
  background: white;
@@ -289,17 +290,15 @@
289
  </div>
290
  {% endif %}
291
  </div>
292
-
293
  <!-- Model Selection Dropdown -->
294
  <div class="model-select-container">
295
  <label for="modelSelect">Choose a model:</label>
296
  <select id="modelSelect" name="model_type" class="model-select">
297
  <option value="lora">LoRA - lora-tinyllama-final</option>
298
- <!-- <option value="adalora">AdaLoRA - adalora-tinyllama-final</option> -->
299
  <option value="Qlora1">QLoRA - lora-tinyllama-final1</option>
300
  </select>
301
  </div>
302
-
303
  <form method="post" id="questionForm">
304
  <div class="chat-input-container">
305
  <div class="input-wrapper">
@@ -316,10 +315,8 @@
316
  </svg>
317
  </button>
318
  </div>
319
- <div class="suggestions-dropdown" id="suggestionsDropdown"></div>
320
  </div>
321
  </form>
322
-
323
  <div class="loading" id="loadingDiv">
324
  <div class="loading-dots">
325
  <div class="loading-dot"></div>
@@ -328,7 +325,6 @@
328
  </div>
329
  <p style="margin-top: 1rem; color: #6b7280;">Generating response...</p>
330
  </div>
331
-
332
  {% if result %}
333
  <div class="response-container">
334
  <div class="response-header">
@@ -341,7 +337,6 @@
341
  <div class="response-text">{{ result }}</div>
342
  </div>
343
  {% endif %}
344
-
345
  <div class="example-questions">
346
  <h3>Try asking:</h3>
347
  <div class="example-grid">
@@ -360,41 +355,32 @@
360
  </div>
361
  </div>
362
  </div>
363
-
364
  <div class="footer">
365
  <p>Powered by TinyLlama and Hugging Face</p>
366
  </div>
367
-
368
  <script>
369
  function fillQuestion(question) {
370
  document.getElementById('instruction').value = question;
371
  document.getElementById('instruction').focus();
372
  }
373
-
374
  document.getElementById('questionForm').addEventListener('submit', async function(e) {
375
  e.preventDefault();
376
-
377
  const textarea = document.getElementById('instruction');
378
  const modelType = document.getElementById('modelSelect').value;
379
  const submitBtn = document.getElementById('submitBtn');
380
  const loadingDiv = document.getElementById('loadingDiv');
381
-
382
  const prompt = textarea.value.trim();
383
  if (!prompt) return;
384
-
385
  loadingDiv.classList.add('show');
386
  submitBtn.disabled = true;
387
-
388
  const formData = new FormData();
389
  formData.append('prompt', prompt);
390
  formData.append('model_type', modelType);
391
-
392
  try {
393
  const response = await fetch("/", {
394
  method: "POST",
395
  body: formData
396
  });
397
-
398
  const html = await response.text();
399
  document.open();
400
  document.write(html);
@@ -406,6 +392,5 @@
406
  }
407
  });
408
  </script>
409
-
410
  </body>
411
  </html>
 
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
  <title>Jack Patel AI Assistant</title>
7
  <style>
8
+ /* Same CSS as before unchanged */
9
  * {
10
  margin: 0;
11
  padding: 0;
 
143
  .suggestions-dropdown {
144
  position: absolute;
145
  top: 100%;
146
+ ;
147
  left: 0;
148
  right: 0;
149
  background: white;
 
290
  </div>
291
  {% endif %}
292
  </div>
 
293
  <!-- Model Selection Dropdown -->
294
  <div class="model-select-container">
295
  <label for="modelSelect">Choose a model:</label>
296
  <select id="modelSelect" name="model_type" class="model-select">
297
  <option value="lora">LoRA - lora-tinyllama-final</option>
298
+ <option value="adalora">AdaLoRA - adalora-tinyllama-final</option>
299
  <option value="Qlora1">QLoRA - lora-tinyllama-final1</option>
300
  </select>
301
  </div>
 
302
  <form method="post" id="questionForm">
303
  <div class="chat-input-container">
304
  <div class="input-wrapper">
 
315
  </svg>
316
  </button>
317
  </div>
 
318
  </div>
319
  </form>
 
320
  <div class="loading" id="loadingDiv">
321
  <div class="loading-dots">
322
  <div class="loading-dot"></div>
 
325
  </div>
326
  <p style="margin-top: 1rem; color: #6b7280;">Generating response...</p>
327
  </div>
 
328
  {% if result %}
329
  <div class="response-container">
330
  <div class="response-header">
 
337
  <div class="response-text">{{ result }}</div>
338
  </div>
339
  {% endif %}
 
340
  <div class="example-questions">
341
  <h3>Try asking:</h3>
342
  <div class="example-grid">
 
355
  </div>
356
  </div>
357
  </div>
 
358
  <div class="footer">
359
  <p>Powered by TinyLlama and Hugging Face</p>
360
  </div>
 
361
  <script>
362
  function fillQuestion(question) {
363
  document.getElementById('instruction').value = question;
364
  document.getElementById('instruction').focus();
365
  }
 
366
  document.getElementById('questionForm').addEventListener('submit', async function(e) {
367
  e.preventDefault();
 
368
  const textarea = document.getElementById('instruction');
369
  const modelType = document.getElementById('modelSelect').value;
370
  const submitBtn = document.getElementById('submitBtn');
371
  const loadingDiv = document.getElementById('loadingDiv');
 
372
  const prompt = textarea.value.trim();
373
  if (!prompt) return;
 
374
  loadingDiv.classList.add('show');
375
  submitBtn.disabled = true;
 
376
  const formData = new FormData();
377
  formData.append('prompt', prompt);
378
  formData.append('model_type', modelType);
 
379
  try {
380
  const response = await fetch("/", {
381
  method: "POST",
382
  body: formData
383
  });
 
384
  const html = await response.text();
385
  document.open();
386
  document.write(html);
 
392
  }
393
  });
394
  </script>
 
395
  </body>
396
  </html>