File size: 3,520 Bytes
2bc99a0
 
 
 
927b5de
 
2bc99a0
b723b47
a3c3064
99c11c8
 
63a0917
3d2716e
2bc99a0
99c11c8
fc295cf
7f75950
a3c3064
99c11c8
 
2bc99a0
 
9bc49ef
0d5c130
9bc49ef
 
 
 
5ab0bbc
0d5c130
63a0917
9bc49ef
 
2bc99a0
 
a3c3064
2bc99a0
 
 
 
 
 
0d5c130
5ab0bbc
a3c3064
2bc99a0
 
a3c3064
2bc99a0
bc3a87b
 
5ab0bbc
 
bc3a87b
8de5029
1874bf4
2bc99a0
1874bf4
 
2bc99a0
 
 
2407fc5
2bc99a0
 
 
 
1874bf4
edc6972
 
927b5de
02f3e50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import math
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import sentencepiece
from tokenization_xgen import XgenTokenizer

title = "Welcome to 🙋🏻‍♂️Tonic's😈Xgen-8K Chat!"
description = "Interestingly there simply wasnt a public demo for Xgen, So I made one. You can use [Salesforce/xgen-7b-8k-inst](https://huggingface.co/Salesforce/xgen-7b-8k-inst) via API using Gradio by scrolling down and clicking Use 'Via API' or privately by [cloning this space on huggingface](https://huggingface.co/spaces/Tonic1/Xgen?duplicate=true) . [Join my active builders' server on discord](https://discord.gg/VqTxc76K3u). Let's build together!."

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "Salesforce/xgen-7b-8k-base"
tokenizer = XgenTokenizer.from_pretrained("./")
model = AutoModelForCausalLM.from_pretrained(model_name,  torch_dtype=torch.bfloat16, device_map="auto")

class XgenChatBot:
    def __init__(self, model, tokenizer, system_message="You are Xgen, an AI language model created by Tonic-AI. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
        self.model = model
        self.tokenizer = tokenizer
        self.system_message = system_message

    def set_system_message(self, new_system_message):
        self.system_message = new_system_message

    def format_prompt(self, user_message):
        prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n<|im_start|>\nuser\n{user_message}<|im_end|>\nassistant\n"
        return prompt

    def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
        prompt = self.format_prompt(user_message)
        inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
        input_ids = inputs["input_ids"].to(self.model.device)

        output_ids = self.model.generate(
            input_ids,
            max_length=input_ids.shape[1] + max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True
        )

        response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return response

def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
    Xgen_bot.set_system_message(system_message)
    response = Xgen_bot.predict(user_message, temperature, max_new_tokens, top_p, repetition_penalty)
    return response
    
Xgen_bot = XgenChatBot(model, tokenizer)

iface = gr.Interface(
    fn=gradio_predict,
    title=title,
    description=description,
    inputs=[
        gr.Textbox(label="Your Message", type="text", lines=3),
        gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", type="text", lines=2),
        gr.Slider(label="Max new tokens", value=550, minimum=360, maximum=600, step=1),
        gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0, step=0.05),
        gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
        gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
    ],
    outputs="text",
    theme="ParityError/Anime"
)

iface.queue(max_size=5).launch()