Spaces:
Running
Running
# app.py β Corrected for proper LoRA adapter loading | |
import os | |
import gc | |
import torch | |
import gradio as gr | |
from typing import List, Tuple | |
import warnings | |
warnings.filterwarnings('ignore') | |
try: | |
from peft import PeftConfig, PeftModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
IMPORTS_OK = True | |
except ImportError as e: | |
IMPORTS_OK = False | |
print(f"Missing dependencies: {e}") | |
print("Please install: pip install transformers peft torch gradio accelerate") | |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
HF_TOKEN = os.environ.get("HF_TOKEN") # Optional for public models | |
# Your LoRA adapter location (HuggingFace repo or local path) | |
ADAPTER_ID = "Reubencf/gemma3-goan-finetuned" | |
# For local adapter: ADAPTER_ID = "./path/to/your/adapter" | |
# Base model - MUST match what you used for fine-tuning! | |
# Check your adapter's config.json for "base_model_name_or_path" | |
BASE_MODEL_ID = "google/gemma-3-4b-it" # Change this to your actual base model | |
# Common options: | |
# - "google/gemma-2b-it" (2B parameters, easier on memory) | |
# - "unsloth/gemma-2-2b-it-bnb-4bit" (quantized version) | |
# - Your actual base model used for training | |
# Settings | |
USE_8BIT = False # Set to True if you have GPU and want to use 8-bit quantization | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TITLE = "π΄ Gemma Goan Q&A Bot" | |
DESCRIPTION = """ | |
Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset. | |
Ask about Goa, Konkani culture, or general topics! | |
**Status**: {} | |
""" | |
# ββ Load model + tokenizer (correct LoRA loading) ββββββββββββββββββββββββββββββ | |
def load_model_and_tokenizer(): | |
"""Load base model and apply LoRA adapter correctly""" | |
if not IMPORTS_OK: | |
raise ImportError("Required packages not installed") | |
print("[Init] Starting model load...") | |
print(f"[Config] Base model: {BASE_MODEL_ID}") | |
print(f"[Config] LoRA adapter: {ADAPTER_ID}") | |
print(f"[Config] Device: {DEVICE}") | |
# Memory cleanup | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
status = "" | |
model = None | |
tokenizer = None | |
try: | |
# Step 1: Try to read adapter config to get the correct base model | |
actual_base_model = BASE_MODEL_ID | |
try: | |
print(f"[Load] Checking adapter configuration...") | |
peft_config = PeftConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN) | |
actual_base_model = peft_config.base_model_name_or_path | |
print(f"[Load] Adapter expects base model: {actual_base_model}") | |
# Warn if mismatch | |
if actual_base_model != BASE_MODEL_ID: | |
print(f"[Warning] BASE_MODEL_ID ({BASE_MODEL_ID}) doesn't match adapter's base ({actual_base_model})") | |
print(f"[Load] Using adapter's base model: {actual_base_model}") | |
except Exception as e: | |
print(f"[Warning] Cannot read adapter config: {e}") | |
print(f"[Load] Will try with configured base model: {BASE_MODEL_ID}") | |
actual_base_model = BASE_MODEL_ID | |
# Step 2: Load the BASE MODEL (not the adapter!) | |
print(f"[Load] Loading base model: {actual_base_model}") | |
# Quantization config for GPU | |
quantization_config = None | |
if USE_8BIT and torch.cuda.is_available(): | |
print("[Load] Using 8-bit quantization") | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16 | |
) | |
# Load base model | |
base_model = AutoModelForCausalLM.from_pretrained( | |
actual_base_model, | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
quantization_config=quantization_config, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
) | |
# Move to device if needed | |
if DEVICE == "cpu" and not torch.cuda.is_available(): | |
base_model = base_model.to("cpu") | |
print("[Load] Model on CPU") | |
print("[Load] Base model loaded successfully") | |
# Step 3: Load tokenizer from the BASE MODEL | |
print(f"[Load] Loading tokenizer from base model...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
actual_base_model, | |
token=HF_TOKEN, | |
use_fast=True, | |
trust_remote_code=True, | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "left" | |
# Step 4: Try to apply LoRA adapter | |
try: | |
print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}") | |
model = PeftModel.from_pretrained( | |
base_model, | |
ADAPTER_ID, | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
is_trainable=False, # Inference only | |
) | |
# Optional: Merge adapter with base model for faster inference | |
# This combines the weights permanently (uses more memory initially but faster inference) | |
merge = input("\nπ‘ Merge adapter for faster inference? (y/n, default=y): ").strip().lower() | |
if merge != 'n': | |
print("[Load] Merging adapter with base model...") | |
model = model.merge_and_unload() | |
print("[Load] Adapter merged successfully") | |
status = f"β Using fine-tuned model (merged): {ADAPTER_ID}" | |
else: | |
print("[Load] Using adapter without merging") | |
status = f"β Using fine-tuned model: {ADAPTER_ID}" | |
except FileNotFoundError as e: | |
print(f"[Error] Adapter files not found: {e}") | |
print("[Fallback] Using base model without fine-tuning") | |
model = base_model | |
status = f"β οΈ Adapter not found. Using base model only: {actual_base_model}" | |
except Exception as e: | |
print(f"[Error] Failed to load adapter: {e}") | |
print("[Fallback] Using base model without fine-tuning") | |
model = base_model | |
status = f"β οΈ Could not load adapter. Using base model only: {actual_base_model}" | |
# Step 5: Final setup | |
model.eval() | |
print(f"[Load] Model ready on {DEVICE}!") | |
# Memory cleanup | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return model, tokenizer, status | |
except Exception as e: | |
error_msg = f"Failed to load model: {str(e)}" | |
print(f"[Fatal] {error_msg}") | |
# Try fallback to smallest model | |
if "gemma-2b" not in BASE_MODEL_ID.lower(): | |
print("[Fallback] Trying with gemma-2b-it...") | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-3-4b-it", | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32, | |
device_map=None, | |
).to("cpu") | |
tokenizer = AutoTokenizer.from_pretrained( | |
"google/gemma-3-4b-it", | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
base_model.eval() | |
return base_model, tokenizer, "β οΈ Using fallback model: gemma-2b-it (no fine-tuning)" | |
except Exception as fallback_error: | |
print(f"[Fatal] Fallback also failed: {fallback_error}") | |
raise gr.Error(f"Cannot load any model. Check your configuration.") | |
else: | |
raise gr.Error(error_msg) | |
# Load model globally | |
try: | |
model, tokenizer, STATUS_MSG = load_model_and_tokenizer() | |
MODEL_LOADED = True | |
DESCRIPTION = DESCRIPTION.format(STATUS_MSG) | |
except Exception as e: | |
print(f"[Fatal] Could not load model: {e}") | |
MODEL_LOADED = False | |
model, tokenizer = None, None | |
DESCRIPTION = DESCRIPTION.format(f"β Model failed to load: {str(e)[:100]}") | |
# ββ Generation function βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def generate_response( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float = 0.7, | |
max_new_tokens: int = 256, | |
top_p: float = 0.95, | |
repetition_penalty: float = 1.1, | |
) -> str: | |
"""Generate response using the model""" | |
if not MODEL_LOADED: | |
return "β οΈ Model failed to load. Please check the logs or restart the application." | |
try: | |
# Build conversation | |
conversation = [] | |
if history: | |
# Keep last 3 exchanges for context | |
for user_msg, assistant_msg in history[-3:]: | |
if user_msg: | |
conversation.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
conversation.append({"role": "assistant", "content": assistant_msg}) | |
conversation.append({"role": "user", "content": message}) | |
# Apply chat template | |
try: | |
prompt = tokenizer.apply_chat_template( | |
conversation, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
) | |
except Exception as e: | |
print(f"[Warning] Chat template failed: {e}, using fallback format") | |
# Fallback format | |
prompt_text = "" | |
for msg in conversation: | |
if msg["role"] == "user": | |
prompt_text += f"User: {msg['content']}\n" | |
else: | |
prompt_text += f"Assistant: {msg['content']}\n" | |
prompt_text += "Assistant: " | |
inputs = tokenizer( | |
prompt_text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512 | |
) | |
prompt = inputs.input_ids | |
# Move to device | |
prompt = prompt.to(model.device if hasattr(model, 'device') else DEVICE) | |
# Generate | |
print(f"[Generate] Input length: {prompt.shape[-1]} tokens") | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=prompt, | |
max_new_tokens=min(int(max_new_tokens), 256), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
repetition_penalty=float(repetition_penalty), | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
use_cache=True, | |
) | |
# Decode only generated tokens | |
generated_tokens = outputs[0][prompt.shape[-1]:] | |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
print(f"[Generate] Output length: {len(generated_tokens)} tokens") | |
# Cleanup | |
del outputs, prompt, generated_tokens | |
gc.collect() | |
return response | |
except Exception as e: | |
error_msg = f"β οΈ Error generating response: {str(e)}" | |
print(f"[Error] {error_msg}") | |
# Try to recover memory | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return error_msg | |
# ββ Gradio Interface ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
examples = [ | |
["What is the capital of Goa?"], | |
["Tell me about Konkani language"], | |
["What are famous beaches in Goa?"], | |
["Describe Goan fish curry"], | |
["What is the history of Old Goa?"], | |
] | |
# Create interface | |
if MODEL_LOADED: | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title=TITLE, | |
description=DESCRIPTION, | |
examples=examples, | |
retry_btn=None, | |
undo_btn=None, | |
additional_inputs=[ | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.05, | |
label="Temperature (lower = more focused)" | |
), | |
gr.Slider( | |
minimum=32, | |
maximum=256, | |
value=128, | |
step=16, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
gr.Slider( | |
minimum=1.0, | |
maximum=2.0, | |
value=1.1, | |
step=0.05, | |
label="Repetition penalty" | |
), | |
], | |
theme=gr.themes.Soft(), | |
) | |
else: | |
demo = gr.Interface( | |
fn=lambda x: "Model failed to load. Check console for errors.", | |
inputs=gr.Textbox(label="Message"), | |
outputs=gr.Textbox(label="Response"), | |
title=TITLE, | |
description=DESCRIPTION, | |
) | |
# Queue with version compatibility | |
try: | |
# Try newer Gradio syntax first (4.x) | |
demo.queue(default_concurrency_limit=1, max_size=10) | |
except TypeError: | |
try: | |
# Fall back to older syntax (3.x) | |
demo.queue(concurrency_count=1, max_size=10) | |
except: | |
# If both fail, try without parameters | |
demo.queue() | |
if __name__ == "__main__": | |
print("\n" + "="*50) | |
print(f"π Starting Gradio app on {DEVICE}...") | |
print(f"π Base model: {BASE_MODEL_ID}") | |
print(f"π§ LoRA adapter: {ADAPTER_ID}") | |
print("="*50 + "\n") | |
demo.launch() |