File size: 3,826 Bytes
7c5c8ba 1e88a5e ead29d0 9680c53 7c5c8ba 1e88a5e ead29d0 7c5c8ba 1e88a5e 2623159 1e88a5e 2623159 53b54a3 2623159 53b54a3 2623159 1e88a5e ead29d0 1e88a5e ead29d0 1e88a5e ead29d0 1e88a5e ead29d0 1e88a5e 854ef87 1e88a5e be5517a 1e88a5e c9d96c8 1e88a5e c9d96c8 1e88a5e ead29d0 1e88a5e ead29d0 b0a1757 1e88a5e ead29d0 7c5c8ba 1e88a5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load Inference Client for the response model
client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
# Load tokenizer and model for the EOU detection
tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector")
model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector")
# Function to compute EOU probability
def compute_eou_probability(chat_ctx: list[dict[str, str]], max_tokens: int = 512) -> float:
# Extract only the 'content' from the chat context (messages) and use a list of strings for tokenization
conversation = ["Assistant ready to help."] # Add system message directly as a string
# Only append the 'content' of each message to the conversation list
for msg in chat_ctx:
content = msg.get("content", "")
if content:
conversation.append(content) # Only append the content (string)
# Tokenize the conversation content (just the text) as a list of strings
inputs = tokenizer(
conversation, padding=True, truncation=True, max_length=max_tokens, return_tensors="pt"
)
# Get model logits
with torch.no_grad():
outputs = model(**inputs)
# Get the logits for the last token in the sequence
logits = outputs.logits[0, -1, :]
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1)
# Get the EOU token index (typically "<|im_end|>" token in the model)
eou_token_id = tokenizer.encode("<|im_end|>")[0]
eou_probability = probabilities[eou_token_id].item()
return eou_probability
# Respond function with EOU checking logic
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
eou_threshold: float = 0.2, # Default EOU threshold
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
# Compute EOU probability before responding
eou_probability = compute_eou_probability(messages, max_tokens=max_tokens)
print(eou_probability)
# Only respond if EOU probability exceeds threshold
if eou_probability >= eou_threshold:
# Prepare message for assistant response
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
else:
# Let the user continue typing if the EOU probability is low
yield "Waiting for user to finish... Please continue."
print("Waiting for user to finish... Please continue.")
# Gradio UI
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Bạn là một trợ lý ảo", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(
minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="EOU Threshold"
), # Add EOU threshold slider
],
)
if __name__ == "__main__":
demo.launch()
|