Reubencf commited on
Commit
13a64ed
·
verified ·
1 Parent(s): 77eea02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -112
app.py CHANGED
@@ -2,10 +2,14 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel, PeftConfig
4
  import torch
 
 
 
 
5
 
6
  # Your model details
7
  PEFT_MODEL_ID = "Reubencf/gemma3-goan-finetuned"
8
- BASE_MODEL_ID = "google/gemma-2-2b-it" # Base model used for fine-tuning
9
 
10
  # UI Configuration
11
  TITLE = "🌴 Gemma Goan Q&A Bot"
@@ -14,7 +18,6 @@ This is a Gemma-2-2B model fine-tuned on Goan Q&A dataset using LoRA.
14
  Ask questions about Goa, Konkani culture, or general topics!
15
 
16
  **Model**: [Reubencf/gemma3-goan-finetuned](https://huggingface.co/Reubencf/gemma3-goan-finetuned)
17
- **Base Model**: google/gemma-2-2b-it
18
  """
19
 
20
  print("Loading model... This might take a few minutes on first run.")
@@ -23,17 +26,21 @@ try:
23
  # Load LoRA config to check base model
24
  peft_config = PeftConfig.from_pretrained(PEFT_MODEL_ID)
25
 
26
- # Load base model
27
  print(f"Loading base model: {BASE_MODEL_ID}")
28
  base_model = AutoModelForCausalLM.from_pretrained(
29
  BASE_MODEL_ID,
 
30
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
  device_map="auto",
32
  low_cpu_mem_usage=True,
33
  )
34
 
35
- # Load tokenizer
36
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
 
 
 
37
  if tokenizer.pad_token is None:
38
  tokenizer.pad_token = tokenizer.eos_token
39
  tokenizer.padding_side = "right"
@@ -58,11 +65,15 @@ except Exception as e:
58
  from peft import AutoPeftModelForCausalLM
59
  model = AutoPeftModelForCausalLM.from_pretrained(
60
  PEFT_MODEL_ID,
 
61
  device_map="auto",
62
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
63
  low_cpu_mem_usage=True,
64
  )
65
- tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL_ID)
 
 
 
66
  if tokenizer.pad_token is None:
67
  tokenizer.pad_token = tokenizer.eos_token
68
 
@@ -78,14 +89,12 @@ def generate_response(
78
 
79
  # Format the prompt using Gemma chat template
80
  if history:
81
- # Build conversation history
82
  conversation = ""
83
  for user, assistant in history:
84
  conversation += f"<start_of_turn>user\n{user}<end_of_turn>\n"
85
  conversation += f"<start_of_turn>model\n{assistant}<end_of_turn>\n"
86
  conversation += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
87
  else:
88
- # Single turn conversation
89
  conversation = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
90
 
91
  # Tokenize
@@ -128,113 +137,26 @@ def generate_response(
128
 
129
  # Example questions
130
  examples = [
131
- ["What is Bebinca?"],
132
- ["who is promod sawant?"],
 
 
133
  ["Explain the history of Old Goa"],
134
- ["What are some popular festivals in Goa?"],
135
  ]
136
 
137
- # Custom CSS for better appearance
138
- custom_css = """
139
- #component-0 {
140
- max-width: 900px;
141
- margin: auto;
142
- }
143
- .gradio-container {
144
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
145
- }
146
- """
147
-
148
- # Create Gradio Chat Interface
149
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
150
- gr.Markdown(f"# {TITLE}")
151
- gr.Markdown(DESCRIPTION)
152
-
153
- chatbot = gr.Chatbot(
154
- height=450,
155
- show_label=False,
156
- avatar_images=(None, "🤖"),
157
- )
158
-
159
- msg = gr.Textbox(
160
- label="Ask a question",
161
- placeholder="Type your question about Goa, Konkani culture, or any topic...",
162
- lines=2,
163
- )
164
-
165
- with gr.Accordion("⚙️ Generation Settings", open=False):
166
- temperature = gr.Slider(
167
- minimum=0.1,
168
- maximum=1.0,
169
- value=0.7,
170
- step=0.1,
171
- label="Temperature (Creativity)",
172
- info="Higher = more creative, Lower = more focused"
173
- )
174
- max_tokens = gr.Slider(
175
- minimum=50,
176
- maximum=512,
177
- value=256,
178
- step=10,
179
- label="Max New Tokens",
180
- info="Maximum length of the response"
181
- )
182
- top_p = gr.Slider(
183
- minimum=0.1,
184
- maximum=1.0,
185
- value=0.95,
186
- step=0.05,
187
- label="Top-p (Nucleus Sampling)",
188
- )
189
- rep_penalty = gr.Slider(
190
- minimum=1.0,
191
- maximum=2.0,
192
- value=1.1,
193
- step=0.1,
194
- label="Repetition Penalty",
195
- )
196
-
197
- with gr.Row():
198
- clear = gr.Button("🗑️ Clear")
199
- submit = gr.Button("📤 Send", variant="primary")
200
-
201
- gr.Examples(
202
- examples=examples,
203
- inputs=msg,
204
- label="Example Questions:",
205
- )
206
-
207
- # Set up event handlers
208
- def user(user_message, history):
209
- return "", history + [[user_message, None]]
210
-
211
- def bot(history, temp, max_tok, top_p_val, rep_pen):
212
- user_message = history[-1][0]
213
- bot_response = generate_response(
214
- user_message,
215
- history[:-1],
216
- temperature=temp,
217
- max_new_tokens=max_tok,
218
- top_p=top_p_val,
219
- repetition_penalty=rep_pen,
220
- )
221
- history[-1][1] = bot_response
222
- return history
223
-
224
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
225
- bot, [chatbot, temperature, max_tokens, top_p, rep_penalty], chatbot
226
- )
227
- submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
228
- bot, [chatbot, temperature, max_tokens, top_p, rep_penalty], chatbot
229
- )
230
- clear.click(lambda: None, None, chatbot, queue=False)
231
-
232
- gr.Markdown("""
233
- ---
234
- ### 📝 Note
235
- This model is fine-tuned specifically on Goan Q&A data. Responses are generated based on patterns learned from the training dataset.
236
- For best results, ask questions about Goa, its culture, history, cuisine, and related topics.
237
- """)
238
 
239
  if __name__ == "__main__":
240
  demo.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel, PeftConfig
4
  import torch
5
+ import os
6
+
7
+ # Get token from Space secrets
8
+ HF_TOKEN = os.environ.get("HF_TOKEN")
9
 
10
  # Your model details
11
  PEFT_MODEL_ID = "Reubencf/gemma3-goan-finetuned"
12
+ BASE_MODEL_ID = "google/gemma-3-4b-it" # Correct base model
13
 
14
  # UI Configuration
15
  TITLE = "🌴 Gemma Goan Q&A Bot"
 
18
  Ask questions about Goa, Konkani culture, or general topics!
19
 
20
  **Model**: [Reubencf/gemma3-goan-finetuned](https://huggingface.co/Reubencf/gemma3-goan-finetuned)
 
21
  """
22
 
23
  print("Loading model... This might take a few minutes on first run.")
 
26
  # Load LoRA config to check base model
27
  peft_config = PeftConfig.from_pretrained(PEFT_MODEL_ID)
28
 
29
+ # Load base model WITH TOKEN
30
  print(f"Loading base model: {BASE_MODEL_ID}")
31
  base_model = AutoModelForCausalLM.from_pretrained(
32
  BASE_MODEL_ID,
33
+ token=HF_TOKEN, # Add token here
34
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
35
  device_map="auto",
36
  low_cpu_mem_usage=True,
37
  )
38
 
39
+ # Load tokenizer WITH TOKEN
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ BASE_MODEL_ID,
42
+ token=HF_TOKEN # Add token here
43
+ )
44
  if tokenizer.pad_token is None:
45
  tokenizer.pad_token = tokenizer.eos_token
46
  tokenizer.padding_side = "right"
 
65
  from peft import AutoPeftModelForCausalLM
66
  model = AutoPeftModelForCausalLM.from_pretrained(
67
  PEFT_MODEL_ID,
68
+ token=HF_TOKEN, # Add token here
69
  device_map="auto",
70
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
71
  low_cpu_mem_usage=True,
72
  )
73
+ tokenizer = AutoTokenizer.from_pretrained(
74
+ PEFT_MODEL_ID,
75
+ token=HF_TOKEN # Add token here
76
+ )
77
  if tokenizer.pad_token is None:
78
  tokenizer.pad_token = tokenizer.eos_token
79
 
 
89
 
90
  # Format the prompt using Gemma chat template
91
  if history:
 
92
  conversation = ""
93
  for user, assistant in history:
94
  conversation += f"<start_of_turn>user\n{user}<end_of_turn>\n"
95
  conversation += f"<start_of_turn>model\n{assistant}<end_of_turn>\n"
96
  conversation += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
97
  else:
 
98
  conversation = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
99
 
100
  # Tokenize
 
137
 
138
  # Example questions
139
  examples = [
140
+ ["What is the capital of Goa?"],
141
+ ["Tell me about Konkani language"],
142
+ ["What are famous beaches in Goa?"],
143
+ ["What is Goan fish curry?"],
144
  ["Explain the history of Old Goa"],
 
145
  ]
146
 
147
+ # Create Gradio interface
148
+ demo = gr.ChatInterface(
149
+ fn=generate_response,
150
+ title=TITLE,
151
+ description=DESCRIPTION,
152
+ examples=examples,
153
+ additional_inputs=[
154
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
155
+ gr.Slider(minimum=1, maximum=512, value=256, step=1, label="Max new tokens"),
156
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
157
+ gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
158
+ ],
159
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  if __name__ == "__main__":
162
  demo.launch()