Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,34 +1,41 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
from transformers import AutoTokenizer
|
4 |
from auto_gptq import AutoGPTQForCausalLM
|
5 |
from transformers import pipeline
|
|
|
|
|
6 |
|
7 |
# ---------------- CONFIG ----------------
|
8 |
-
|
|
|
9 |
TRANSLATOR_MODEL = "facebook/m2m100_418M"
|
10 |
|
11 |
SYSTEM_PROMPT_DEFAULT = (
|
12 |
-
"You are Vicuna
|
13 |
"Always respond formally and answer appropriately depending on the selected explanation style."
|
14 |
)
|
15 |
|
16 |
-
#
|
17 |
MAX_NEW_TOKENS_DEFAULT = 300
|
18 |
TEMP_DEFAULT = 0.3
|
19 |
TOP_P_DEFAULT = 0.9
|
20 |
|
21 |
-
|
22 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
|
24 |
-
# ---------------- MODEL
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
model = AutoGPTQForCausalLM.from_quantized(
|
27 |
-
|
28 |
-
device=
|
|
|
29 |
use_safetensors=True,
|
30 |
-
|
31 |
)
|
|
|
32 |
|
33 |
# ---------------- TRANSLATOR ----------------
|
34 |
translator = pipeline("translation", model=TRANSLATOR_MODEL)
|
@@ -41,8 +48,12 @@ def is_translation_request(message: str) -> bool:
|
|
41 |
non_ascii_ratio = sum(1 for c in message if ord(c) > 127) / max(len(message), 1)
|
42 |
return non_ascii_ratio > 0.4
|
43 |
|
|
|
|
|
|
|
44 |
# ---------------- CHAT FUNCTION ----------------
|
45 |
def stream_response(message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
|
|
|
46 |
if is_translation_request(message):
|
47 |
try:
|
48 |
translated = translator(message, src_lang="auto", tgt_lang="en")[0]["translation_text"]
|
@@ -64,39 +75,33 @@ def stream_response(message, chat_history, system_message, max_tokens, temperatu
|
|
64 |
elif response_style == "Detailed explanation":
|
65 |
system_message += " Provide a thorough and detailed answer with reasoning and examples."
|
66 |
|
67 |
-
|
68 |
-
formatted_prompt = system_message + "\n\n"
|
69 |
-
for turn in chat_history:
|
70 |
-
formatted_prompt += f"{turn['role'].capitalize()}: {turn['content']}\n"
|
71 |
-
formatted_prompt += f"User: {message}\nAssistant:"
|
72 |
|
73 |
-
# Append user turn
|
74 |
chat_history.append({"role": "user", "content": message})
|
75 |
-
response = ""
|
76 |
chat_history.append({"role": "assistant", "content": ""}) # placeholder
|
77 |
|
78 |
try:
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
89 |
chat_history[-1]["content"] = response
|
90 |
yield "", chat_history
|
91 |
except Exception as e:
|
92 |
chat_history[-1]["content"] = f"β οΈ Error generating response: {str(e)}"
|
93 |
yield "", chat_history
|
94 |
|
95 |
-
yield "", chat_history
|
96 |
-
|
97 |
# ---------------- UI ----------------
|
98 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
|
99 |
-
gr.Markdown("# π¦ Vicuna-
|
100 |
|
101 |
chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True, label="Chat Assistant")
|
102 |
|
@@ -122,7 +127,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink"))
|
|
122 |
clear_btn.click(lambda: [], None, chatbot, queue=False)
|
123 |
|
124 |
gr.Markdown("---")
|
125 |
-
gr.Markdown("π Built with β€οΈ using [Vicuna-
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import AutoTokenizer
|
3 |
from auto_gptq import AutoGPTQForCausalLM
|
4 |
from transformers import pipeline
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
|
8 |
# ---------------- CONFIG ----------------
|
9 |
+
MODEL_NAME = "TheBloke/vicuna-7b-1.1-HF"
|
10 |
+
MODEL_TYPE = "GPTQ" # AutoGPTQ model type
|
11 |
TRANSLATOR_MODEL = "facebook/m2m100_418M"
|
12 |
|
13 |
SYSTEM_PROMPT_DEFAULT = (
|
14 |
+
"You are Vicuna 7B, a formal and polite AI assistant. "
|
15 |
"Always respond formally and answer appropriately depending on the selected explanation style."
|
16 |
)
|
17 |
|
18 |
+
# Defaults
|
19 |
MAX_NEW_TOKENS_DEFAULT = 300
|
20 |
TEMP_DEFAULT = 0.3
|
21 |
TOP_P_DEFAULT = 0.9
|
22 |
|
23 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
24 |
|
25 |
+
# ---------------- MODEL ----------------
|
26 |
+
offload_dir = "./offload" # Folder for offloading if VRAM is insufficient
|
27 |
+
os.makedirs(offload_dir, exist_ok=True)
|
28 |
+
|
29 |
+
print("Loading tokenizer and GPTQ model...")
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
31 |
model = AutoGPTQForCausalLM.from_quantized(
|
32 |
+
MODEL_NAME,
|
33 |
+
device=DEVICE,
|
34 |
+
use_triton=False,
|
35 |
use_safetensors=True,
|
36 |
+
offload_folder=offload_dir
|
37 |
)
|
38 |
+
print("Model loaded successfully!")
|
39 |
|
40 |
# ---------------- TRANSLATOR ----------------
|
41 |
translator = pipeline("translation", model=TRANSLATOR_MODEL)
|
|
|
48 |
non_ascii_ratio = sum(1 for c in message if ord(c) > 127) / max(len(message), 1)
|
49 |
return non_ascii_ratio > 0.4
|
50 |
|
51 |
+
def format_prompt(message, system_message):
|
52 |
+
return f"{system_message}\nUser: {message}\nAssistant:"
|
53 |
+
|
54 |
# ---------------- CHAT FUNCTION ----------------
|
55 |
def stream_response(message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
|
56 |
+
# Translation handling
|
57 |
if is_translation_request(message):
|
58 |
try:
|
59 |
translated = translator(message, src_lang="auto", tgt_lang="en")[0]["translation_text"]
|
|
|
75 |
elif response_style == "Detailed explanation":
|
76 |
system_message += " Provide a thorough and detailed answer with reasoning and examples."
|
77 |
|
78 |
+
prompt = format_prompt(message, system_message)
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
# Append user turn
|
81 |
chat_history.append({"role": "user", "content": message})
|
|
|
82 |
chat_history.append({"role": "assistant", "content": ""}) # placeholder
|
83 |
|
84 |
try:
|
85 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
|
86 |
+
# Generate with streaming
|
87 |
+
with torch.no_grad():
|
88 |
+
output_ids = model.generate(
|
89 |
+
input_ids,
|
90 |
+
max_new_tokens=max_tokens,
|
91 |
+
do_sample=True,
|
92 |
+
temperature=temperature,
|
93 |
+
top_p=top_p
|
94 |
+
)
|
95 |
+
response = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
|
96 |
chat_history[-1]["content"] = response
|
97 |
yield "", chat_history
|
98 |
except Exception as e:
|
99 |
chat_history[-1]["content"] = f"β οΈ Error generating response: {str(e)}"
|
100 |
yield "", chat_history
|
101 |
|
|
|
|
|
102 |
# ---------------- UI ----------------
|
103 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
|
104 |
+
gr.Markdown("# π¦ Vicuna-7B Chat + π Translator")
|
105 |
|
106 |
chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True, label="Chat Assistant")
|
107 |
|
|
|
127 |
clear_btn.click(lambda: [], None, chatbot, queue=False)
|
128 |
|
129 |
gr.Markdown("---")
|
130 |
+
gr.Markdown("π Built with β€οΈ using [Vicuna-7B](https://huggingface.co/TheBloke/vicuna-7b-1.1-HF) & [M2M100](https://huggingface.co/facebook/m2m100_418M).")
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
demo.launch()
|