import gradio as gr from huggingface_hub import InferenceClient from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Load Inference Client for the response model client = InferenceClient("Qwen/Qwen2.5-3B-Instruct") # Load tokenizer and model for the EOU detection tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector") model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector") # Function to compute EOU probability def compute_eou_probability(chat_ctx: list[dict[str, str]], max_tokens: int = 512) -> float: # Extract only the 'content' from the chat context (messages) and use a list of strings for tokenization conversation = ["Assistant ready to help."] # Add system message directly as a string # Only append the 'content' of each message to the conversation list for msg in chat_ctx: content = msg.get("content", "") if content: conversation.append(content) # Only append the content (string) # Tokenize the conversation content (just the text) as a list of strings inputs = tokenizer( conversation, padding=True, truncation=True, max_length=max_tokens, return_tensors="pt" ) # Get model logits with torch.no_grad(): outputs = model(**inputs) # Get the logits for the last token in the sequence logits = outputs.logits[0, -1, :] # Apply softmax to get probabilities probabilities = torch.nn.functional.softmax(logits, dim=-1) # Get the EOU token index (typically "<|im_end|>" token in the model) eou_token_id = tokenizer.encode("<|im_end|>")[0] eou_probability = probabilities[eou_token_id].item() return eou_probability # Respond function with EOU checking logic def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, eou_threshold: float = 0.2, # Default EOU threshold ): messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) # Compute EOU probability before responding eou_probability = compute_eou_probability(messages, max_tokens=max_tokens) print(eou_probability) # Only respond if EOU probability exceeds threshold if eou_probability >= eou_threshold: # Prepare message for assistant response messages.append({"role": "user", "content": message}) response = "" for message in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = message.choices[0].delta.content response += token yield response else: # Let the user continue typing if the EOU probability is low yield "Waiting for user to finish... Please continue." print("Waiting for user to finish... Please continue.") # Gradio UI demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="Bạn là một trợ lý ảo", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), gr.Slider( minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="EOU Threshold" ), # Add EOU threshold slider ], ) if __name__ == "__main__": demo.launch()