Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,109 +1,69 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoTokenizer
|
3 |
-
from auto_gptq import AutoGPTQForCausalLM
|
4 |
-
from transformers import pipeline
|
5 |
-
import torch
|
6 |
-
import os
|
7 |
|
8 |
# ---------------- CONFIG ----------------
|
9 |
-
MODEL_NAME = "
|
10 |
-
MODEL_TYPE = "GPTQ" # AutoGPTQ model type
|
11 |
-
TRANSLATOR_MODEL = "facebook/m2m100_418M"
|
12 |
-
|
13 |
SYSTEM_PROMPT_DEFAULT = (
|
14 |
-
"You are
|
15 |
-
"Always respond
|
16 |
)
|
17 |
|
18 |
-
|
19 |
-
MAX_NEW_TOKENS_DEFAULT = 300
|
20 |
TEMP_DEFAULT = 0.3
|
21 |
TOP_P_DEFAULT = 0.9
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
model = AutoGPTQForCausalLM.from_quantized(
|
32 |
-
MODEL_NAME,
|
33 |
-
device=DEVICE,
|
34 |
-
use_triton=False,
|
35 |
-
use_safetensors=True,
|
36 |
-
offload_folder=offload_dir
|
37 |
)
|
38 |
-
print("Model loaded successfully!")
|
39 |
-
|
40 |
-
# ---------------- TRANSLATOR ----------------
|
41 |
-
translator = pipeline("translation", model=TRANSLATOR_MODEL)
|
42 |
|
43 |
# ---------------- HELPERS ----------------
|
44 |
-
def
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
def format_prompt(message, system_message):
|
52 |
-
return f"{system_message}\nUser: {message}\nAssistant:"
|
53 |
-
|
54 |
-
# ---------------- CHAT FUNCTION ----------------
|
55 |
-
def stream_response(message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
|
56 |
-
# Translation handling
|
57 |
-
if is_translation_request(message):
|
58 |
-
try:
|
59 |
-
translated = translator(message, src_lang="auto", tgt_lang="en")[0]["translation_text"]
|
60 |
-
chat_history.append({"role": "user", "content": message})
|
61 |
-
chat_history.append({"role": "assistant", "content": translated})
|
62 |
-
yield "", chat_history
|
63 |
-
return
|
64 |
-
except Exception as e:
|
65 |
-
chat_history.append({"role": "user", "content": message})
|
66 |
-
chat_history.append({"role": "assistant", "content": f"⚠️ Translation failed: {str(e)}"})
|
67 |
-
yield "", chat_history
|
68 |
-
return
|
69 |
-
|
70 |
-
# Apply response style
|
71 |
if response_style == "No explanation":
|
72 |
-
|
73 |
elif response_style == "Short explanation":
|
74 |
-
|
75 |
elif response_style == "Detailed explanation":
|
76 |
-
|
77 |
-
|
78 |
-
prompt
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
chat_history
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
yield "", chat_history
|
101 |
|
102 |
# ---------------- UI ----------------
|
103 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
|
104 |
-
gr.Markdown("#
|
105 |
|
106 |
-
chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True
|
107 |
|
108 |
with gr.Row():
|
109 |
msg = gr.Textbox(label="💬 Your Message", placeholder="Type here…", scale=6)
|
@@ -121,13 +81,17 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink"))
|
|
121 |
top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
|
122 |
max_tokens = gr.Slider(32, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
clear_btn.click(lambda: [], None, chatbot, queue=False)
|
128 |
|
129 |
-
gr.Markdown("---")
|
130 |
-
gr.Markdown("🔗 Built with ❤️ using [Vicuna-7B](https://huggingface.co/TheBloke/vicuna-7b-1.1-HF) & [M2M100](https://huggingface.co/facebook/m2m100_418M).")
|
131 |
-
|
132 |
if __name__ == "__main__":
|
133 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# ---------------- CONFIG ----------------
|
5 |
+
MODEL_NAME = "google/gemma-2b-it-pytorch" # instruction-tuned Gemma 2B
|
|
|
|
|
|
|
6 |
SYSTEM_PROMPT_DEFAULT = (
|
7 |
+
"You are a formal and polite AI assistant. "
|
8 |
+
"Always respond appropriately depending on the selected explanation style."
|
9 |
)
|
10 |
|
11 |
+
MAX_NEW_TOKENS_DEFAULT = 256
|
|
|
12 |
TEMP_DEFAULT = 0.3
|
13 |
TOP_P_DEFAULT = 0.9
|
14 |
|
15 |
+
# ---------------- LOAD MODEL ----------------
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
18 |
+
generator = pipeline(
|
19 |
+
"text-generation",
|
20 |
+
model=model,
|
21 |
+
tokenizer=tokenizer,
|
22 |
+
device=0 if model.device.type == "cuda" else -1,
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
)
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# ---------------- HELPERS ----------------
|
26 |
+
def format_prompt(chat_history, user_message, system_message, response_style):
|
27 |
+
prompt = system_message + "\n\n"
|
28 |
+
for turn in chat_history:
|
29 |
+
prompt += f"{turn['role'].capitalize()}: {turn['content']}\n"
|
30 |
+
prompt += f"User: {user_message}\nAssistant:"
|
31 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
if response_style == "No explanation":
|
33 |
+
prompt += " Only provide the direct answer with no explanation."
|
34 |
elif response_style == "Short explanation":
|
35 |
+
prompt += " Provide a concise answer with a one-sentence explanation."
|
36 |
elif response_style == "Detailed explanation":
|
37 |
+
prompt += " Provide a thorough and detailed answer with reasoning and examples."
|
38 |
+
|
39 |
+
return prompt
|
40 |
|
41 |
+
# ---------------- CHAT FUNCTION ----------------
|
42 |
+
def chat(user_message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
|
43 |
+
chat_history = chat_history or []
|
44 |
+
prompt = format_prompt(chat_history, user_message, system_message, response_style)
|
45 |
+
|
46 |
+
output = generator(
|
47 |
+
prompt,
|
48 |
+
max_new_tokens=max_tokens,
|
49 |
+
do_sample=True,
|
50 |
+
temperature=temperature,
|
51 |
+
top_p=top_p,
|
52 |
+
)[0]['generated_text']
|
53 |
+
|
54 |
+
# Remove prompt part from output
|
55 |
+
response = output[len(prompt):].strip()
|
56 |
+
|
57 |
+
chat_history.append({"role": "user", "content": user_message})
|
58 |
+
chat_history.append({"role": "assistant", "content": response})
|
59 |
+
|
60 |
+
return "", chat_history
|
|
|
61 |
|
62 |
# ---------------- UI ----------------
|
63 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
|
64 |
+
gr.Markdown("# 🧠 Gemma 2B Chat Assistant")
|
65 |
|
66 |
+
chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True)
|
67 |
|
68 |
with gr.Row():
|
69 |
msg = gr.Textbox(label="💬 Your Message", placeholder="Type here…", scale=6)
|
|
|
81 |
top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
|
82 |
max_tokens = gr.Slider(32, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
|
83 |
|
84 |
+
send_btn.click(
|
85 |
+
chat,
|
86 |
+
[msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
|
87 |
+
[msg, chatbot]
|
88 |
+
)
|
89 |
+
msg.submit(
|
90 |
+
chat,
|
91 |
+
[msg, chatbot, system_prompt, max_tokens, temperature, top_p, response_style],
|
92 |
+
[msg, chatbot]
|
93 |
+
)
|
94 |
clear_btn.click(lambda: [], None, chatbot, queue=False)
|
95 |
|
|
|
|
|
|
|
96 |
if __name__ == "__main__":
|
97 |
demo.launch()
|