HAHAHAHA / app.py
joermd's picture
Update app.py
f03201b verified
raw
history blame
3.52 kB
import gradio as gr
import transformers
import torch
# First install required dependencies
# pip install tiktoken sentencepiece
def initialize_pipeline():
model_id = "joermd/speedy-llama2"
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
use_fast=False # Use slow tokenizer to avoid tiktoken issues
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
return pipeline, tokenizer
# Initialize pipeline and tokenizer
pipeline, tokenizer = initialize_pipeline()
def format_chat_prompt(messages, system_message):
"""Format the chat messages into a prompt the model can understand"""
formatted_messages = []
if system_message:
formatted_messages.append({"role": "system", "content": system_message})
for msg in messages:
if msg[0]: # User message
formatted_messages.append({"role": "user", "content": msg[0]})
if msg[1]: # Assistant message
formatted_messages.append({"role": "assistant", "content": msg[1]})
return formatted_messages
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""Generate response using the pipeline"""
messages = format_chat_prompt(history, system_message)
messages.append({"role": "user", "content": message})
# Define terminators
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in tokenizer.get_vocab() else None
]
terminators = [t for t in terminators if t is not None]
outputs = pipeline(
messages,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=terminators,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
)
# Extract the generated response
try:
response = outputs[0]["generated_text"]
if isinstance(response, list) and len(response) > 0 and isinstance(response[-1], dict):
response = response[-1].get("content", "")
except (IndexError, KeyError, AttributeError):
response = "I apologize, but I couldn't generate a proper response."
yield response
# Create the Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="Kamu adalah seorang asisten yang baik",
label="System message"
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
title="Chat Assistant",
description="A conversational AI assistant powered by Llama-2"
)
if __name__ == "__main__":
demo.launch()