Asilbek14 commited on
Commit
9c8aa08
Β·
verified Β·
1 Parent(s): 8d6719c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -32
app.py CHANGED
@@ -1,34 +1,41 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import AutoTokenizer
4
  from auto_gptq import AutoGPTQForCausalLM
5
  from transformers import pipeline
 
 
6
 
7
  # ---------------- CONFIG ----------------
8
- MODEL_REPO = "TheBloke/vicuna-13b-v1.3.0-GPTQ"
 
9
  TRANSLATOR_MODEL = "facebook/m2m100_418M"
10
 
11
  SYSTEM_PROMPT_DEFAULT = (
12
- "You are Vicuna-13B, a formal and polite AI assistant. "
13
  "Always respond formally and answer appropriately depending on the selected explanation style."
14
  )
15
 
16
- # βœ… Optimized defaults
17
  MAX_NEW_TOKENS_DEFAULT = 300
18
  TEMP_DEFAULT = 0.3
19
  TOP_P_DEFAULT = 0.9
20
 
21
- # ---------------- DEVICE ----------------
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # ---------------- MODEL & TOKENIZER ----------------
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, use_fast=True)
 
 
 
 
26
  model = AutoGPTQForCausalLM.from_quantized(
27
- MODEL_REPO,
28
- device=device,
 
29
  use_safetensors=True,
30
- device_map="auto"
31
  )
 
32
 
33
  # ---------------- TRANSLATOR ----------------
34
  translator = pipeline("translation", model=TRANSLATOR_MODEL)
@@ -41,8 +48,12 @@ def is_translation_request(message: str) -> bool:
41
  non_ascii_ratio = sum(1 for c in message if ord(c) > 127) / max(len(message), 1)
42
  return non_ascii_ratio > 0.4
43
 
 
 
 
44
  # ---------------- CHAT FUNCTION ----------------
45
  def stream_response(message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
 
46
  if is_translation_request(message):
47
  try:
48
  translated = translator(message, src_lang="auto", tgt_lang="en")[0]["translation_text"]
@@ -64,39 +75,33 @@ def stream_response(message, chat_history, system_message, max_tokens, temperatu
64
  elif response_style == "Detailed explanation":
65
  system_message += " Provide a thorough and detailed answer with reasoning and examples."
66
 
67
- # Format chat into a single prompt
68
- formatted_prompt = system_message + "\n\n"
69
- for turn in chat_history:
70
- formatted_prompt += f"{turn['role'].capitalize()}: {turn['content']}\n"
71
- formatted_prompt += f"User: {message}\nAssistant:"
72
 
73
- # Append user turn first
74
  chat_history.append({"role": "user", "content": message})
75
- response = ""
76
  chat_history.append({"role": "assistant", "content": ""}) # placeholder
77
 
78
  try:
79
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
80
- generation = model.generate(
81
- **inputs,
82
- max_new_tokens=max_tokens,
83
- temperature=temperature,
84
- top_p=top_p
85
- )
86
-
87
- output_text = tokenizer.decode(generation[0], skip_special_tokens=True)
88
- response += output_text.split("Assistant:")[-1].strip()
 
89
  chat_history[-1]["content"] = response
90
  yield "", chat_history
91
  except Exception as e:
92
  chat_history[-1]["content"] = f"⚠️ Error generating response: {str(e)}"
93
  yield "", chat_history
94
 
95
- yield "", chat_history
96
-
97
  # ---------------- UI ----------------
98
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
99
- gr.Markdown("# πŸ¦™ Vicuna-13B Chat + 🌍 Translator")
100
 
101
  chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True, label="Chat Assistant")
102
 
@@ -122,7 +127,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink"))
122
  clear_btn.click(lambda: [], None, chatbot, queue=False)
123
 
124
  gr.Markdown("---")
125
- gr.Markdown("πŸ”— Built with ❀️ using [Vicuna-13B GPTQ](https://huggingface.co/TheBloke/vicuna-13b-v1.3.0-GPTQ) & [M2M100](https://huggingface.co/facebook/m2m100_418M).")
126
 
127
  if __name__ == "__main__":
128
  demo.launch()
 
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer
3
  from auto_gptq import AutoGPTQForCausalLM
4
  from transformers import pipeline
5
+ import torch
6
+ import os
7
 
8
  # ---------------- CONFIG ----------------
9
+ MODEL_NAME = "TheBloke/vicuna-7b-1.1-HF"
10
+ MODEL_TYPE = "GPTQ" # AutoGPTQ model type
11
  TRANSLATOR_MODEL = "facebook/m2m100_418M"
12
 
13
  SYSTEM_PROMPT_DEFAULT = (
14
+ "You are Vicuna 7B, a formal and polite AI assistant. "
15
  "Always respond formally and answer appropriately depending on the selected explanation style."
16
  )
17
 
18
+ # Defaults
19
  MAX_NEW_TOKENS_DEFAULT = 300
20
  TEMP_DEFAULT = 0.3
21
  TOP_P_DEFAULT = 0.9
22
 
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
24
 
25
+ # ---------------- MODEL ----------------
26
+ offload_dir = "./offload" # Folder for offloading if VRAM is insufficient
27
+ os.makedirs(offload_dir, exist_ok=True)
28
+
29
+ print("Loading tokenizer and GPTQ model...")
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
31
  model = AutoGPTQForCausalLM.from_quantized(
32
+ MODEL_NAME,
33
+ device=DEVICE,
34
+ use_triton=False,
35
  use_safetensors=True,
36
+ offload_folder=offload_dir
37
  )
38
+ print("Model loaded successfully!")
39
 
40
  # ---------------- TRANSLATOR ----------------
41
  translator = pipeline("translation", model=TRANSLATOR_MODEL)
 
48
  non_ascii_ratio = sum(1 for c in message if ord(c) > 127) / max(len(message), 1)
49
  return non_ascii_ratio > 0.4
50
 
51
+ def format_prompt(message, system_message):
52
+ return f"{system_message}\nUser: {message}\nAssistant:"
53
+
54
  # ---------------- CHAT FUNCTION ----------------
55
  def stream_response(message, chat_history, system_message, max_tokens, temperature, top_p, response_style):
56
+ # Translation handling
57
  if is_translation_request(message):
58
  try:
59
  translated = translator(message, src_lang="auto", tgt_lang="en")[0]["translation_text"]
 
75
  elif response_style == "Detailed explanation":
76
  system_message += " Provide a thorough and detailed answer with reasoning and examples."
77
 
78
+ prompt = format_prompt(message, system_message)
 
 
 
 
79
 
80
+ # Append user turn
81
  chat_history.append({"role": "user", "content": message})
 
82
  chat_history.append({"role": "assistant", "content": ""}) # placeholder
83
 
84
  try:
85
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
86
+ # Generate with streaming
87
+ with torch.no_grad():
88
+ output_ids = model.generate(
89
+ input_ids,
90
+ max_new_tokens=max_tokens,
91
+ do_sample=True,
92
+ temperature=temperature,
93
+ top_p=top_p
94
+ )
95
+ response = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
96
  chat_history[-1]["content"] = response
97
  yield "", chat_history
98
  except Exception as e:
99
  chat_history[-1]["content"] = f"⚠️ Error generating response: {str(e)}"
100
  yield "", chat_history
101
 
 
 
102
  # ---------------- UI ----------------
103
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet", secondary_hue="pink")) as demo:
104
+ gr.Markdown("# πŸ¦™ Vicuna-7B Chat + 🌍 Translator")
105
 
106
  chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True, label="Chat Assistant")
107
 
 
127
  clear_btn.click(lambda: [], None, chatbot, queue=False)
128
 
129
  gr.Markdown("---")
130
+ gr.Markdown("πŸ”— Built with ❀️ using [Vicuna-7B](https://huggingface.co/TheBloke/vicuna-7b-1.1-HF) & [M2M100](https://huggingface.co/facebook/m2m100_418M).")
131
 
132
  if __name__ == "__main__":
133
  demo.launch()