Spaces:
Sleeping
Sleeping
Jainish1808
commited on
Commit
·
46a03f3
1
Parent(s):
bf50fc7
Uploaded 21-06 (6)
Browse files- main.py +33 -20
- templates/index.html +34 -29
main.py
CHANGED
|
@@ -20,24 +20,24 @@ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
|
|
| 20 |
app = FastAPI()
|
| 21 |
templates = Jinja2Templates(directory="templates")
|
| 22 |
|
| 23 |
-
#
|
| 24 |
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
PROMPT_TEMPLATE = """
|
| 28 |
-
You are Jack Patel
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
2. **General Knowledge Mode:** For other questions, respond normally using general knowledge.
|
| 34 |
-
|
| 35 |
-
User: {prompt}
|
| 36 |
-
AI:
|
| 37 |
"""
|
| 38 |
|
| 39 |
def load_model(base_model, lora_path):
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
tokenizer.pad_token = tokenizer.eos_token
|
| 42 |
|
| 43 |
base = AutoModelForCausalLM.from_pretrained(
|
|
@@ -63,20 +63,33 @@ except Exception as e:
|
|
| 63 |
|
| 64 |
def generate_response(prompt, tokenizer, model):
|
| 65 |
full_prompt = PROMPT_TEMPLATE.format(prompt=prompt)
|
| 66 |
-
print("\n===== PROMPT PASSED TO MODEL =====\n", full_prompt)
|
| 67 |
-
|
| 68 |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
|
|
|
| 69 |
with torch.no_grad():
|
| 70 |
outputs = model.generate(
|
| 71 |
**inputs,
|
| 72 |
-
max_new_tokens=
|
| 73 |
-
temperature=0.
|
| 74 |
top_p=0.9,
|
| 75 |
do_sample=True,
|
| 76 |
-
pad_token_id=tokenizer.eos_token_id
|
|
|
|
|
|
|
| 77 |
)
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
@app.get("/", response_class=HTMLResponse)
|
| 82 |
def index(request: Request):
|
|
|
|
| 20 |
app = FastAPI()
|
| 21 |
templates = Jinja2Templates(directory="templates")
|
| 22 |
|
| 23 |
+
# Load base model
|
| 24 |
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 25 |
|
| 26 |
+
# FIXED PROMPT TEMPLATE
|
| 27 |
+
PROMPT_TEMPLATE = """<|system|>
|
| 28 |
+
You are Jack Patel. Answer questions about yourself using only information you were trained on. If you don't know something specific about yourself, say "I don't have that information."
|
| 29 |
+
If the user's question is not about Jack Patel, answer as an AI assistant using your general knowledge and provide the most accurate answer possible.
|
| 30 |
+
<|user|>
|
| 31 |
+
{prompt}
|
| 32 |
+
<|assistant|>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"""
|
| 34 |
|
| 35 |
def load_model(base_model, lora_path):
|
| 36 |
+
try:
|
| 37 |
+
tokenizer = AutoTokenizer.from_pretrained(lora_path)
|
| 38 |
+
except:
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 40 |
+
|
| 41 |
tokenizer.pad_token = tokenizer.eos_token
|
| 42 |
|
| 43 |
base = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 63 |
|
| 64 |
def generate_response(prompt, tokenizer, model):
|
| 65 |
full_prompt = PROMPT_TEMPLATE.format(prompt=prompt)
|
|
|
|
|
|
|
| 66 |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
| 67 |
+
|
| 68 |
with torch.no_grad():
|
| 69 |
outputs = model.generate(
|
| 70 |
**inputs,
|
| 71 |
+
max_new_tokens=100,
|
| 72 |
+
temperature=0.3,
|
| 73 |
top_p=0.9,
|
| 74 |
do_sample=True,
|
| 75 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 76 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 77 |
+
repetition_penalty=1.1
|
| 78 |
)
|
| 79 |
+
|
| 80 |
+
# Decode and clean the response
|
| 81 |
+
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 82 |
+
|
| 83 |
+
# Extract only the assistant's response
|
| 84 |
+
if "<|assistant|>" in full_response:
|
| 85 |
+
response = full_response.split("<|assistant|>")[-1].strip()
|
| 86 |
+
else:
|
| 87 |
+
response = full_response.split("### Response:")[-1].strip() if "### Response:" in full_response else full_response
|
| 88 |
+
|
| 89 |
+
# Clean up any remaining artifacts
|
| 90 |
+
response = response.replace("<|user|>", "").replace("<|system|>", "").strip()
|
| 91 |
+
|
| 92 |
+
return response
|
| 93 |
|
| 94 |
@app.get("/", response_class=HTMLResponse)
|
| 95 |
def index(request: Request):
|
templates/index.html
CHANGED
|
@@ -366,41 +366,46 @@
|
|
| 366 |
</div>
|
| 367 |
|
| 368 |
<script>
|
| 369 |
-
|
| 370 |
-
|
|
|
|
|
|
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
const submitBtn = document.getElementById('submitBtn');
|
| 375 |
-
const loadingDiv = document.getElementById('loadingDiv');
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
formData.append('model_type', modelType);
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
});
|
| 403 |
-
</script>
|
| 404 |
|
| 405 |
</body>
|
| 406 |
</html>
|
|
|
|
| 366 |
</div>
|
| 367 |
|
| 368 |
<script>
|
| 369 |
+
function fillQuestion(question) {
|
| 370 |
+
document.getElementById('instruction').value = question;
|
| 371 |
+
document.getElementById('instruction').focus();
|
| 372 |
+
}
|
| 373 |
|
| 374 |
+
document.getElementById('questionForm').addEventListener('submit', async function(e) {
|
| 375 |
+
e.preventDefault();
|
|
|
|
|
|
|
| 376 |
|
| 377 |
+
const textarea = document.getElementById('instruction');
|
| 378 |
+
const modelType = document.getElementById('modelSelect').value;
|
| 379 |
+
const submitBtn = document.getElementById('submitBtn');
|
| 380 |
+
const loadingDiv = document.getElementById('loadingDiv');
|
| 381 |
|
| 382 |
+
const prompt = textarea.value.trim();
|
| 383 |
+
if (!prompt) return;
|
| 384 |
|
| 385 |
+
loadingDiv.classList.add('show');
|
| 386 |
+
submitBtn.disabled = true;
|
|
|
|
| 387 |
|
| 388 |
+
const formData = new FormData();
|
| 389 |
+
formData.append('prompt', prompt);
|
| 390 |
+
formData.append('model_type', modelType);
|
| 391 |
+
|
| 392 |
+
try {
|
| 393 |
+
const response = await fetch("/", {
|
| 394 |
+
method: "POST",
|
| 395 |
+
body: formData
|
| 396 |
+
});
|
| 397 |
|
| 398 |
+
const html = await response.text();
|
| 399 |
+
document.open();
|
| 400 |
+
document.write(html);
|
| 401 |
+
document.close();
|
| 402 |
+
} catch (err) {
|
| 403 |
+
alert("Something went wrong: " + err.message);
|
| 404 |
+
} finally {
|
| 405 |
+
submitBtn.disabled = false;
|
| 406 |
+
}
|
| 407 |
+
});
|
| 408 |
+
</script>
|
| 409 |
|
| 410 |
</body>
|
| 411 |
</html>
|