import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel, PeftConfig import torch # Your model details PEFT_MODEL_ID = "Reubencf/gemma3-goan-finetuned" BASE_MODEL_ID = "google/gemma-2-2b-it" # Base model used for fine-tuning # UI Configuration TITLE = "🌴 Gemma Goan Q&A Bot" DESCRIPTION = """ This is a Gemma-2-2B model fine-tuned on Goan Q&A dataset using LoRA. Ask questions about Goa, Konkani culture, or general topics! **Model**: [Reubencf/gemma3-goan-finetuned](https://huggingface.co/Reubencf/gemma3-goan-finetuned) **Base Model**: google/gemma-2-2b-it """ print("Loading model... This might take a few minutes on first run.") try: # Load LoRA config to check base model peft_config = PeftConfig.from_pretrained(PEFT_MODEL_ID) # Load base model print(f"Loading base model: {BASE_MODEL_ID}") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", low_cpu_mem_usage=True, ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Load LoRA adapter print(f"Loading LoRA adapter: {PEFT_MODEL_ID}") model = PeftModel.from_pretrained( base_model, PEFT_MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) # Set to evaluation mode model.eval() print("✅ Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") print("Trying alternative loading method...") # Alternative: Try loading as AutoPeftModel from peft import AutoPeftModelForCausalLM model = AutoPeftModelForCausalLM.from_pretrained( PEFT_MODEL_ID, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, ) tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def generate_response( message, history, temperature=0.7, max_new_tokens=256, top_p=0.95, repetition_penalty=1.1, ): """Generate response using the fine-tuned model""" # Format the prompt using Gemma chat template if history: # Build conversation history conversation = "" for user, assistant in history: conversation += f"user\n{user}\n" conversation += f"model\n{assistant}\n" conversation += f"user\n{message}\nmodel\n" else: # Single turn conversation conversation = f"user\n{message}\nmodel\n" # Tokenize inputs = tokenizer( conversation, return_tensors="pt", truncation=True, max_length=1024 ) # Move to device device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Generate try: with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode only the generated portion generated_tokens = outputs[0][inputs['input_ids'].shape[1]:] response = tokenizer.decode(generated_tokens, skip_special_tokens=True) # Clean up response response = response.replace("", "").strip() return response except Exception as e: return f"Error generating response: {str(e)}" # Example questions examples = [ ["What is Bebinca?"], ["who is promod sawant?"], ["Explain the history of Old Goa"], ["What are some popular festivals in Goa?"], ] # Custom CSS for better appearance custom_css = """ #component-0 { max-width: 900px; margin: auto; } .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } """ # Create Gradio Chat Interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) chatbot = gr.Chatbot( height=450, show_label=False, avatar_images=(None, "🤖"), ) msg = gr.Textbox( label="Ask a question", placeholder="Type your question about Goa, Konkani culture, or any topic...", lines=2, ) with gr.Accordion("⚙️ Generation Settings", open=False): temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature (Creativity)", info="Higher = more creative, Lower = more focused" ) max_tokens = gr.Slider( minimum=50, maximum=512, value=256, step=10, label="Max New Tokens", info="Maximum length of the response" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)", ) rep_penalty = gr.Slider( minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty", ) with gr.Row(): clear = gr.Button("🗑️ Clear") submit = gr.Button("📤 Send", variant="primary") gr.Examples( examples=examples, inputs=msg, label="Example Questions:", ) # Set up event handlers def user(user_message, history): return "", history + [[user_message, None]] def bot(history, temp, max_tok, top_p_val, rep_pen): user_message = history[-1][0] bot_response = generate_response( user_message, history[:-1], temperature=temp, max_new_tokens=max_tok, top_p=top_p_val, repetition_penalty=rep_pen, ) history[-1][1] = bot_response return history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, temperature, max_tokens, top_p, rep_penalty], chatbot ) submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, temperature, max_tokens, top_p, rep_penalty], chatbot ) clear.click(lambda: None, None, chatbot, queue=False) gr.Markdown(""" --- ### 📝 Note This model is fine-tuned specifically on Goan Q&A data. Responses are generated based on patterns learned from the training dataset. For best results, ask questions about Goa, its culture, history, cuisine, and related topics. """) if __name__ == "__main__": demo.launch()