Spaces:
Runtime error
Runtime error
xlr8
commited on
Commit
·
672d8c3
1
Parent(s):
81ae02e
initial commit
Browse files- app.py +120 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from openai import OpenAI
|
4 |
+
import jinja2
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
# Initialize the OpenAI client
|
8 |
+
client = OpenAI(
|
9 |
+
base_url="https://api.hyperbolic.xyz/v1",
|
10 |
+
api_key=os.environ["HYPERBOLIC_API_KEY"],
|
11 |
+
)
|
12 |
+
|
13 |
+
# the tokenizer complains later after gradio forks without this setting.
|
14 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
15 |
+
# use unofficial copy of Llama to avoid access restrictions.
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained("mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated")
|
17 |
+
|
18 |
+
# Initial prompt
|
19 |
+
initial_prompts = {
|
20 |
+
"Default": ["405B", """A chat between a person and the Llama 3.1 405B base model.
|
21 |
+
|
22 |
+
"""],
|
23 |
+
}
|
24 |
+
|
25 |
+
# ChatML template
|
26 |
+
chatml_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"""
|
27 |
+
chat_template = """{% for message in messages %}{{'<' + message['role'] + '>: ' + message['content'] + '\n'}}{% endfor %}"""
|
28 |
+
|
29 |
+
def format_chat(messages, use_chatml=False):
|
30 |
+
if use_chatml:
|
31 |
+
template = jinja2.Template(chatml_template)
|
32 |
+
else:
|
33 |
+
template = jinja2.Template(chat_template)
|
34 |
+
formatted = template.render(messages=messages)
|
35 |
+
return formatted
|
36 |
+
|
37 |
+
def count_tokens(text):
|
38 |
+
return len(tokenizer.encode(text))
|
39 |
+
|
40 |
+
def limit_history(initial_prompt, history, new_message, max_tokens):
|
41 |
+
limited_history = []
|
42 |
+
|
43 |
+
token_count = count_tokens(new_message) + count_tokens(initial_prompt)
|
44 |
+
if token_count > max_tokens:
|
45 |
+
raise(ValueError("message too large for context window"))
|
46 |
+
|
47 |
+
for user_msg, assistant_msg in reversed(history):
|
48 |
+
# TODO add ChatML wrapping here for better counting?
|
49 |
+
user_tokens = count_tokens(user_msg)
|
50 |
+
assistant_tokens = count_tokens(assistant_msg)
|
51 |
+
if token_count + user_tokens + assistant_tokens > max_tokens:
|
52 |
+
break
|
53 |
+
token_count += user_tokens + assistant_tokens
|
54 |
+
limited_history.insert(0, (user_msg, assistant_msg))
|
55 |
+
return limited_history
|
56 |
+
|
57 |
+
|
58 |
+
def generate_response(message, history, initial_prompt, user_role, assistant_role, use_chatml):
|
59 |
+
context_length = 8192
|
60 |
+
response_length = 1000
|
61 |
+
slop_length = 300 # slop for chatml encoding etc--TODO fix this
|
62 |
+
|
63 |
+
# trim history based on token count
|
64 |
+
history_tokens = context_length - response_length - slop_length
|
65 |
+
limited_history = limit_history(initial_prompt, history, message, max_tokens=history_tokens)
|
66 |
+
|
67 |
+
# Prepare the input
|
68 |
+
chat_history = [{"role": user_role if i % 2 == 0 else assistant_role, "content": m}
|
69 |
+
for i, m in enumerate([item for sublist in limited_history for item in sublist] + [message])]
|
70 |
+
formatted_input = format_chat(chat_history, use_chatml)
|
71 |
+
|
72 |
+
if use_chatml:
|
73 |
+
full_prompt = initial_prompt + "\n\n" + formatted_input + f"<|im_start|>{assistant_role}\n"
|
74 |
+
else:
|
75 |
+
full_prompt = initial_prompt + "\n\n" + formatted_input + f"<{assistant_role}>:"
|
76 |
+
|
77 |
+
print(full_prompt)
|
78 |
+
|
79 |
+
completion = client.completions.create(
|
80 |
+
model="meta-llama/Meta-Llama-3.1-405B-FP8",
|
81 |
+
prompt=full_prompt,
|
82 |
+
temperature=0.7,
|
83 |
+
frequency_penalty=0.1,
|
84 |
+
max_tokens=response_length,
|
85 |
+
stop=[f'<{user_role}>:', f'<{assistant_role}>:'] if not use_chatml else [f'<|im_end|>']
|
86 |
+
)
|
87 |
+
|
88 |
+
assistant_response = completion.choices[0].text.strip()
|
89 |
+
return assistant_response
|
90 |
+
|
91 |
+
with gr.Blocks(theme=gr.themes.Soft()) as iface:
|
92 |
+
with gr.Row():
|
93 |
+
initial_prompt = gr.Textbox(
|
94 |
+
value="A chat between a person and the Llama 3.1 405B base model.",
|
95 |
+
label="Initial Prompt",
|
96 |
+
lines=3
|
97 |
+
)
|
98 |
+
with gr.Column():
|
99 |
+
user_role = gr.Textbox(value="User", label="User Role")
|
100 |
+
assistant_role = gr.Textbox(value="405B", label="Assistant Role")
|
101 |
+
use_chatml = gr.Checkbox(label="Use ChatML", value=True)
|
102 |
+
|
103 |
+
|
104 |
+
chatbot = gr.ChatInterface(
|
105 |
+
generate_response,
|
106 |
+
title="Chat with 405B",
|
107 |
+
additional_inputs=[initial_prompt, user_role, assistant_role, use_chatml],
|
108 |
+
concurrency_limit=10,
|
109 |
+
chatbot=gr.Chatbot(height=800)
|
110 |
+
)
|
111 |
+
|
112 |
+
gr.Markdown("""
|
113 |
+
This chat interface is powered by the Llama 3.1 405B base model, served by [Hyperbolic](https://hyperbolic.xyz), The Open Access AI Cloud.
|
114 |
+
|
115 |
+
Thank you to Hyperbolic for making this base model available!
|
116 |
+
""")
|
117 |
+
|
118 |
+
|
119 |
+
# Launch the interface
|
120 |
+
iface.launch(share=True, max_threads=40)
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
openai
|