MongolianLlama / app.py
Dorjzodovsuren's picture
Update app.py
eca10fb verified
import os
import torch
import spaces
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, AutoTokenizer
# Model configuration
model_name = "Dorjzodovsuren/Mongolian_Llama3-v1.1"
max_seq_length = 1024
dtype = torch.float16 # or torch.bfloat16 if preferred
load_in_4bit = False # if using bitsandbytes for 4-bit loading
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# # Load model
# model = AutoModelForCausalLM.from_pretrained(
# model_name,
# device_map="auto",
# torch_dtype=dtype,
# load_in_4bit=load_in_4bit # This requires `bitsandbytes` to be installed
# )
model_id = "unsloth/llama-3.1-8b-bnb-4bit"
peft_model_id = "Dorjzodovsuren/Mongolian_Llama3-v1.1"
model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(peft_model_id)
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
# Get the device based on GPU availability
device = 'cuda'
# Move model into device
model = model.to(device)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
# Current implementation does not support conversation based on history.
# Highly recommend to experiment on various hyper parameters to compare qualities.
gpu_timeout = int(os.getenv("GPU_TIMEOUT", 60))
@spaces.GPU(duration=gpu_timeout)
def predict(message, history):
stop = StopOnTokens()
messages = alpaca_prompt.format(
message,
"",
"",
)
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
#streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
top_p=0.95,
temperature=0.001,
repetition_penalty=1.1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
# Add a simple chat example
examples = [
["What's the capital of France?"],
["What is meaning of life?"],
["Хайр гэж юу вэ?"]
]
gr.ChatInterface(predict, examples=examples).launch(debug=True, share=True, show_api=True)