|
import gradio as gr |
|
import transformers |
|
import torch |
|
|
|
|
|
|
|
|
|
def initialize_pipeline(): |
|
model_id = "joermd/speedy-llama2" |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
use_fast=False |
|
) |
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device_map="auto" |
|
) |
|
|
|
return pipeline, tokenizer |
|
|
|
|
|
pipeline, tokenizer = initialize_pipeline() |
|
|
|
def format_chat_prompt(messages, system_message): |
|
"""Format the chat messages into a prompt the model can understand""" |
|
formatted_messages = [] |
|
if system_message: |
|
formatted_messages.append({"role": "system", "content": system_message}) |
|
|
|
for msg in messages: |
|
if msg[0]: |
|
formatted_messages.append({"role": "user", "content": msg[0]}) |
|
if msg[1]: |
|
formatted_messages.append({"role": "assistant", "content": msg[1]}) |
|
|
|
return formatted_messages |
|
|
|
def respond( |
|
message: str, |
|
history: list[tuple[str, str]], |
|
system_message: str, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
): |
|
"""Generate response using the pipeline""" |
|
messages = format_chat_prompt(history, system_message) |
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in tokenizer.get_vocab() else None |
|
] |
|
terminators = [t for t in terminators if t is not None] |
|
|
|
outputs = pipeline( |
|
messages, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
eos_token_id=terminators, |
|
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id, |
|
) |
|
|
|
|
|
try: |
|
response = outputs[0]["generated_text"] |
|
if isinstance(response, list) and len(response) > 0 and isinstance(response[-1], dict): |
|
response = response[-1].get("content", "") |
|
except (IndexError, KeyError, AttributeError): |
|
response = "I apologize, but I couldn't generate a proper response." |
|
|
|
yield response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
additional_inputs=[ |
|
gr.Textbox( |
|
value="Kamu adalah seorang asisten yang baik", |
|
label="System message" |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=2048, |
|
value=512, |
|
step=1, |
|
label="Max new tokens" |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=4.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature" |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)" |
|
), |
|
], |
|
title="Chat Assistant", |
|
description="A conversational AI assistant powered by Llama-2" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |