Shining-Data commited on
Commit
d46278b
·
verified ·
1 Parent(s): 8379f24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -13
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import os
2
  import time
3
  import gc
4
- import threading
 
5
  from itertools import islice
6
  from datetime import datetime
7
  import re # for parsing <think> blocks
8
  import gradio as gr
9
  import torch
10
- from transformers import TextIteratorStreamer
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  from duckduckgo_search import DDGS
13
  # import spaces # Import spaces early to enable ZeroGPU support
@@ -23,7 +23,7 @@ else:
23
  # ------------------------------
24
  # Global Cancellation Event
25
  # ------------------------------
26
- cancel_event = threading.Event()
27
 
28
  # ------------------------------
29
  # Torch-Compatible Model Definitions with Adjusted Descriptions
@@ -38,6 +38,42 @@ MODELS = {
38
  # Global cache for pipelines to avoid re-loading.
39
  PIPELINES = {}
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def load_pipeline(model_name):
42
  """
43
  Load and cache a transformers pipeline for text generation.
@@ -101,7 +137,7 @@ def chat_response(user_msg, chat_history, system_prompt,
101
  search_results = []
102
  if enable_search:
103
  debug = 'Search task started.'
104
- thread_search = threading.Thread(
105
  target=lambda: search_results.extend(
106
  retrieve_context(user_msg, int(max_results), int(max_chars))
107
  )
@@ -142,20 +178,20 @@ def chat_response(user_msg, chat_history, system_prompt,
142
  skip_prompt=True,
143
  skip_special_tokens=True)
144
  generation_config = dict(
145
- temperature=temperature,
146
- top_k=top_k,
147
- top_p=top_p,
148
- max_new_tokens=max_tokens,
149
- do_sample=True,
150
- repetition_penalty=repeat_penalty,
151
- streamer=streamer,
152
- )
153
  inputs = pipe["tokenizer"](prompt, return_tensors="pt")
154
  if device == "auto":
155
  input_ids = inputs["input_ids"].cuda()
156
  else:
157
  input_ids = inputs["input_ids"]
158
- gen_thread = threading.Thread(target=lambda: pipe["model"].generate(input_ids=input_ids, **generation_config))
159
  gen_thread.start()
160
 
161
  # Buffers for thought vs answer
 
1
  import os
2
  import time
3
  import gc
4
+ from queue import Queue
5
+ from threading import Thread, Event
6
  from itertools import islice
7
  from datetime import datetime
8
  import re # for parsing <think> blocks
9
  import gradio as gr
10
  import torch
 
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  from duckduckgo_search import DDGS
13
  # import spaces # Import spaces early to enable ZeroGPU support
 
23
  # ------------------------------
24
  # Global Cancellation Event
25
  # ------------------------------
26
+ cancel_event = Event()
27
 
28
  # ------------------------------
29
  # Torch-Compatible Model Definitions with Adjusted Descriptions
 
38
  # Global cache for pipelines to avoid re-loading.
39
  PIPELINES = {}
40
 
41
+ class TextIterStreamer:
42
+ def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
43
+ self.tokenizer = tokenizer
44
+ self.skip_prompt = skip_prompt
45
+ self.skip_special_tokens = skip_special_tokens
46
+ self.tokens = []
47
+ self.text_queue = Queue()
48
+ # self.text_queue = []
49
+ self.next_tokens_are_prompt = True
50
+
51
+ def put(self, value):
52
+ if self.skip_prompt and self.next_tokens_are_prompt:
53
+ self.next_tokens_are_prompt = False
54
+ else:
55
+ if len(value.shape) > 1:
56
+ value = value[0]
57
+ self.tokens.extend(value.tolist())
58
+ word = self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)
59
+ # self.text_queue.append(word)
60
+ self.text_queue.put(word)
61
+
62
+ def end(self):
63
+ # self.text_queue.append(None)
64
+ self.text_queue.put(None)
65
+
66
+ def __iter__(self):
67
+ return self
68
+
69
+ def __next__(self):
70
+ value = self.text_queue.get()
71
+ if value is None:
72
+ raise StopIteration()
73
+ else:
74
+ return value
75
+
76
+
77
  def load_pipeline(model_name):
78
  """
79
  Load and cache a transformers pipeline for text generation.
 
137
  search_results = []
138
  if enable_search:
139
  debug = 'Search task started.'
140
+ thread_search = Thread(
141
  target=lambda: search_results.extend(
142
  retrieve_context(user_msg, int(max_results), int(max_chars))
143
  )
 
178
  skip_prompt=True,
179
  skip_special_tokens=True)
180
  generation_config = dict(
181
+ temperature=temperature,
182
+ top_k=top_k,
183
+ top_p=top_p,
184
+ max_new_tokens=max_tokens,
185
+ do_sample=True,
186
+ repetition_penalty=repeat_penalty,
187
+ streamer=streamer,
188
+ )
189
  inputs = pipe["tokenizer"](prompt, return_tensors="pt")
190
  if device == "auto":
191
  input_ids = inputs["input_ids"].cuda()
192
  else:
193
  input_ids = inputs["input_ids"]
194
+ gen_thread = Thread(target=lambda: pipe["model"].generate(input_ids=input_ids, **generation_config))
195
  gen_thread.start()
196
 
197
  # Buffers for thought vs answer