xlr8 commited on
Commit
672d8c3
·
1 Parent(s): 81ae02e

initial commit

Browse files
Files changed (2) hide show
  1. app.py +120 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from openai import OpenAI
4
+ import jinja2
5
+ from transformers import AutoTokenizer
6
+
7
+ # Initialize the OpenAI client
8
+ client = OpenAI(
9
+ base_url="https://api.hyperbolic.xyz/v1",
10
+ api_key=os.environ["HYPERBOLIC_API_KEY"],
11
+ )
12
+
13
+ # the tokenizer complains later after gradio forks without this setting.
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+ # use unofficial copy of Llama to avoid access restrictions.
16
+ tokenizer = AutoTokenizer.from_pretrained("mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated")
17
+
18
+ # Initial prompt
19
+ initial_prompts = {
20
+ "Default": ["405B", """A chat between a person and the Llama 3.1 405B base model.
21
+
22
+ """],
23
+ }
24
+
25
+ # ChatML template
26
+ chatml_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"""
27
+ chat_template = """{% for message in messages %}{{'<' + message['role'] + '>: ' + message['content'] + '\n'}}{% endfor %}"""
28
+
29
+ def format_chat(messages, use_chatml=False):
30
+ if use_chatml:
31
+ template = jinja2.Template(chatml_template)
32
+ else:
33
+ template = jinja2.Template(chat_template)
34
+ formatted = template.render(messages=messages)
35
+ return formatted
36
+
37
+ def count_tokens(text):
38
+ return len(tokenizer.encode(text))
39
+
40
+ def limit_history(initial_prompt, history, new_message, max_tokens):
41
+ limited_history = []
42
+
43
+ token_count = count_tokens(new_message) + count_tokens(initial_prompt)
44
+ if token_count > max_tokens:
45
+ raise(ValueError("message too large for context window"))
46
+
47
+ for user_msg, assistant_msg in reversed(history):
48
+ # TODO add ChatML wrapping here for better counting?
49
+ user_tokens = count_tokens(user_msg)
50
+ assistant_tokens = count_tokens(assistant_msg)
51
+ if token_count + user_tokens + assistant_tokens > max_tokens:
52
+ break
53
+ token_count += user_tokens + assistant_tokens
54
+ limited_history.insert(0, (user_msg, assistant_msg))
55
+ return limited_history
56
+
57
+
58
+ def generate_response(message, history, initial_prompt, user_role, assistant_role, use_chatml):
59
+ context_length = 8192
60
+ response_length = 1000
61
+ slop_length = 300 # slop for chatml encoding etc--TODO fix this
62
+
63
+ # trim history based on token count
64
+ history_tokens = context_length - response_length - slop_length
65
+ limited_history = limit_history(initial_prompt, history, message, max_tokens=history_tokens)
66
+
67
+ # Prepare the input
68
+ chat_history = [{"role": user_role if i % 2 == 0 else assistant_role, "content": m}
69
+ for i, m in enumerate([item for sublist in limited_history for item in sublist] + [message])]
70
+ formatted_input = format_chat(chat_history, use_chatml)
71
+
72
+ if use_chatml:
73
+ full_prompt = initial_prompt + "\n\n" + formatted_input + f"<|im_start|>{assistant_role}\n"
74
+ else:
75
+ full_prompt = initial_prompt + "\n\n" + formatted_input + f"<{assistant_role}>:"
76
+
77
+ print(full_prompt)
78
+
79
+ completion = client.completions.create(
80
+ model="meta-llama/Meta-Llama-3.1-405B-FP8",
81
+ prompt=full_prompt,
82
+ temperature=0.7,
83
+ frequency_penalty=0.1,
84
+ max_tokens=response_length,
85
+ stop=[f'<{user_role}>:', f'<{assistant_role}>:'] if not use_chatml else [f'<|im_end|>']
86
+ )
87
+
88
+ assistant_response = completion.choices[0].text.strip()
89
+ return assistant_response
90
+
91
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
92
+ with gr.Row():
93
+ initial_prompt = gr.Textbox(
94
+ value="A chat between a person and the Llama 3.1 405B base model.",
95
+ label="Initial Prompt",
96
+ lines=3
97
+ )
98
+ with gr.Column():
99
+ user_role = gr.Textbox(value="User", label="User Role")
100
+ assistant_role = gr.Textbox(value="405B", label="Assistant Role")
101
+ use_chatml = gr.Checkbox(label="Use ChatML", value=True)
102
+
103
+
104
+ chatbot = gr.ChatInterface(
105
+ generate_response,
106
+ title="Chat with 405B",
107
+ additional_inputs=[initial_prompt, user_role, assistant_role, use_chatml],
108
+ concurrency_limit=10,
109
+ chatbot=gr.Chatbot(height=800)
110
+ )
111
+
112
+ gr.Markdown("""
113
+ This chat interface is powered by the Llama 3.1 405B base model, served by [Hyperbolic](https://hyperbolic.xyz), The Open Access AI Cloud.
114
+
115
+ Thank you to Hyperbolic for making this base model available!
116
+ """)
117
+
118
+
119
+ # Launch the interface
120
+ iface.launch(share=True, max_threads=40)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ openai