File size: 3,826 Bytes
7c5c8ba
1e88a5e
ead29d0
9680c53
7c5c8ba
1e88a5e
 
 
 
ead29d0
 
7c5c8ba
1e88a5e
 
2623159
 
1e88a5e
2623159
53b54a3
 
 
2623159
53b54a3
2623159
1e88a5e
 
 
 
 
ead29d0
 
1e88a5e
 
 
 
 
 
 
 
 
 
 
 
ead29d0
1e88a5e
ead29d0
 
 
 
 
 
 
1e88a5e
ead29d0
1e88a5e
 
 
 
 
 
 
 
 
 
854ef87
1e88a5e
 
 
 
be5517a
1e88a5e
c9d96c8
1e88a5e
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d96c8
1e88a5e
ead29d0
 
 
1e88a5e
ead29d0
 
 
 
 
 
 
 
 
b0a1757
1e88a5e
 
ead29d0
 
7c5c8ba
1e88a5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load Inference Client for the response model
client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")

# Load tokenizer and model for the EOU detection
tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector")
model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector")

# Function to compute EOU probability
def compute_eou_probability(chat_ctx: list[dict[str, str]], max_tokens: int = 512) -> float:
    # Extract only the 'content' from the chat context (messages) and use a list of strings for tokenization
    conversation = ["Assistant ready to help."]  # Add system message directly as a string
    
    # Only append the 'content' of each message to the conversation list
    for msg in chat_ctx:
        content = msg.get("content", "")
        if content:
            conversation.append(content)  # Only append the content (string)

    # Tokenize the conversation content (just the text) as a list of strings
    inputs = tokenizer(
        conversation, padding=True, truncation=True, max_length=max_tokens, return_tensors="pt"
    )
    
    # Get model logits
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get the logits for the last token in the sequence
    logits = outputs.logits[0, -1, :]
    
    # Apply softmax to get probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    
    # Get the EOU token index (typically "<|im_end|>" token in the model)
    eou_token_id = tokenizer.encode("<|im_end|>")[0]
    eou_probability = probabilities[eou_token_id].item()
    
    return eou_probability

# Respond function with EOU checking logic
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    eou_threshold: float = 0.2,  # Default EOU threshold
):
    messages = [{"role": "system", "content": system_message}]
    
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    
    # Compute EOU probability before responding
    eou_probability = compute_eou_probability(messages, max_tokens=max_tokens)
    print(eou_probability)
    # Only respond if EOU probability exceeds threshold
    if eou_probability >= eou_threshold:
        # Prepare message for assistant response
        messages.append({"role": "user", "content": message})

        response = ""

        for message in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = message.choices[0].delta.content
            response += token
            yield response
    else:
        # Let the user continue typing if the EOU probability is low
        yield "Waiting for user to finish... Please continue."
        print("Waiting for user to finish... Please continue.")

# Gradio UI
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Bạn là một trợ lý ảo", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(
            minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="EOU Threshold"
        ),  # Add EOU threshold slider
    ],
)

if __name__ == "__main__":
    demo.launch()