dhgtfrd / app.py
BICORP's picture
Update app.py
30b93f3 verified
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import hf_hub_download
# Set your Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
# Define model names and their local paths
model_names = {
"mistralai/Mistral-7B-Instruct-v0.3": "mistral-7b-instruct",
"BICORP/Lake-1-Advanced": "lake-1-advanced"
}
# Download models and tokenizers using the Hugging Face Hub
def download_model(repo_id):
model_path = hf_hub_download(repo_id=repo_id, token=HF_TOKEN)
return model_path
# Load models and tokenizers
models = {}
tokenizers = {}
for name in model_names.keys():
model_path = download_model(name)
models[name] = AutoModelForCausalLM.from_pretrained(model_path)
tokenizers[name] = AutoTokenizer.from_pretrained(model_path)
# Define presets for each model
presets = {
"mistralai/Mistral-7B-Instruct-v0.3": {
"Fast": {"max_tokens": 256, "temperature": 1.0, "top_p": 0.8},
"Normal": {"max_tokens": 512, "temperature": 0.6, "top_p": 0.75},
"Quality": {"max_tokens": 1024, "temperature": 0.45, "top_p": 0.60},
"Unreal Performance": {"max_tokens": 1048, "temperature": 0.5, "top_p": 0.7},
},
"BICORP/Lake-1-Advanced": {
"Fast": {"max_tokens": 800, "temperature": 1.0, "top_p": 0.9},
"Normal": {"max_tokens": 4000, "temperature": 0.7, "top_p": 0.95},
"Quality": {"max_tokens": 32000, "temperature": 0.5, "top_p": 0.90},
"Unreal Performance": {"max_tokens": 128000, "temperature": 0.6, "top_p": 0.75},
}
}
# System messages for each model
system_messages = {
"mistralai/Mistral-7B-Instruct-v0.3": "Your name is Lake 1 Base but mine is User",
"BICORP/Lake-1-Advanced": "Your name is Lake 1 Advanced [Alpha] but mine is User or what I will type as my name"
}
# Model names and their pseudonyms
model_choices = [
("mistralai/Mistral-7B-Instruct-v0.3", "Lake 1 Base"),
("BICORP/Lake-1-Advanced", "Lake 1 Advanced [Alpha]")
]
# Convert pseudonyms to model names for the dropdown
pseudonyms = [model[1] for model in model_choices]
def respond(
message,
history: list,
model_name,
preset_name
):
# Get the correct model and tokenizer
model = models[model_name]
tokenizer = tokenizers[model_name]
# Get the system message for the model
system_message = system_messages[model_name]
# Prepare the input for the model
input_text = system_message + "\n" + "\n".join([f"{val['role']}: {val['content']}" for val in history]) + f"\n:User {message}\n"
# Tokenize the input
inputs = tokenizer.encode(input_text, return_tensors='pt')
# Get the preset settings
preset = presets[model_name][preset_name]
max_tokens = preset["max_tokens"]
temperature = preset["temperature"]
top_p = preset["top_p"]
# Generate response
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Decode the response
final_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the assistant's response
assistant_response = final_response.split(":User ")[-1].strip # Append the user message and assistant response to the history
history.append({"role": "User ", "content": message})
history.append({"role": "Assistant", "content": assistant_response})
return assistant_response, history
# Gradio interface
def launch_interface():
with gr.Blocks() as demo:
gr.Markdown("## Chat with Lake 1 Models")
model_selector = gr.Dropdown(choices=pseudonyms, label="Select Model")
preset_selector = gr.Dropdown(choices=["Fast", "Normal", "Quality", "Unreal Performance"], label="Select Preset")
message_input = gr.Textbox(label="Your Message")
chat_history = gr.Chatbox(label="Chat History")
def update_model(selected_model):
return model_names[pseudonyms.index(selected_model)]
model_selector.change(update_model, inputs=model_selector, outputs=model_selector)
def submit_message(message, history, model_name, preset_name):
return respond(message, history, model_name, preset_name)
submit_button = gr.Button("Send")
submit_button.click(submit_message, inputs=[message_input, chat_history, model_selector, preset_selector], outputs=[chat_history, chat_history])
demo.launch()
if __name__ == "__main__":
launch_interface()