import os import gradio as gr from openai import OpenAI import jinja2 from transformers import AutoTokenizer # Initialize the OpenAI client client = OpenAI( base_url="https://api.hyperbolic.xyz/v1", api_key=os.environ["HYPERBOLIC_API_KEY"], ) # the tokenizer complains later after gradio forks without this setting. os.environ["TOKENIZERS_PARALLELISM"] = "false" # use unofficial copy of Llama to avoid access restrictions. tokenizer = AutoTokenizer.from_pretrained("mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated") # Initial prompt initial_prompts = { "Default": ["405B", """A chat between a person and the Llama 3.1 405B base model. """], } # ChatML template chatml_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}""" chat_template = """{% for message in messages %}{{'<' + message['role'] + '>: ' + message['content'] + '\n'}}{% endfor %}""" def format_chat(messages, use_chatml=False): if use_chatml: template = jinja2.Template(chatml_template) else: template = jinja2.Template(chat_template) formatted = template.render(messages=messages) return formatted def count_tokens(text): return len(tokenizer.encode(text)) def limit_history(initial_prompt, history, new_message, max_tokens): limited_history = [] token_count = count_tokens(new_message) + count_tokens(initial_prompt) if token_count > max_tokens: raise(ValueError("message too large for context window")) for user_msg, assistant_msg in reversed(history): # TODO add ChatML wrapping here for better counting? user_tokens = count_tokens(user_msg) assistant_tokens = count_tokens(assistant_msg) if token_count + user_tokens + assistant_tokens > max_tokens: break token_count += user_tokens + assistant_tokens limited_history.insert(0, (user_msg, assistant_msg)) return limited_history def generate_response(message, history, initial_prompt, user_role, assistant_role, use_chatml): context_length = 8192 response_length = 1000 slop_length = 300 # slop for chatml encoding etc--TODO fix this # trim history based on token count history_tokens = context_length - response_length - slop_length limited_history = limit_history(initial_prompt, history, message, max_tokens=history_tokens) # Prepare the input chat_history = [{"role": user_role if i % 2 == 0 else assistant_role, "content": m} for i, m in enumerate([item for sublist in limited_history for item in sublist] + [message])] formatted_input = format_chat(chat_history, use_chatml) if use_chatml: full_prompt = "<|im_start|>system\n" + initial_prompt + "<|im_end|>\n" + formatted_input + f"<|im_start|>{assistant_role}\n" else: full_prompt = initial_prompt + "\n\n" + formatted_input + f"<{assistant_role}>:" completion = client.completions.create( model="meta-llama/Meta-Llama-3.1-405B", prompt=full_prompt, temperature=0.7, frequency_penalty=0.1, max_tokens=response_length, stop=[f'<{user_role}>:', f'<{assistant_role}>:'] if not use_chatml else [f'<|im_end|>'] ) assistant_response = completion.choices[0].text.strip() return assistant_response with gr.Blocks(theme=gr.themes.Soft()) as iface: with gr.Row(): initial_prompt = gr.Textbox( value="Please respond in whatever manner comes most naturally to you. You do not need to act as an assistant.", label="Initial Prompt", lines=3 ) with gr.Column(): user_role = gr.Textbox(value="user", label="User Role") assistant_role = gr.Textbox(value="model", label="Assistant Role") use_chatml = gr.Checkbox(label="Use ChatML", value=True) chatbot = gr.ChatInterface( generate_response, title="Chat with 405B", additional_inputs=[initial_prompt, user_role, assistant_role, use_chatml], concurrency_limit=10, chatbot=gr.Chatbot(height=600) ) gr.Markdown(""" This chat interface is powered by the Llama 3.1 405B base model, served by [Hyperbolic](https://hyperbolic.xyz), The Open Access AI Cloud. Thank you to Hyperbolic for making this base model available! """) # Launch the interface iface.launch(share=True, max_threads=40)