Jainish1808 commited on
Commit
46a03f3
·
1 Parent(s): bf50fc7

Uploaded 21-06 (6)

Browse files
Files changed (2) hide show
  1. main.py +33 -20
  2. templates/index.html +34 -29
main.py CHANGED
@@ -20,24 +20,24 @@ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
20
  app = FastAPI()
21
  templates = Jinja2Templates(directory="templates")
22
 
23
- # Base model
24
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
25
 
26
- # Improved prompt template
27
- PROMPT_TEMPLATE = """
28
- You are Jack Patel's AI assistant. You respond in two modes:
29
-
30
- 1. **Jack Patel Specialist Mode:** When the user's question is about Jack Patel (e.g., "Jack Patel", "his", "him", "Jack's"), use only facts from training. If you don't know, say:
31
- "I don't have that specific information about Jack Patel in my training data."
32
-
33
- 2. **General Knowledge Mode:** For other questions, respond normally using general knowledge.
34
-
35
- User: {prompt}
36
- AI:
37
  """
38
 
39
  def load_model(base_model, lora_path):
40
- tokenizer = AutoTokenizer.from_pretrained(lora_path, use_fast=True)
 
 
 
 
41
  tokenizer.pad_token = tokenizer.eos_token
42
 
43
  base = AutoModelForCausalLM.from_pretrained(
@@ -63,20 +63,33 @@ except Exception as e:
63
 
64
  def generate_response(prompt, tokenizer, model):
65
  full_prompt = PROMPT_TEMPLATE.format(prompt=prompt)
66
- print("\n===== PROMPT PASSED TO MODEL =====\n", full_prompt)
67
-
68
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
 
69
  with torch.no_grad():
70
  outputs = model.generate(
71
  **inputs,
72
- max_new_tokens=256,
73
- temperature=0.7,
74
  top_p=0.9,
75
  do_sample=True,
76
- pad_token_id=tokenizer.eos_token_id
 
 
77
  )
78
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
- return decoded.split("AI:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  @app.get("/", response_class=HTMLResponse)
82
  def index(request: Request):
 
20
  app = FastAPI()
21
  templates = Jinja2Templates(directory="templates")
22
 
23
+ # Load base model
24
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
25
 
26
+ # FIXED PROMPT TEMPLATE
27
+ PROMPT_TEMPLATE = """<|system|>
28
+ You are Jack Patel. Answer questions about yourself using only information you were trained on. If you don't know something specific about yourself, say "I don't have that information."
29
+ If the user's question is not about Jack Patel, answer as an AI assistant using your general knowledge and provide the most accurate answer possible.
30
+ <|user|>
31
+ {prompt}
32
+ <|assistant|>
 
 
 
 
33
  """
34
 
35
  def load_model(base_model, lora_path):
36
+ try:
37
+ tokenizer = AutoTokenizer.from_pretrained(lora_path)
38
+ except:
39
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
40
+
41
  tokenizer.pad_token = tokenizer.eos_token
42
 
43
  base = AutoModelForCausalLM.from_pretrained(
 
63
 
64
  def generate_response(prompt, tokenizer, model):
65
  full_prompt = PROMPT_TEMPLATE.format(prompt=prompt)
 
 
66
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
67
+
68
  with torch.no_grad():
69
  outputs = model.generate(
70
  **inputs,
71
+ max_new_tokens=100,
72
+ temperature=0.3,
73
  top_p=0.9,
74
  do_sample=True,
75
+ pad_token_id=tokenizer.eos_token_id,
76
+ eos_token_id=tokenizer.eos_token_id,
77
+ repetition_penalty=1.1
78
  )
79
+
80
+ # Decode and clean the response
81
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
+
83
+ # Extract only the assistant's response
84
+ if "<|assistant|>" in full_response:
85
+ response = full_response.split("<|assistant|>")[-1].strip()
86
+ else:
87
+ response = full_response.split("### Response:")[-1].strip() if "### Response:" in full_response else full_response
88
+
89
+ # Clean up any remaining artifacts
90
+ response = response.replace("<|user|>", "").replace("<|system|>", "").strip()
91
+
92
+ return response
93
 
94
  @app.get("/", response_class=HTMLResponse)
95
  def index(request: Request):
templates/index.html CHANGED
@@ -366,41 +366,46 @@
366
  </div>
367
 
368
  <script>
369
- document.getElementById('questionForm').addEventListener('submit', async function(e) {
370
- e.preventDefault(); // prevent default form reload
 
 
371
 
372
- const textarea = document.getElementById('instruction');
373
- const modelType = document.getElementById('modelSelect').value;
374
- const submitBtn = document.getElementById('submitBtn');
375
- const loadingDiv = document.getElementById('loadingDiv');
376
 
377
- const prompt = textarea.value.trim();
378
- if (!prompt) return;
 
 
379
 
380
- loadingDiv.classList.add('show');
381
- submitBtn.disabled = true;
382
 
383
- const formData = new FormData();
384
- formData.append('prompt', prompt);
385
- formData.append('model_type', modelType);
386
 
387
- try {
388
- const response = await fetch("/", {
389
- method: "POST",
390
- body: formData
391
- });
 
 
 
 
392
 
393
- const html = await response.text();
394
- document.open();
395
- document.write(html);
396
- document.close();
397
- } catch (err) {
398
- alert("Something went wrong: " + err.message);
399
- } finally {
400
- submitBtn.disabled = false;
401
- }
402
- });
403
- </script>
404
 
405
  </body>
406
  </html>
 
366
  </div>
367
 
368
  <script>
369
+ function fillQuestion(question) {
370
+ document.getElementById('instruction').value = question;
371
+ document.getElementById('instruction').focus();
372
+ }
373
 
374
+ document.getElementById('questionForm').addEventListener('submit', async function(e) {
375
+ e.preventDefault();
 
 
376
 
377
+ const textarea = document.getElementById('instruction');
378
+ const modelType = document.getElementById('modelSelect').value;
379
+ const submitBtn = document.getElementById('submitBtn');
380
+ const loadingDiv = document.getElementById('loadingDiv');
381
 
382
+ const prompt = textarea.value.trim();
383
+ if (!prompt) return;
384
 
385
+ loadingDiv.classList.add('show');
386
+ submitBtn.disabled = true;
 
387
 
388
+ const formData = new FormData();
389
+ formData.append('prompt', prompt);
390
+ formData.append('model_type', modelType);
391
+
392
+ try {
393
+ const response = await fetch("/", {
394
+ method: "POST",
395
+ body: formData
396
+ });
397
 
398
+ const html = await response.text();
399
+ document.open();
400
+ document.write(html);
401
+ document.close();
402
+ } catch (err) {
403
+ alert("Something went wrong: " + err.message);
404
+ } finally {
405
+ submitBtn.disabled = false;
406
+ }
407
+ });
408
+ </script>
409
 
410
  </body>
411
  </html>