# app.py — Corrected for proper LoRA adapter loading import os import gc import torch import gradio as gr from typing import List, Tuple import warnings warnings.filterwarnings('ignore') try: from peft import PeftConfig, PeftModel from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig IMPORTS_OK = True except ImportError as e: IMPORTS_OK = False print(f"Missing dependencies: {e}") print("Please install: pip install transformers peft torch gradio accelerate") # ── Configuration ────────────────────────────────────────────────────────────── HF_TOKEN = os.environ.get("HF_TOKEN") # Optional for public models # Your LoRA adapter location (HuggingFace repo or local path) ADAPTER_ID = "Reubencf/gemma3-goan-finetuned" # For local adapter: ADAPTER_ID = "./path/to/your/adapter" # Base model - MUST match what you used for fine-tuning! # Check your adapter's config.json for "base_model_name_or_path" BASE_MODEL_ID = "google/gemma-3-4b-it" # Change this to your actual base model # Common options: # - "google/gemma-2b-it" (2B parameters, easier on memory) # - "unsloth/gemma-2-2b-it-bnb-4bit" (quantized version) # - Your actual base model used for training # Settings USE_8BIT = False # Set to True if you have GPU and want to use 8-bit quantization DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TITLE = "🌴 Gemma Goan Q&A Bot" DESCRIPTION = """ Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset. Ask about Goa, Konkani culture, or general topics! **Status**: {} """ # ── Load model + tokenizer (correct LoRA loading) ────────────────────────────── def load_model_and_tokenizer(): """Load base model and apply LoRA adapter correctly""" if not IMPORTS_OK: raise ImportError("Required packages not installed") print("[Init] Starting model load...") print(f"[Config] Base model: {BASE_MODEL_ID}") print(f"[Config] LoRA adapter: {ADAPTER_ID}") print(f"[Config] Device: {DEVICE}") # Memory cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() status = "" model = None tokenizer = None try: # Step 1: Try to read adapter config to get the correct base model actual_base_model = BASE_MODEL_ID try: print(f"[Load] Checking adapter configuration...") peft_config = PeftConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN) actual_base_model = peft_config.base_model_name_or_path print(f"[Load] Adapter expects base model: {actual_base_model}") # Warn if mismatch if actual_base_model != BASE_MODEL_ID: print(f"[Warning] BASE_MODEL_ID ({BASE_MODEL_ID}) doesn't match adapter's base ({actual_base_model})") print(f"[Load] Using adapter's base model: {actual_base_model}") except Exception as e: print(f"[Warning] Cannot read adapter config: {e}") print(f"[Load] Will try with configured base model: {BASE_MODEL_ID}") actual_base_model = BASE_MODEL_ID # Step 2: Load the BASE MODEL (not the adapter!) print(f"[Load] Loading base model: {actual_base_model}") # Quantization config for GPU quantization_config = None if USE_8BIT and torch.cuda.is_available(): print("[Load] Using 8-bit quantization") quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16 ) # Load base model base_model = AutoModelForCausalLM.from_pretrained( actual_base_model, token=HF_TOKEN, trust_remote_code=True, quantization_config=quantization_config, low_cpu_mem_usage=True, torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16, device_map="auto" if torch.cuda.is_available() else None, ) # Move to device if needed if DEVICE == "cpu" and not torch.cuda.is_available(): base_model = base_model.to("cpu") print("[Load] Model on CPU") print("[Load] Base model loaded successfully") # Step 3: Load tokenizer from the BASE MODEL print(f"[Load] Loading tokenizer from base model...") tokenizer = AutoTokenizer.from_pretrained( actual_base_model, token=HF_TOKEN, use_fast=True, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" # Step 4: Try to apply LoRA adapter try: print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}") model = PeftModel.from_pretrained( base_model, ADAPTER_ID, token=HF_TOKEN, trust_remote_code=True, is_trainable=False, # Inference only ) # Optional: Merge adapter with base model for faster inference # This combines the weights permanently (uses more memory initially but faster inference) merge = input("\n💡 Merge adapter for faster inference? (y/n, default=y): ").strip().lower() if merge != 'n': print("[Load] Merging adapter with base model...") model = model.merge_and_unload() print("[Load] Adapter merged successfully") status = f"✅ Using fine-tuned model (merged): {ADAPTER_ID}" else: print("[Load] Using adapter without merging") status = f"✅ Using fine-tuned model: {ADAPTER_ID}" except FileNotFoundError as e: print(f"[Error] Adapter files not found: {e}") print("[Fallback] Using base model without fine-tuning") model = base_model status = f"⚠️ Adapter not found. Using base model only: {actual_base_model}" except Exception as e: print(f"[Error] Failed to load adapter: {e}") print("[Fallback] Using base model without fine-tuning") model = base_model status = f"⚠️ Could not load adapter. Using base model only: {actual_base_model}" # Step 5: Final setup model.eval() print(f"[Load] Model ready on {DEVICE}!") # Memory cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return model, tokenizer, status except Exception as e: error_msg = f"Failed to load model: {str(e)}" print(f"[Fatal] {error_msg}") # Try fallback to smallest model if "gemma-2b" not in BASE_MODEL_ID.lower(): print("[Fallback] Trying with gemma-2b-it...") try: base_model = AutoModelForCausalLM.from_pretrained( "google/gemma-3-4b-it", token=HF_TOKEN, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float32, device_map=None, ).to("cpu") tokenizer = AutoTokenizer.from_pretrained( "google/gemma-3-4b-it", token=HF_TOKEN, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base_model.eval() return base_model, tokenizer, "⚠️ Using fallback model: gemma-2b-it (no fine-tuning)" except Exception as fallback_error: print(f"[Fatal] Fallback also failed: {fallback_error}") raise gr.Error(f"Cannot load any model. Check your configuration.") else: raise gr.Error(error_msg) # Load model globally try: model, tokenizer, STATUS_MSG = load_model_and_tokenizer() MODEL_LOADED = True DESCRIPTION = DESCRIPTION.format(STATUS_MSG) except Exception as e: print(f"[Fatal] Could not load model: {e}") MODEL_LOADED = False model, tokenizer = None, None DESCRIPTION = DESCRIPTION.format(f"❌ Model failed to load: {str(e)[:100]}") # ── Generation function ───────────────────────────────────────────────────────── def generate_response( message: str, history: List[Tuple[str, str]], temperature: float = 0.7, max_new_tokens: int = 256, top_p: float = 0.95, repetition_penalty: float = 1.1, ) -> str: """Generate response using the model""" if not MODEL_LOADED: return "⚠️ Model failed to load. Please check the logs or restart the application." try: # Build conversation conversation = [] if history: # Keep last 3 exchanges for context for user_msg, assistant_msg in history[-3:]: if user_msg: conversation.append({"role": "user", "content": user_msg}) if assistant_msg: conversation.append({"role": "assistant", "content": assistant_msg}) conversation.append({"role": "user", "content": message}) # Apply chat template try: prompt = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ) except Exception as e: print(f"[Warning] Chat template failed: {e}, using fallback format") # Fallback format prompt_text = "" for msg in conversation: if msg["role"] == "user": prompt_text += f"User: {msg['content']}\n" else: prompt_text += f"Assistant: {msg['content']}\n" prompt_text += "Assistant: " inputs = tokenizer( prompt_text, return_tensors="pt", truncation=True, max_length=512 ) prompt = inputs.input_ids # Move to device prompt = prompt.to(model.device if hasattr(model, 'device') else DEVICE) # Generate print(f"[Generate] Input length: {prompt.shape[-1]} tokens") with torch.no_grad(): outputs = model.generate( input_ids=prompt, max_new_tokens=min(int(max_new_tokens), 256), temperature=float(temperature), top_p=float(top_p), repetition_penalty=float(repetition_penalty), do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, ) # Decode only generated tokens generated_tokens = outputs[0][prompt.shape[-1]:] response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() print(f"[Generate] Output length: {len(generated_tokens)} tokens") # Cleanup del outputs, prompt, generated_tokens gc.collect() return response except Exception as e: error_msg = f"⚠️ Error generating response: {str(e)}" print(f"[Error] {error_msg}") # Try to recover memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return error_msg # ── Gradio Interface ──────────────────────────────────────────────────────────── examples = [ ["What is the capital of Goa?"], ["Tell me about Konkani language"], ["What are famous beaches in Goa?"], ["Describe Goan fish curry"], ["What is the history of Old Goa?"], ] # Create interface if MODEL_LOADED: demo = gr.ChatInterface( fn=generate_response, title=TITLE, description=DESCRIPTION, examples=examples, retry_btn=None, undo_btn=None, additional_inputs=[ gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature (lower = more focused)" ), gr.Slider( minimum=32, maximum=256, value=128, step=16, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), gr.Slider( minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty" ), ], theme=gr.themes.Soft(), ) else: demo = gr.Interface( fn=lambda x: "Model failed to load. Check console for errors.", inputs=gr.Textbox(label="Message"), outputs=gr.Textbox(label="Response"), title=TITLE, description=DESCRIPTION, ) # Queue with version compatibility try: # Try newer Gradio syntax first (4.x) demo.queue(default_concurrency_limit=1, max_size=10) except TypeError: try: # Fall back to older syntax (3.x) demo.queue(concurrency_count=1, max_size=10) except: # If both fail, try without parameters demo.queue() if __name__ == "__main__": print("\n" + "="*50) print(f"🚀 Starting Gradio app on {DEVICE}...") print(f"📍 Base model: {BASE_MODEL_ID}") print(f"🔧 LoRA adapter: {ADAPTER_ID}") print("="*50 + "\n") demo.launch()