# app.py — Hugging Face Space ready (LoRA adapter, Gradio compat) # --------------------------------------------------------------- # What changed vs your script # - Removed ChatInterface args that broke on old Gradio (retry_btn, undo_btn) # - No interactive input() for merging (Spaces are non-interactive). Use MERGE_LORA env var. # - Secrets: read HF token from env (Settings → Secrets → HF_TOKEN), never hardcode. # - Token passing works across transformers/peft versions (token/use_auth_token fallback). # - Optional 8-bit via USE_8BIT=1 (GPU only). Safe CPU defaults. # - Robust theme/queue/launch for mixed Gradio versions. import os import gc import warnings from typing import List, Tuple import torch import gradio as gr warnings.filterwarnings("ignore") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") try: from peft import PeftConfig, PeftModel from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, ) IMPORTS_OK = True except Exception as e: IMPORTS_OK = False print(f"Missing dependencies: {e}") print("Install: pip install --upgrade 'transformers>=4.41' peft accelerate gradio torch bitsandbytes") # ── Configuration ────────────────────────────────────────────────────────────── HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Settings → Secrets → HF_TOKEN # LoRA adapter repo (must be compatible with BASE_MODEL_ID) ADAPTER_ID = os.getenv("ADAPTER_ID", "Reubencf/gemma3-goan-finetuned") # Base model used during fine-tuning (should match adapter's base) BASE_MODEL_ID_DEFAULT = os.getenv("BASE_MODEL_ID", "google/gemma-3-4b-it") # Quantization toggle (GPU only): set USE_8BIT=1 in Space variables USE_8BIT = os.getenv("USE_8BIT", "0").lower() in {"1", "true", "yes", "y"} # Merge LoRA into the base for faster inference: MERGE_LORA=1/0 MERGE_LORA = os.getenv("MERGE_LORA", "1").lower() in {"1", "true", "yes", "y"} DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TITLE = "🌴 Gemma Goan Q&A Bot" DESCRIPTION_TMPL = ( "Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset.\n" "Ask about Goa, Konkani culture, or general topics!\n\n" "**Status**: {}" ) # ── Helpers ─────────────────────────────────────────────────────────────────── def call_with_token(fn, *args, **kwargs): """Call HF/Transformers/PEFT functions with token OR use_auth_token for broad version compatibility.""" if HF_TOKEN: try: return fn(*args, token=HF_TOKEN, **kwargs) except TypeError: return fn(*args, use_auth_token=HF_TOKEN, **kwargs) return fn(*args, **kwargs) # ── Load model + tokenizer ───────────────────────────────────────────────────── def load_model_and_tokenizer(): if not IMPORTS_OK: raise ImportError("Required packages not installed.") print("[Init] Starting model load…") print(f"[Config] Device: {DEVICE}") # GC + VRAM cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Step 1: Confirm base model from the adapter's config if possible actual_base_model = BASE_MODEL_ID_DEFAULT try: print(f"[Load] Reading adapter config: {ADAPTER_ID}") peft_cfg = call_with_token(PeftConfig.from_pretrained, ADAPTER_ID) if getattr(peft_cfg, "base_model_name_or_path", None): actual_base_model = peft_cfg.base_model_name_or_path print(f"[Load] Adapter expects base model: {actual_base_model}") else: print("[Warn] Adapter did not expose base_model_name_or_path; using configured base.") except Exception as e: print(f"[Warn] Could not read adapter config ({e}); using configured base: {actual_base_model}") # Step 2: Load base model (optionally quantized on GPU) print(f"[Load] Loading base model: {actual_base_model}") quant_cfg = None if USE_8BIT and torch.cuda.is_available(): print("[Load] Enabling 8-bit quantization (bitsandbytes)") quant_cfg = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16) base_model = call_with_token( AutoModelForCausalLM.from_pretrained, actual_base_model, trust_remote_code=True, quantization_config=quant_cfg, low_cpu_mem_usage=True, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) if DEVICE == "cpu" and not torch.cuda.is_available(): base_model = base_model.to("cpu") print("[Load] Model on CPU") print("[Load] Base model loaded ✔") # Step 3: Tokenizer print("[Load] Loading tokenizer…") tokenizer = call_with_token( AutoTokenizer.from_pretrained, actual_base_model, use_fast=True, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" # Step 4: Apply LoRA adapter status = "" model = base_model try: print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}") model = call_with_token(PeftModel.from_pretrained, base_model, ADAPTER_ID) if MERGE_LORA: print("[Load] Merging adapter into base (merge_and_unload)…") model = model.merge_and_unload() status = f"✅ Using fine-tuned model (merged): {ADAPTER_ID}" else: status = f"✅ Using fine-tuned model via adapter: {ADAPTER_ID}" except FileNotFoundError as e: print(f"[Error] Adapter files not found: {e}") status = f"⚠️ Adapter not found. Using base only: {actual_base_model}" except Exception as e: print(f"[Error] Failed to load adapter: {e}") status = f"⚠️ Could not load adapter. Using base only: {actual_base_model}" model.eval() print(f"[Load] Model ready on {DEVICE} ✔") gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return model, tokenizer, status # Global load at import time (Space-friendly) try: model, tokenizer, STATUS_MSG = load_model_and_tokenizer() MODEL_LOADED = True DESCRIPTION = DESCRIPTION_TMPL.format(STATUS_MSG) except Exception as e: print(f"[Fatal] Could not load model: {e}") MODEL_LOADED = False model = tokenizer = None DESCRIPTION = DESCRIPTION_TMPL.format(f"❌ Model failed to load: {str(e)[:140]}") # ── Generation ──────────────────────────────────────────────────────────────── def generate_response( message: str, history: List[Tuple[str, str]], temperature: float = 0.7, max_new_tokens: int = 256, top_p: float = 0.95, repetition_penalty: float = 1.1, ) -> str: if not MODEL_LOADED: return "⚠️ Model failed to load. Check Space logs." try: # Build short chat history conversation = [] if history: for u, a in history[-3:]: if u: conversation.append({"role": "user", "content": u}) if a: conversation.append({"role": "assistant", "content": a}) conversation.append({"role": "user", "content": message}) # Try the tokenizer's chat template first try: input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt", ) except Exception as e: print(f"[Warn] chat_template failed: {e}; using manual format") prompt_text = "".join( [ ("User: " + m["content"] + "\n") if m["role"] == "user" else ("Assistant: " + m["content"] + "\n") for m in conversation ] ) + "Assistant: " input_ids = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=1024).input_ids input_ids = input_ids.to(model.device if hasattr(model, "device") else DEVICE) with torch.no_grad(): out = model.generate( input_ids=input_ids, max_new_tokens=max(1, min(int(max_new_tokens), 512)), temperature=float(temperature), top_p=float(top_p), repetition_penalty=float(repetition_penalty), do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, ) gen = out[0][input_ids.shape[-1]:] text = tokenizer.decode(gen, skip_special_tokens=True).strip() # Cleanup del out, input_ids, gen gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return text or "(no output)" except Exception as e: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return f"⚠️ Error generating response: {e}" # ── UI ──────────────────────────────────────────────────────────────────────── examples = [ "What is the capital of Goa?", "Tell me about the Konkani language.", "What are famous beaches in Goa?", "Describe Goan fish curry.", "What is the history of Old Goa?", ] # Best-effort theme across versions try: THEME = gr.themes.Soft() except Exception: THEME = None if MODEL_LOADED: demo = gr.ChatInterface( fn=generate_response, title=TITLE, description=DESCRIPTION, examples=examples, additional_inputs=[ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature"), gr.Slider(minimum=32, maximum=512, value=256, step=16, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"), ], theme=THEME, ) else: demo = gr.Interface( fn=lambda x: "Model failed to load. Check Space logs.", inputs=gr.Textbox(label="Message"), outputs=gr.Textbox(label="Response"), title=TITLE, description=DESCRIPTION, theme=THEME, ) # Queue — keep params minimal for cross-version compat try: demo.queue() except Exception: pass if __name__ == "__main__": print("\n" + "=" * 60) print(f"🚀 Starting Gradio app on {DEVICE} …") print(f"📍 Base model: {BASE_MODEL_ID_DEFAULT}") print(f"🔧 LoRA adapter: {ADAPTER_ID}") print(f"🧩 Merge LoRA: {MERGE_LORA}") print("=" * 60 + "\n") # On Spaces, just calling launch() is fine. demo.launch()