File size: 3,984 Bytes
2bc99a0 927b5de 2bc99a0 a3c3064 2bc99a0 0d5c130 63a0917 2bc99a0 a3c3064 2bc99a0 0d5c130 63a0917 2bc99a0 0d5c130 2bc99a0 a3c3064 2bc99a0 2407fc5 0d5c130 a3c3064 2bc99a0 0d5c130 2bc99a0 0d5c130 2bc99a0 a3c3064 2bc99a0 8de5029 1874bf4 2bc99a0 1874bf4 2bc99a0 2407fc5 2bc99a0 1874bf4 edc6972 927b5de 2407fc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import math
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import sentencepiece
title = "Welcome to Tonic's 🐋🐳Orca-2-13B (in 8bit)!"
description = "You can use [🐋🐳microsoft/Orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b) via API using Gradio by scrolling down and clicking Use 'Via API' or privately by [cloning this space on huggingface](https://huggingface.co/spaces/Tonic1/TonicsOrca2?duplicate=true) . [Join my active builders' server on discord](https://discord.gg/VqTxc76K3u). Let's build together! Big thanks to the HuggingFace Organisation for the Community Grant."
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "microsoft/Orca-2-13b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
class OrcaChatBot:
def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
self.model = model
self.tokenizer = tokenizer
self.system_message = system_message
self.conversation_history = []
def update_conversation_history(self, user_message, assistant_message):
self.conversation_history.append(("user", user_message))
self.conversation_history.append(("assistant", assistant_message))
def format_prompt(self):
prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n"
for role, message in self.conversation_history:
if message.strip():
prompt += f"<|im_start|>{role}\n{message}<|im_end|>\n"
# if role == "assistant":
# prompt += f"<|im_end|>\n"
prompt += "<|im_start|> assistant\n"
return prompt
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
self.update_conversation_history(user_message, "")
prompt = self.format_prompt()
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
input_ids = inputs["input_ids"].to(self.model.device)
output_ids = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
# pad_token_id=self.tokenizer.eos_token_id,
do_sample=True
)
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
self.update_conversation_history("", response)
return response
Orca_bot = OrcaChatBot(model, tokenizer)
def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
full_message = f"{system_message}\n{user_message}" if system_message else user_message
return Orca_bot.predict(full_message, temperature, max_new_tokens, top_p, repetition_penalty)
iface = gr.Interface(
fn=gradio_predict,
title=title,
description=description,
inputs=[
gr.Textbox(label="Your Message", type="text", lines=3),
gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", type="text", lines=2),
gr.Slider(label="Max new tokens", value=550, minimum=360, maximum=600, step=1),
gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0, step=0.05),
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
],
outputs="text",
theme="ParityError/Anime"
)
demo.queue(max_size=5).iface.launch() |