Reubencf's picture
Update app.py
ae129eb verified
raw
history blame
14.5 kB
# app.py β€” Corrected for proper LoRA adapter loading
import os
import gc
import torch
import gradio as gr
from typing import List, Tuple
import warnings
warnings.filterwarnings('ignore')
try:
from peft import PeftConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
IMPORTS_OK = True
except ImportError as e:
IMPORTS_OK = False
print(f"Missing dependencies: {e}")
print("Please install: pip install transformers peft torch gradio accelerate")
# ── Configuration ──────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN") # Optional for public models
# Your LoRA adapter location (HuggingFace repo or local path)
ADAPTER_ID = "Reubencf/gemma3-goan-finetuned"
# For local adapter: ADAPTER_ID = "./path/to/your/adapter"
# Base model - MUST match what you used for fine-tuning!
# Check your adapter's config.json for "base_model_name_or_path"
BASE_MODEL_ID = "google/gemma-3-4b-it" # Change this to your actual base model
# Common options:
# - "google/gemma-2b-it" (2B parameters, easier on memory)
# - "unsloth/gemma-2-2b-it-bnb-4bit" (quantized version)
# - Your actual base model used for training
# Settings
USE_8BIT = False # Set to True if you have GPU and want to use 8-bit quantization
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TITLE = "🌴 Gemma Goan Q&A Bot"
DESCRIPTION = """
Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset.
Ask about Goa, Konkani culture, or general topics!
**Status**: {}
"""
# ── Load model + tokenizer (correct LoRA loading) ──────────────────────────────
def load_model_and_tokenizer():
"""Load base model and apply LoRA adapter correctly"""
if not IMPORTS_OK:
raise ImportError("Required packages not installed")
print("[Init] Starting model load...")
print(f"[Config] Base model: {BASE_MODEL_ID}")
print(f"[Config] LoRA adapter: {ADAPTER_ID}")
print(f"[Config] Device: {DEVICE}")
# Memory cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
status = ""
model = None
tokenizer = None
try:
# Step 1: Try to read adapter config to get the correct base model
actual_base_model = BASE_MODEL_ID
try:
print(f"[Load] Checking adapter configuration...")
peft_config = PeftConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN)
actual_base_model = peft_config.base_model_name_or_path
print(f"[Load] Adapter expects base model: {actual_base_model}")
# Warn if mismatch
if actual_base_model != BASE_MODEL_ID:
print(f"[Warning] BASE_MODEL_ID ({BASE_MODEL_ID}) doesn't match adapter's base ({actual_base_model})")
print(f"[Load] Using adapter's base model: {actual_base_model}")
except Exception as e:
print(f"[Warning] Cannot read adapter config: {e}")
print(f"[Load] Will try with configured base model: {BASE_MODEL_ID}")
actual_base_model = BASE_MODEL_ID
# Step 2: Load the BASE MODEL (not the adapter!)
print(f"[Load] Loading base model: {actual_base_model}")
# Quantization config for GPU
quantization_config = None
if USE_8BIT and torch.cuda.is_available():
print("[Load] Using 8-bit quantization")
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16
)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
actual_base_model,
token=HF_TOKEN,
trust_remote_code=True,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
)
# Move to device if needed
if DEVICE == "cpu" and not torch.cuda.is_available():
base_model = base_model.to("cpu")
print("[Load] Model on CPU")
print("[Load] Base model loaded successfully")
# Step 3: Load tokenizer from the BASE MODEL
print(f"[Load] Loading tokenizer from base model...")
tokenizer = AutoTokenizer.from_pretrained(
actual_base_model,
token=HF_TOKEN,
use_fast=True,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Step 4: Try to apply LoRA adapter
try:
print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}")
model = PeftModel.from_pretrained(
base_model,
ADAPTER_ID,
token=HF_TOKEN,
trust_remote_code=True,
is_trainable=False, # Inference only
)
# Optional: Merge adapter with base model for faster inference
# This combines the weights permanently (uses more memory initially but faster inference)
merge = input("\nπŸ’‘ Merge adapter for faster inference? (y/n, default=y): ").strip().lower()
if merge != 'n':
print("[Load] Merging adapter with base model...")
model = model.merge_and_unload()
print("[Load] Adapter merged successfully")
status = f"βœ… Using fine-tuned model (merged): {ADAPTER_ID}"
else:
print("[Load] Using adapter without merging")
status = f"βœ… Using fine-tuned model: {ADAPTER_ID}"
except FileNotFoundError as e:
print(f"[Error] Adapter files not found: {e}")
print("[Fallback] Using base model without fine-tuning")
model = base_model
status = f"⚠️ Adapter not found. Using base model only: {actual_base_model}"
except Exception as e:
print(f"[Error] Failed to load adapter: {e}")
print("[Fallback] Using base model without fine-tuning")
model = base_model
status = f"⚠️ Could not load adapter. Using base model only: {actual_base_model}"
# Step 5: Final setup
model.eval()
print(f"[Load] Model ready on {DEVICE}!")
# Memory cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return model, tokenizer, status
except Exception as e:
error_msg = f"Failed to load model: {str(e)}"
print(f"[Fatal] {error_msg}")
# Try fallback to smallest model
if "gemma-2b" not in BASE_MODEL_ID.lower():
print("[Fallback] Trying with gemma-2b-it...")
try:
base_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-3-4b-it",
token=HF_TOKEN,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
device_map=None,
).to("cpu")
tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-3-4b-it",
token=HF_TOKEN,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model.eval()
return base_model, tokenizer, "⚠️ Using fallback model: gemma-2b-it (no fine-tuning)"
except Exception as fallback_error:
print(f"[Fatal] Fallback also failed: {fallback_error}")
raise gr.Error(f"Cannot load any model. Check your configuration.")
else:
raise gr.Error(error_msg)
# Load model globally
try:
model, tokenizer, STATUS_MSG = load_model_and_tokenizer()
MODEL_LOADED = True
DESCRIPTION = DESCRIPTION.format(STATUS_MSG)
except Exception as e:
print(f"[Fatal] Could not load model: {e}")
MODEL_LOADED = False
model, tokenizer = None, None
DESCRIPTION = DESCRIPTION.format(f"❌ Model failed to load: {str(e)[:100]}")
# ── Generation function ─────────────────────────────────────────────────────────
def generate_response(
message: str,
history: List[Tuple[str, str]],
temperature: float = 0.7,
max_new_tokens: int = 256,
top_p: float = 0.95,
repetition_penalty: float = 1.1,
) -> str:
"""Generate response using the model"""
if not MODEL_LOADED:
return "⚠️ Model failed to load. Please check the logs or restart the application."
try:
# Build conversation
conversation = []
if history:
# Keep last 3 exchanges for context
for user_msg, assistant_msg in history[-3:]:
if user_msg:
conversation.append({"role": "user", "content": user_msg})
if assistant_msg:
conversation.append({"role": "assistant", "content": assistant_msg})
conversation.append({"role": "user", "content": message})
# Apply chat template
try:
prompt = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
)
except Exception as e:
print(f"[Warning] Chat template failed: {e}, using fallback format")
# Fallback format
prompt_text = ""
for msg in conversation:
if msg["role"] == "user":
prompt_text += f"User: {msg['content']}\n"
else:
prompt_text += f"Assistant: {msg['content']}\n"
prompt_text += "Assistant: "
inputs = tokenizer(
prompt_text,
return_tensors="pt",
truncation=True,
max_length=512
)
prompt = inputs.input_ids
# Move to device
prompt = prompt.to(model.device if hasattr(model, 'device') else DEVICE)
# Generate
print(f"[Generate] Input length: {prompt.shape[-1]} tokens")
with torch.no_grad():
outputs = model.generate(
input_ids=prompt,
max_new_tokens=min(int(max_new_tokens), 256),
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
# Decode only generated tokens
generated_tokens = outputs[0][prompt.shape[-1]:]
response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
print(f"[Generate] Output length: {len(generated_tokens)} tokens")
# Cleanup
del outputs, prompt, generated_tokens
gc.collect()
return response
except Exception as e:
error_msg = f"⚠️ Error generating response: {str(e)}"
print(f"[Error] {error_msg}")
# Try to recover memory
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return error_msg
# ── Gradio Interface ────────────────────────────────────────────────────────────
examples = [
["What is the capital of Goa?"],
["Tell me about Konkani language"],
["What are famous beaches in Goa?"],
["Describe Goan fish curry"],
["What is the history of Old Goa?"],
]
# Create interface
if MODEL_LOADED:
demo = gr.ChatInterface(
fn=generate_response,
title=TITLE,
description=DESCRIPTION,
examples=examples,
retry_btn=None,
undo_btn=None,
additional_inputs=[
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.05,
label="Temperature (lower = more focused)"
),
gr.Slider(
minimum=32,
maximum=256,
value=128,
step=16,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.05,
label="Repetition penalty"
),
],
theme=gr.themes.Soft(),
)
else:
demo = gr.Interface(
fn=lambda x: "Model failed to load. Check console for errors.",
inputs=gr.Textbox(label="Message"),
outputs=gr.Textbox(label="Response"),
title=TITLE,
description=DESCRIPTION,
)
# Queue with version compatibility
try:
# Try newer Gradio syntax first (4.x)
demo.queue(default_concurrency_limit=1, max_size=10)
except TypeError:
try:
# Fall back to older syntax (3.x)
demo.queue(concurrency_count=1, max_size=10)
except:
# If both fail, try without parameters
demo.queue()
if __name__ == "__main__":
print("\n" + "="*50)
print(f"πŸš€ Starting Gradio app on {DEVICE}...")
print(f"πŸ“ Base model: {BASE_MODEL_ID}")
print(f"πŸ”§ LoRA adapter: {ADAPTER_ID}")
print("="*50 + "\n")
demo.launch()