beyoru commited on
Commit
1e88a5e
·
verified ·
1 Parent(s): 8eb4c94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -61
app.py CHANGED
@@ -1,34 +1,43 @@
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import numpy as np
5
 
6
- # Load tokenizer and model for EOU detection
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector")
8
  model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector")
9
 
10
- # Define function to calculate softmax
11
- def _softmax(logits: np.ndarray) -> np.ndarray:
12
- exp_logits = np.exp(logits - np.max(logits))
13
- return exp_logits / np.sum(exp_logits)
14
-
15
- # Define the EOU probability calculation
16
- def get_eou_probability(chat_ctx: list) -> float:
17
- """Calculate the probability of End of Utterance (EOU)"""
18
- text = " ".join([msg["content"] for msg in chat_ctx])
19
- inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
20
-
21
- # Run the model and get the logits
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
- logits = outputs.logits[0, -1, :] # Get logits of the last token
25
- probs = _softmax(logits.numpy()) # Convert logits to probabilities
26
-
27
- # Assuming <|im_end|> token corresponds to EOU, get the probability of that token
28
- eou_token_id = tokenizer.encode("<|im_end|>")[-1]
29
- return probs[eou_token_id]
 
 
 
 
 
 
30
 
31
- # Define the main response function for Gradio
32
  def respond(
33
  message,
34
  history: list[tuple[str, str]],
@@ -36,47 +45,46 @@ def respond(
36
  max_tokens,
37
  temperature,
38
  top_p,
39
- eou_threshold: float = 0.2 # Probability threshold to stop or transition the conversation
40
  ):
41
- # Keep only the last 4 user inputs and add the current user input
42
- user_history = [msg[0] for msg in history if msg[0]] # Extract user inputs from history
43
- user_history = user_history[-4:] # Keep the last 4 user inputs
44
- user_history.append(message) # Add the current message
45
-
46
- # Check if the EOU probability is high for the combined history (previous 4 + current input)
47
- chat_ctx = [{"role": "user", "content": msg} for msg in user_history]
48
- eou_probability = get_eou_probability(chat_ctx)
 
 
49
  print(eou_probability)
50
- # If the EOU probability is higher than the threshold, wait for the user to complete their sentence
51
- if eou_probability > eou_threshold:
52
- return f"EOU probability is high: {eou_probability:.2f}. Please complete your sentence."
 
53
 
54
- # Otherwise, generate the model's response
55
- inputs = tokenizer(system_message + "\n" + message, return_tensors="pt", max_length=max_tokens, truncation=True)
56
 
57
- # Set attention_mask to avoid issues with padding and make sure the model uses the correct pad_token_id
58
- attention_mask = inputs['attention_mask'] if 'attention_mask' in inputs else None
59
- pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- response = ""
62
- generated_output = model.generate(
63
- inputs['input_ids'],
64
- attention_mask=attention_mask,
65
- max_length=max_tokens,
66
- do_sample=True, # Enable sampling
67
- temperature=temperature,
68
- top_p=top_p,
69
- pad_token_id=pad_token_id
70
- )
71
- response = tokenizer.decode(generated_output[0], skip_special_tokens=True)
72
-
73
- return response
74
-
75
- # Gradio interface setup
76
  demo = gr.ChatInterface(
77
  respond,
78
  additional_inputs=[
79
- gr.Textbox(value="You are a assistant call Mei", label="System message"),
80
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
81
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
82
  gr.Slider(
@@ -87,14 +95,10 @@ demo = gr.ChatInterface(
87
  label="Top-p (nucleus sampling)",
88
  ),
89
  gr.Slider(
90
- minimum=0.0,
91
- maximum=1.0,
92
- value=0.9,
93
- step=0.01,
94
- label="EOU Probability Threshold"
95
- ),
96
  ],
97
  )
98
 
99
- # Launch Gradio with public link sharing
100
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import numpy as np
6
 
7
+ # Load Inference Client for the response model
8
+ client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
9
+
10
+ # Load tokenizer and model for the EOU detection
11
  tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector")
12
  model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector")
13
 
14
+ # Function to compute EOU probability
15
+ def compute_eou_probability(chat_ctx: list[dict[str, str]], max_tokens: int = 512) -> float:
16
+ # Prepare the chat context
17
+ conversation = [{"role": "system", "content": "Assistant ready to help."}] + chat_ctx
18
+
19
+ # Tokenize and prepare the input for the EOU model
20
+ inputs = tokenizer(
21
+ conversation, padding=True, truncation=True, max_length=max_tokens, return_tensors="pt"
22
+ )
23
+
24
+ # Get model logits
 
25
  with torch.no_grad():
26
  outputs = model(**inputs)
27
+
28
+ # Get the logits for the last token in the sequence
29
+ logits = outputs.logits[0, -1, :]
30
+
31
+ # Apply softmax to get probabilities
32
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
33
+
34
+ # Get the EOU token index (typically "<|im_end|>" token in the model)
35
+ eou_token_id = tokenizer.encode("<|im_end|>")[0]
36
+ eou_probability = probabilities[eou_token_id].item()
37
+
38
+ return eou_probability
39
 
40
+ # Respond function with EOU checking logic
41
  def respond(
42
  message,
43
  history: list[tuple[str, str]],
 
45
  max_tokens,
46
  temperature,
47
  top_p,
48
+ eou_threshold: float = 0.2, # Default EOU threshold
49
  ):
50
+ messages = [{"role": "system", "content": system_message}]
51
+
52
+ for val in history:
53
+ if val[0]:
54
+ messages.append({"role": "user", "content": val[0]})
55
+ if val[1]:
56
+ messages.append({"role": "assistant", "content": val[1]})
57
+
58
+ # Compute EOU probability before responding
59
+ eou_probability = compute_eou_probability(messages, max_tokens=max_tokens)
60
  print(eou_probability)
61
+ # Only respond if EOU probability exceeds threshold
62
+ if eou_probability >= eou_threshold:
63
+ # Prepare message for assistant response
64
+ messages.append({"role": "user", "content": message})
65
 
66
+ response = ""
 
67
 
68
+ for message in client.chat_completion(
69
+ messages,
70
+ max_tokens=max_tokens,
71
+ stream=True,
72
+ temperature=temperature,
73
+ top_p=top_p,
74
+ ):
75
+ token = message.choices[0].delta.content
76
+ response += token
77
+ yield response
78
+ else:
79
+ # Let the user continue typing if the EOU probability is low
80
+ yield "Waiting for user to finish... Please continue."
81
+ print("Waiting for user to finish... Please continue.")
82
 
83
+ # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  demo = gr.ChatInterface(
85
  respond,
86
  additional_inputs=[
87
+ gr.Textbox(value="Bạn một trợ ảo", label="System message"),
88
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
89
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
90
  gr.Slider(
 
95
  label="Top-p (nucleus sampling)",
96
  ),
97
  gr.Slider(
98
+ minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="EOU Threshold"
99
+ ), # Add EOU threshold slider
 
 
 
 
100
  ],
101
  )
102
 
103
+ if __name__ == "__main__":
104
+ demo.launch()