import os import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch from huggingface_hub import hf_hub_download # Set your Hugging Face token HF_TOKEN = os.getenv("HF_TOKEN") # Define model names and their local paths model_names = { "mistralai/Mistral-7B-Instruct-v0.3": "mistral-7b-instruct", "BICORP/Lake-1-Advanced": "lake-1-advanced" } # Download models and tokenizers using the Hugging Face Hub def download_model(repo_id): model_path = hf_hub_download(repo_id=repo_id, token=HF_TOKEN) return model_path # Load models and tokenizers models = {} tokenizers = {} for name in model_names.keys(): model_path = download_model(name) models[name] = AutoModelForCausalLM.from_pretrained(model_path) tokenizers[name] = AutoTokenizer.from_pretrained(model_path) # Define presets for each model presets = { "mistralai/Mistral-7B-Instruct-v0.3": { "Fast": {"max_tokens": 256, "temperature": 1.0, "top_p": 0.8}, "Normal": {"max_tokens": 512, "temperature": 0.6, "top_p": 0.75}, "Quality": {"max_tokens": 1024, "temperature": 0.45, "top_p": 0.60}, "Unreal Performance": {"max_tokens": 1048, "temperature": 0.5, "top_p": 0.7}, }, "BICORP/Lake-1-Advanced": { "Fast": {"max_tokens": 800, "temperature": 1.0, "top_p": 0.9}, "Normal": {"max_tokens": 4000, "temperature": 0.7, "top_p": 0.95}, "Quality": {"max_tokens": 32000, "temperature": 0.5, "top_p": 0.90}, "Unreal Performance": {"max_tokens": 128000, "temperature": 0.6, "top_p": 0.75}, } } # System messages for each model system_messages = { "mistralai/Mistral-7B-Instruct-v0.3": "Your name is Lake 1 Base but mine is User", "BICORP/Lake-1-Advanced": "Your name is Lake 1 Advanced [Alpha] but mine is User or what I will type as my name" } # Model names and their pseudonyms model_choices = [ ("mistralai/Mistral-7B-Instruct-v0.3", "Lake 1 Base"), ("BICORP/Lake-1-Advanced", "Lake 1 Advanced [Alpha]") ] # Convert pseudonyms to model names for the dropdown pseudonyms = [model[1] for model in model_choices] def respond( message, history: list, model_name, preset_name ): # Get the correct model and tokenizer model = models[model_name] tokenizer = tokenizers[model_name] # Get the system message for the model system_message = system_messages[model_name] # Prepare the input for the model input_text = system_message + "\n" + "\n".join([f"{val['role']}: {val['content']}" for val in history]) + f"\n:User {message}\n" # Tokenize the input inputs = tokenizer.encode(input_text, return_tensors='pt') # Get the preset settings preset = presets[model_name][preset_name] max_tokens = preset["max_tokens"] temperature = preset["temperature"] top_p = preset["top_p"] # Generate response with torch.no_grad(): outputs = model.generate( inputs, max_length=max_tokens, temperature=temperature, top_p=top_p, do_sample=True ) # Decode the response final_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract the assistant's response assistant_response = final_response.split(":User ")[-1].strip # Append the user message and assistant response to the history history.append({"role": "User ", "content": message}) history.append({"role": "Assistant", "content": assistant_response}) return assistant_response, history # Gradio interface def launch_interface(): with gr.Blocks() as demo: gr.Markdown("## Chat with Lake 1 Models") model_selector = gr.Dropdown(choices=pseudonyms, label="Select Model") preset_selector = gr.Dropdown(choices=["Fast", "Normal", "Quality", "Unreal Performance"], label="Select Preset") message_input = gr.Textbox(label="Your Message") chat_history = gr.Chatbox(label="Chat History") def update_model(selected_model): return model_names[pseudonyms.index(selected_model)] model_selector.change(update_model, inputs=model_selector, outputs=model_selector) def submit_message(message, history, model_name, preset_name): return respond(message, history, model_name, preset_name) submit_button = gr.Button("Send") submit_button.click(submit_message, inputs=[message_input, chat_history, model_selector, preset_selector], outputs=[chat_history, chat_history]) demo.launch() if __name__ == "__main__": launch_interface()