beyoru commited on
Commit
02a8fce
·
verified ·
1 Parent(s): e7db5e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -51
app.py CHANGED
@@ -1,36 +1,37 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- import string
4
- import numpy as np
5
  from transformers import AutoTokenizer
6
  import onnxruntime as ort
 
 
 
7
  import os
8
 
9
- # Initialize client and models
10
  client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))
11
 
12
- # Constants for EOU calculation
13
- PUNCS = string.punctuation.replace("'", "")
14
- MAX_HISTORY = 4
15
- MAX_HISTORY_TOKENS = 1024
16
- EOU_THRESHOLD = 0.5
17
-
18
- # Initialize tokenizer and ONNX session
19
  HG_MODEL = "livekit/turn-detector"
20
  ONNX_FILENAME = "model_quantized.onnx"
 
 
 
 
 
 
21
  tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
22
  onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
23
 
24
- # Helper functions for EOU
25
  def softmax(logits):
26
  exp_logits = np.exp(logits - np.max(logits))
27
  return exp_logits / np.sum(exp_logits)
28
 
 
29
  def normalize_text(text):
30
  def strip_puncs(text):
31
  return text.translate(str.maketrans("", "", PUNCS))
32
  return " ".join(strip_puncs(text).lower().split())
33
 
 
34
  def format_chat_ctx(chat_ctx):
35
  new_chat_ctx = []
36
  for msg in chat_ctx:
@@ -39,14 +40,19 @@ def format_chat_ctx(chat_ctx):
39
  if content:
40
  msg["content"] = content
41
  new_chat_ctx.append(msg)
 
 
42
  convo_text = tokenizer.apply_chat_template(
43
  new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
44
  )
 
 
45
  ix = convo_text.rfind("<|im_end|>")
46
  return convo_text[:ix]
47
 
 
48
  def calculate_eou(chat_ctx, session):
49
- formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
50
  inputs = tokenizer(
51
  formatted_text,
52
  return_tensors="np",
@@ -61,35 +67,42 @@ def calculate_eou(chat_ctx, session):
61
  return probs[eou_token_id]
62
 
63
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- messages = []
66
-
67
-
68
- def chatbot(user_input):
69
- global messages
70
-
71
- # Exit condition
72
- if user_input.lower() == "exit":
73
- messages = [] # Reset conversation history
74
- return "Chat ended. Refresh the page to start again."
75
 
76
-
77
- # Add user message to conversation history
78
- messages.append({"role": "user", "content": user_input})
79
 
80
- # Calculate EOU to determine if user has finished typing
81
  eou_prob = calculate_eou(messages, onnx_session)
 
 
 
82
  if eou_prob < EOU_THRESHOLD:
83
- yield "[I'm waiting for you to complete the sentence...]"
84
  return
85
 
86
- # Stream the chatbot's response
87
  stream = client.chat.completions.create(
88
- model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
89
  messages=messages,
90
- temperature=0.6,
91
- max_tokens=2200,
92
- top_p=0.95,
93
  stream=True
94
  )
95
 
@@ -97,23 +110,23 @@ def chatbot(user_input):
97
  for chunk in stream:
98
  bot_response += chunk.choices[0].delta.content
99
  yield bot_response
 
100
 
101
- # Add final bot response to conversation history
102
- messages.append({"role": "assistant", "content": bot_response})
103
-
104
- # Create Gradio interface
105
- with gr.Blocks(theme='darkdefault') as demo:
106
- gr.Markdown("""# Chat with DeepSeek""")
107
-
108
- with gr.Row():
109
- with gr.Column():
110
- user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
111
- submit_button = gr.Button("Send")
112
- with gr.Column():
113
- chat_output = gr.Textbox(label="Chatbot Response", interactive=False)
114
-
115
- # Define interactions
116
- submit_button.click(chatbot, inputs=[user_input], outputs=[chat_output])
117
 
118
- # Launch the app
119
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
  from transformers import AutoTokenizer
3
  import onnxruntime as ort
4
+ import numpy as np
5
+ import string
6
+ from huggingface_hub import InferenceClient
7
  import os
8
 
 
9
  client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))
10
 
11
+ # Model and ONNX setup
 
 
 
 
 
 
12
  HG_MODEL = "livekit/turn-detector"
13
  ONNX_FILENAME = "model_quantized.onnx"
14
+ PUNCS = string.punctuation.replace("'", "")
15
+ MAX_HISTORY = 4 # Adjusted to use the last 4 messages
16
+ MAX_HISTORY_TOKENS = 512
17
+ EOU_THRESHOLD = 0.5 # Updated threshold to match original
18
+
19
+ # Initialize ONNX model
20
  tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
21
  onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
22
 
23
+ # Softmax function
24
  def softmax(logits):
25
  exp_logits = np.exp(logits - np.max(logits))
26
  return exp_logits / np.sum(exp_logits)
27
 
28
+ # Normalize text
29
  def normalize_text(text):
30
  def strip_puncs(text):
31
  return text.translate(str.maketrans("", "", PUNCS))
32
  return " ".join(strip_puncs(text).lower().split())
33
 
34
+ # Format chat context
35
  def format_chat_ctx(chat_ctx):
36
  new_chat_ctx = []
37
  for msg in chat_ctx:
 
40
  if content:
41
  msg["content"] = content
42
  new_chat_ctx.append(msg)
43
+
44
+ # Tokenize with chat template
45
  convo_text = tokenizer.apply_chat_template(
46
  new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
47
  )
48
+
49
+ # Remove EOU token from the current utterance
50
  ix = convo_text.rfind("<|im_end|>")
51
  return convo_text[:ix]
52
 
53
+ # Calculate EOU probability
54
  def calculate_eou(chat_ctx, session):
55
+ formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:]) # Use the last 4 messages
56
  inputs = tokenizer(
57
  formatted_text,
58
  return_tensors="np",
 
67
  return probs[eou_token_id]
68
 
69
 
70
+ # Respond function
71
+ def respond(
72
+ message,
73
+ history: list[tuple[str, str]],
74
+ max_tokens,
75
+ temperature,
76
+ top_p,
77
+ ):
78
+ # Keep the last 4 conversation pairs (user-assistant)
79
+ messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC")}]
80
 
81
+ for val in history[-20:]:
82
+ if val[0]:
83
+ messages.append({"role": "user", "content": val[0]})
84
+ if val[1]:
85
+ messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
86
 
87
+ # Add the new user message to the context
88
+ messages.append({"role": "user", "content": message})
 
89
 
90
+ # Calculate EOU probability
91
  eou_prob = calculate_eou(messages, onnx_session)
92
+ print(f"EOU Probability: {eou_prob}") # Debug output
93
+
94
+ # If EOU is below the threshold, ask for more input
95
  if eou_prob < EOU_THRESHOLD:
96
+ yield "[Waiting for user to continue input...]"
97
  return
98
 
99
+
100
  stream = client.chat.completions.create(
101
+ model=os.environ.get('MODEL_ID'),
102
  messages=messages,
103
+ temperature = 0.6,
104
+ max_tokens= 2048,
105
+ top_p = 0.9,
106
  stream=True
107
  )
108
 
 
110
  for chunk in stream:
111
  bot_response += chunk.choices[0].delta.content
112
  yield bot_response
113
+
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # Gradio interface
117
+ demo = gr.ChatInterface(
118
+ respond,
119
+ # additional_inputs=[
120
+ # # Commented out to disable user modification of the system message
121
+ # # gr.Textbox(value="You are an assistant.", label="System message"),
122
+ # gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
123
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
124
+ # gr.Slider(
125
+ # minimum=0.1,
126
+ # maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
127
+ # ),
128
+ # ],
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch()