File size: 3,518 Bytes
db71016
f03201b
 
db71016
f03201b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import transformers
import torch

# First install required dependencies
# pip install tiktoken sentencepiece

def initialize_pipeline():
    model_id = "joermd/speedy-llama2"
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=True,
        use_fast=False  # Use slow tokenizer to avoid tiktoken issues
    )
    
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    
    pipeline = transformers.pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map="auto"
    )
    
    return pipeline, tokenizer

# Initialize pipeline and tokenizer
pipeline, tokenizer = initialize_pipeline()

def format_chat_prompt(messages, system_message):
    """Format the chat messages into a prompt the model can understand"""
    formatted_messages = []
    if system_message:
        formatted_messages.append({"role": "system", "content": system_message})
    
    for msg in messages:
        if msg[0]:  # User message
            formatted_messages.append({"role": "user", "content": msg[0]})
        if msg[1]:  # Assistant message
            formatted_messages.append({"role": "assistant", "content": msg[1]})
    
    return formatted_messages

def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """Generate response using the pipeline"""
    messages = format_chat_prompt(history, system_message)
    messages.append({"role": "user", "content": message})
    
    # Define terminators
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in tokenizer.get_vocab() else None
    ]
    terminators = [t for t in terminators if t is not None]
    
    outputs = pipeline(
        messages,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        eos_token_id=terminators,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
    )
    
    # Extract the generated response
    try:
        response = outputs[0]["generated_text"]
        if isinstance(response, list) and len(response) > 0 and isinstance(response[-1], dict):
            response = response[-1].get("content", "")
    except (IndexError, KeyError, AttributeError):
        response = "I apologize, but I couldn't generate a proper response."
    
    yield response

# Create the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="Kamu adalah seorang asisten yang baik",
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
    title="Chat Assistant",
    description="A conversational AI assistant powered by Llama-2"
)

if __name__ == "__main__":
    demo.launch()