Reubencf commited on
Commit
a18b9c3
Β·
verified Β·
1 Parent(s): abc6167

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -54
app.py CHANGED
@@ -1,70 +1,240 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
 
 
24
 
25
- response = ""
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  minimum=0.1,
55
  maximum=1.0,
56
  value=0.95,
57
  step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel, PeftConfig
4
+ import torch
5
 
6
+ # Your model details
7
+ PEFT_MODEL_ID = "Reubencf/gemma3-goan-finetuned"
8
+ BASE_MODEL_ID = "google/gemma-2-2b-it" # Base model used for fine-tuning
9
 
10
+ # UI Configuration
11
+ TITLE = "🌴 Gemma Goan Q&A Bot"
12
+ DESCRIPTION = """
13
+ This is a Gemma-2-2B model fine-tuned on Goan Q&A dataset using LoRA.
14
+ Ask questions about Goa, Konkani culture, or general topics!
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ **Model**: [Reubencf/gemma3-goan-finetuned](https://huggingface.co/Reubencf/gemma3-goan-finetuned)
17
+ **Base Model**: google/gemma-2-2b-it
18
+ """
19
 
20
+ print("Loading model... This might take a few minutes on first run.")
21
 
22
+ try:
23
+ # Load LoRA config to check base model
24
+ peft_config = PeftConfig.from_pretrained(PEFT_MODEL_ID)
25
+
26
+ # Load base model
27
+ print(f"Loading base model: {BASE_MODEL_ID}")
28
+ base_model = AutoModelForCausalLM.from_pretrained(
29
+ BASE_MODEL_ID,
30
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
+ device_map="auto",
32
+ low_cpu_mem_usage=True,
33
+ )
34
+
35
+ # Load tokenizer
36
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
37
+ if tokenizer.pad_token is None:
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+ tokenizer.padding_side = "right"
40
+
41
+ # Load LoRA adapter
42
+ print(f"Loading LoRA adapter: {PEFT_MODEL_ID}")
43
+ model = PeftModel.from_pretrained(
44
+ base_model,
45
+ PEFT_MODEL_ID,
46
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
47
+ )
48
+
49
+ # Set to evaluation mode
50
+ model.eval()
51
+ print("βœ… Model loaded successfully!")
52
+
53
+ except Exception as e:
54
+ print(f"Error loading model: {e}")
55
+ print("Trying alternative loading method...")
56
+
57
+ # Alternative: Try loading as AutoPeftModel
58
+ from peft import AutoPeftModelForCausalLM
59
+ model = AutoPeftModelForCausalLM.from_pretrained(
60
+ PEFT_MODEL_ID,
61
+ device_map="auto",
62
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
63
+ low_cpu_mem_usage=True,
64
+ )
65
+ tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL_ID)
66
+ if tokenizer.pad_token is None:
67
+ tokenizer.pad_token = tokenizer.eos_token
68
 
69
+ def generate_response(
70
+ message,
71
+ history,
72
+ temperature=0.7,
73
+ max_new_tokens=256,
74
+ top_p=0.95,
75
+ repetition_penalty=1.1,
76
+ ):
77
+ """Generate response using the fine-tuned model"""
78
+
79
+ # Format the prompt using Gemma chat template
80
+ if history:
81
+ # Build conversation history
82
+ conversation = ""
83
+ for user, assistant in history:
84
+ conversation += f"<start_of_turn>user\n{user}<end_of_turn>\n"
85
+ conversation += f"<start_of_turn>model\n{assistant}<end_of_turn>\n"
86
+ conversation += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
87
+ else:
88
+ # Single turn conversation
89
+ conversation = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
90
+
91
+ # Tokenize
92
+ inputs = tokenizer(
93
+ conversation,
94
+ return_tensors="pt",
95
+ truncation=True,
96
+ max_length=1024
97
+ )
98
+
99
+ # Move to device
100
+ device = next(model.parameters()).device
101
+ inputs = {k: v.to(device) for k, v in inputs.items()}
102
+
103
+ # Generate
104
+ try:
105
+ with torch.no_grad():
106
+ outputs = model.generate(
107
+ **inputs,
108
+ max_new_tokens=max_new_tokens,
109
+ temperature=temperature,
110
+ top_p=top_p,
111
+ repetition_penalty=repetition_penalty,
112
+ do_sample=True,
113
+ pad_token_id=tokenizer.pad_token_id,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ )
116
+
117
+ # Decode only the generated portion
118
+ generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
119
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
120
+
121
+ # Clean up response
122
+ response = response.replace("<end_of_turn>", "").strip()
123
+
124
+ return response
125
+
126
+ except Exception as e:
127
+ return f"Error generating response: {str(e)}"
128
 
129
+ # Example questions
130
+ examples = [
131
+ ["What is Bebinca?"],
132
+ ["who is promod sawant?"],
133
+ ["Explain the history of Old Goa"],
134
+ ["What are some popular festivals in Goa?"],
135
+ ]
136
 
137
+ # Custom CSS for better appearance
138
+ custom_css = """
139
+ #component-0 {
140
+ max-width: 900px;
141
+ margin: auto;
142
+ }
143
+ .gradio-container {
144
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
145
+ }
146
  """
147
+
148
+ # Create Gradio Chat Interface
149
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
150
+ gr.Markdown(f"# {TITLE}")
151
+ gr.Markdown(DESCRIPTION)
152
+
153
+ chatbot = gr.Chatbot(
154
+ height=450,
155
+ show_label=False,
156
+ avatar_images=(None, "πŸ€–"),
157
+ )
158
+
159
+ msg = gr.Textbox(
160
+ label="Ask a question",
161
+ placeholder="Type your question about Goa, Konkani culture, or any topic...",
162
+ lines=2,
163
+ )
164
+
165
+ with gr.Accordion("βš™οΈ Generation Settings", open=False):
166
+ temperature = gr.Slider(
167
+ minimum=0.1,
168
+ maximum=1.0,
169
+ value=0.7,
170
+ step=0.1,
171
+ label="Temperature (Creativity)",
172
+ info="Higher = more creative, Lower = more focused"
173
+ )
174
+ max_tokens = gr.Slider(
175
+ minimum=50,
176
+ maximum=512,
177
+ value=256,
178
+ step=10,
179
+ label="Max New Tokens",
180
+ info="Maximum length of the response"
181
+ )
182
+ top_p = gr.Slider(
183
  minimum=0.1,
184
  maximum=1.0,
185
  value=0.95,
186
  step=0.05,
187
+ label="Top-p (Nucleus Sampling)",
188
+ )
189
+ rep_penalty = gr.Slider(
190
+ minimum=1.0,
191
+ maximum=2.0,
192
+ value=1.1,
193
+ step=0.1,
194
+ label="Repetition Penalty",
195
+ )
196
+
197
+ with gr.Row():
198
+ clear = gr.Button("πŸ—‘οΈ Clear")
199
+ submit = gr.Button("πŸ“€ Send", variant="primary")
200
+
201
+ gr.Examples(
202
+ examples=examples,
203
+ inputs=msg,
204
+ label="Example Questions:",
205
+ )
206
+
207
+ # Set up event handlers
208
+ def user(user_message, history):
209
+ return "", history + [[user_message, None]]
210
+
211
+ def bot(history, temp, max_tok, top_p_val, rep_pen):
212
+ user_message = history[-1][0]
213
+ bot_response = generate_response(
214
+ user_message,
215
+ history[:-1],
216
+ temperature=temp,
217
+ max_new_tokens=max_tok,
218
+ top_p=top_p_val,
219
+ repetition_penalty=rep_pen,
220
+ )
221
+ history[-1][1] = bot_response
222
+ return history
223
+
224
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
225
+ bot, [chatbot, temperature, max_tokens, top_p, rep_penalty], chatbot
226
+ )
227
+ submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
228
+ bot, [chatbot, temperature, max_tokens, top_p, rep_penalty], chatbot
229
+ )
230
+ clear.click(lambda: None, None, chatbot, queue=False)
231
+
232
+ gr.Markdown("""
233
+ ---
234
+ ### πŸ“ Note
235
+ This model is fine-tuned specifically on Goan Q&A data. Responses are generated based on patterns learned from the training dataset.
236
+ For best results, ask questions about Goa, its culture, history, cuisine, and related topics.
237
+ """)
238
 
239
  if __name__ == "__main__":
240
+ demo.launch()