import spaces import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import traceback # Try to import peft, if not available use base model only try: from peft import PeftModel PEFT_AVAILABLE = True except ImportError: print("Warning: peft library not found. LoRA adapters will not be available.") PEFT_AVAILABLE = False # === Define all your available models here === # This new dictionary allows you to define both base models and LoRA adapters. # 'type': can be 'base' for a standalone model or 'lora' for an adapter. # 'id': the Hugging Face model/adapter ID. # 'base_model_id': for LoRA adapters, specifies which base model to use. AVAILABLE_MODELS = { "BokantLM0.1-0.5B": { "type": "base", "id": "llaa33219/BokantLM0.1-0.5B", }, "BokantLM0.1-135M-Deepseek": { "type": "base", "id": "llaa33219/BokantLM0.1-135M-Deepseek", }, # --- You can add more models here --- # Example of another base model: # "Another Base Model (e.g., Ko-LLaMA)": { # "type": "base", # "id": "beomi/KoAlpaca-Polyglot-5.8B" # }, # Example of another LoRA adapter: # "Another LoRA Finetune": { # "type": "lora", # "id": "path/to/your/other-lora-adapter", # "base_model_id": "Qwen/Qwen2.5-3B-Instruct" # }, } # Global variables for model caching current_model_name = None current_tokenizer = None current_model = None def load_model(name): """ Loads a model based on the selection. It can load a base model directly or load a base model and then apply a LoRA adapter to it. """ global current_model_name, current_tokenizer, current_model if current_model_name == name: # Model is already loaded, no need to do anything return current_tokenizer, current_model print(f"Switching to model: {name}") # Clear previous model from memory if current_model is not None: del current_model del current_tokenizer current_model = None current_tokenizer = None torch.cuda.empty_cache() print("Cleared previous model from memory.") try: model_info = AVAILABLE_MODELS[name] model_type = model_info["type"] model_id = model_info["id"] # --- Case 1: Load a LoRA adapter model --- if model_type == 'lora' and PEFT_AVAILABLE: base_model_id = model_info["base_model_id"] adapter_id = model_id print(f"Loading LoRA model. Base: '{base_model_id}', Adapter: '{adapter_id}'") # Load tokenizer from the adapter (it might have special tokens) current_tokenizer = AutoTokenizer.from_pretrained(adapter_id, trust_remote_code=True) # Load base model base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True ) # Resize token embeddings if the adapter's vocab differs from the base model's if base_model.config.vocab_size != len(current_tokenizer): print(f"Resizing token embeddings from {base_model.config.vocab_size} to {len(current_tokenizer)}") base_model.resize_token_embeddings(len(current_tokenizer)) # Load and merge the LoRA adapter print(f"Loading and merging LoRA adapter: {adapter_id}") lora_model = PeftModel.from_pretrained( base_model, adapter_id, torch_dtype=torch.float16 ) current_model = lora_model.merge_and_unload() print("Successfully merged LoRA adapter.") # --- Case 2: Load a base model directly --- else: if model_type == 'lora' and not PEFT_AVAILABLE: print(f"PEFT not available. Cannot load LoRA adapter '{name}'. Falling back to its base model.") # Fallback to the base model if PEFT is missing model_id = model_info.get("base_model_id", list(AVAILABLE_MODELS.values())[0]['id']) print(f"Loading base model: {model_id}") current_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) current_model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True ) # Common post-processing for any loaded model if current_tokenizer.pad_token is None: current_tokenizer.pad_token = current_tokenizer.eos_token print("Set pad_token to eos_token.") current_model_name = name print(f"✅ Successfully loaded model: {name}") except Exception as e: print(f"❌ Failed to load model {name}: {e}") traceback.print_exc() # Clean up on failure current_model_name = None current_model = None current_tokenizer = None raise e # Re-raise the exception to be caught by the chat function return current_tokenizer, current_model @spaces.GPU() def chat_fn(message, history, selected_model): try: tokenizer, model = load_model(selected_model) # Ensure model is on the correct device (GPU) if not next(model.parameters()).is_cuda: model = model.cuda() # Build conversation history for the chat template conversation = [] for user_msg, bot_msg in history: conversation.append({"role": "user", "content": user_msg}) conversation.append({"role": "assistant", "content": bot_msg}) conversation.append({"role": "user", "content": message}) # Apply the model's specific chat template try: input_ids = tokenizer.apply_chat_template( conversation=conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).cuda() except Exception as e: print(f"Chat template error: {e}. Falling back to simple encoding.") text = f"User: {message}\nAssistant:" input_ids = tokenizer.encode(text, return_tensors="pt").cuda() # Generate response with torch.no_grad(): # Create attention mask attention_mask = torch.ones_like(input_ids) output_ids = model.generate( input_ids, max_new_tokens=512, temperature=0.7, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, attention_mask=attention_mask ) # Decode the generated tokens into text, skipping the prompt response = tokenizer.decode( output_ids[0][input_ids.shape[1]:], skip_special_tokens=True ).strip() return response except Exception as e: print(f"Error in chat_fn: {str(e)}") traceback.print_exc() return f"죄송합니다. 오류가 발생했습니다: {str(e)}" def respond(message, chat_history, selected_model): if not message.strip(): # If the message is empty, do nothing return chat_history, "" # Get the bot's response bot_message = chat_fn(message, chat_history, selected_model) # Update chat history chat_history.append([message, bot_message]) return chat_history, "" # Return updated history and clear the input box # --- Gradio Interface --- title = "Multi-Model Chatbot (with LoRA Support)" if PEFT_AVAILABLE else "Multi-Model Chatbot (Base Models Only)" with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo: gr.Markdown(f"

🗨️ {title}

") gr.Markdown("
Select a model from the dropdown and start chatting. The app will load the model on the first message.
") with gr.Row(): model_select = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value=list(AVAILABLE_MODELS.keys())[0], # Default to the first model in the list label="Choose Model", interactive=True ) chatbot = gr.Chatbot( height=500, label="Chat", show_copy_button=True, bubble_full_width=False ) with gr.Row(): msg = gr.Textbox( label="Message", placeholder="여기에 메시지를 입력하세요...", scale=4 ) send_btn = gr.Button("Send", scale=1, variant="primary") clear_btn = gr.Button("Clear Chat", variant="secondary") # --- Event Handlers --- def clear_chat(): return [], "" # Send message on button click or enter key press send_btn.click( respond, inputs=[msg, chatbot, model_select], outputs=[chatbot, msg] ) msg.submit( respond, inputs=[msg, chatbot, model_select], outputs=[chatbot, msg] ) # Clear chat button clear_btn.click(clear_chat, outputs=[chatbot, msg]) if __name__ == "__main__": # Pre-load the default model to speed up the first interaction try: print("Pre-loading the default model...") default_model_name = list(AVAILABLE_MODELS.keys())[0] load_model(default_model_name) print("✅ Default model pre-loaded successfully.") except Exception as e: print(f"⚠️ Could not pre-load the default model: {e}") demo.launch( share=False, # Set to True to get a public link (on Hugging Face Spaces or Colab) server_name="0.0.0.0", server_port=7860 )