Reubencf commited on
Commit
6f40086
Β·
verified Β·
1 Parent(s): b441326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -87
app.py CHANGED
@@ -1,142 +1,264 @@
1
- # app.py β€” Hugging Face Spaces (Gradio) friendly
2
 
3
  import os
 
4
  import torch
5
  import gradio as gr
6
  from typing import List, Tuple
7
 
8
- from peft import PeftConfig, PeftModel, AutoPeftModelForCausalLM
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
- # ── Secrets / IDs ──────────────────────────────────────────────────────────────
12
  HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space β†’ Settings β†’ Variables & secrets
13
  ADAPTER_ID = "Reubencf/gemma3-goan-finetuned" # your LoRA adapter repo
14
 
 
 
 
 
 
15
  TITLE = "🌴 Gemma Goan Q&A Bot"
16
  DESCRIPTION = """
17
  Gemma-3-4B-Instruct base + LoRA adapter fine-tuned on a Goan Q&A dataset.
18
  Ask about Goa, Konkani culture, or general topics!
19
 
20
  **Adapter**: https://huggingface.co/Reubencf/gemma3-goan-finetuned
 
 
21
  """
22
 
23
- # ── Load model + tokenizer (robust, Space-safe) ────────────────────────────────
24
  def load_model_and_tokenizer():
25
- # Read the exact base the adapter expects (prevents KeyError on wrong modules)
 
 
 
 
26
  peft_cfg = PeftConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN)
27
- base_id = peft_cfg.base_model_name_or_path # e.g. "google/gemma-3-4b-it"
28
- print(f"[Load] Adapter expects base: {base_id}")
29
-
30
- # Try the simple path first: AutoPEFT will fetch the correct base under the hood
 
 
 
 
31
  try:
32
- model = AutoPeftModelForCausalLM.from_pretrained(
33
- ADAPTER_ID,
34
- token=HF_TOKEN,
35
- trust_remote_code=True,
36
- device_map="auto", # Spaces CPU/GPU both OK
37
- low_cpu_mem_usage=True,
38
- torch_dtype=(torch.bfloat16 if torch.cuda.is_available() else torch.float32),
39
- )
40
- print("[Load] AutoPEFT model loaded")
41
- except Exception as e:
42
- # Fallback: explicit base + attach adapter (also works on CPU Spaces)
43
- print(f"[Warn] AutoPEFT failed: {e} β€” trying explicit base + adapter...")
44
- base = AutoModelForCausalLM.from_pretrained(
45
  base_id,
46
  token=HF_TOKEN,
47
  trust_remote_code=True,
48
- device_map=None, # explicit CPU (safer on free tier)
49
  low_cpu_mem_usage=True,
50
- torch_dtype=torch.float32,
51
- ).to("cpu")
 
 
 
 
 
 
 
 
 
 
52
  model = PeftModel.from_pretrained(
53
- base,
54
  ADAPTER_ID,
55
  token=HF_TOKEN,
56
  trust_remote_code=True,
 
57
  )
58
- print("[Load] Base+adapter attached")
59
-
60
- # Always load tokenizer from the **base model**
61
- tok = AutoTokenizer.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  base_id,
63
  token=HF_TOKEN,
64
  use_fast=True,
65
  trust_remote_code=True,
66
  )
67
- if tok.pad_token is None:
68
- tok.pad_token = tok.eos_token
69
- tok.padding_side = "right"
70
-
 
 
 
71
  model.eval()
72
- return model, tok, base_id
 
 
 
 
73
 
74
- model, tokenizer, BASE_ID = load_model_and_tokenizer()
75
-
76
- # ── Generation using chat template ─────────────────────────────────────────────
77
- def _to_chat(history: List[Tuple[str, str]], message: str):
78
- conv = []
79
- if history:
80
- for u, a in history:
81
- if u:
82
- conv.append({"role": "user", "content": u})
83
- if a:
84
- conv.append({"role": "assistant", "content": a})
85
- conv.append({"role": "user", "content": message})
86
- return conv
87
 
 
88
  def generate_response(
89
- message,
90
- history,
91
- temperature=0.7,
92
- max_new_tokens=256,
93
- top_p=0.95,
94
- repetition_penalty=1.1,
95
- ):
 
 
 
 
 
96
  try:
97
- conv = _to_chat(history, message)
98
- # Let Transformers format for Gemma’s chat template
99
- inputs = tokenizer.apply_chat_template(
100
- conv, add_generation_prompt=True, return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
101
  )
102
- inputs = inputs.to(next(model.parameters()).device)
103
-
 
 
 
104
  with torch.no_grad():
105
- out = model.generate(
106
- input_ids=inputs,
107
- max_new_tokens=int(max_new_tokens),
 
108
  temperature=float(temperature),
109
  top_p=float(top_p),
110
  repetition_penalty=float(repetition_penalty),
111
  do_sample=True,
112
  pad_token_id=tokenizer.pad_token_id,
113
  eos_token_id=tokenizer.eos_token_id,
 
114
  )
115
-
116
- gen = out[0, inputs.shape[-1]:]
117
- return tokenizer.decode(gen, skip_special_tokens=True).strip()
118
-
 
 
 
 
 
 
 
 
 
 
 
119
  except Exception as e:
120
- return f"Error generating response: {e}"
121
 
122
- # ── Gradio app (Spaces auto-detects `demo`) ────────────────────────────────────
123
  examples = [
124
  ["What is the capital of Goa?"],
125
- ["Tell me about the Konkani language."],
126
- ["Famous beaches in Goa?"],
127
- ["What is Goan fish curry?"],
128
- ["Explain the history of Old Goa."],
129
  ]
130
 
131
- demo = gr.ChatInterface(
132
- fn=generate_response,
133
- title=TITLE,
134
- description=DESCRIPTION,
135
- examples=examples,
136
- additional_inputs=[
137
- gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="Temperature"),
138
- gr.Slider(32, 512, value=256, step=16, label="Max new tokens"),
139
- gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
140
- gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty"),
141
- ],
142
- ).queue(concurrency_count=1, max_size=16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py β€” Optimized for Hugging Face Spaces Free Tier (CPU-only)
2
 
3
  import os
4
+ import gc
5
  import torch
6
  import gradio as gr
7
  from typing import List, Tuple
8
 
9
+ from peft import PeftConfig, PeftModel
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
11
 
12
+ # ── Configuration ──────────────────────────────────────────────────────────────
13
  HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space β†’ Settings β†’ Variables & secrets
14
  ADAPTER_ID = "Reubencf/gemma3-goan-finetuned" # your LoRA adapter repo
15
 
16
+ # Free tier optimization flags
17
+ USE_8BIT = False # Set to True if you have access to GPU tier
18
+ MAX_MEMORY = "15GB" # Conservative for free tier
19
+ DEVICE = "cpu" # Force CPU for free tier
20
+
21
  TITLE = "🌴 Gemma Goan Q&A Bot"
22
  DESCRIPTION = """
23
  Gemma-3-4B-Instruct base + LoRA adapter fine-tuned on a Goan Q&A dataset.
24
  Ask about Goa, Konkani culture, or general topics!
25
 
26
  **Adapter**: https://huggingface.co/Reubencf/gemma3-goan-finetuned
27
+
28
+ ⚠️ **Note**: Running on free tier (CPU). Responses may be slower. For faster inference, consider upgrading to GPU tier.
29
  """
30
 
31
+ # ── Load model + tokenizer (optimized for free tier) ───────────────────────────
32
  def load_model_and_tokenizer():
33
+ """Load model with memory optimizations for free tier"""
34
+
35
+ print("[Init] Starting model load for free tier...")
36
+
37
+ # Get the base model ID from adapter config
38
  peft_cfg = PeftConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN)
39
+ base_id = peft_cfg.base_model_name_or_path
40
+ print(f"[Load] Base model: {base_id}")
41
+
42
+ # Memory cleanup before loading
43
+ gc.collect()
44
+ if torch.cuda.is_available():
45
+ torch.cuda.empty_cache()
46
+
47
  try:
48
+ # Load base model with memory optimizations
49
+ print("[Load] Loading base model with CPU optimizations...")
50
+
51
+ # Quantization config (only if GPU available and enabled)
52
+ quantization_config = None
53
+ if USE_8BIT and torch.cuda.is_available():
54
+ quantization_config = BitsAndBytesConfig(
55
+ load_in_8bit=True,
56
+ bnb_8bit_compute_dtype=torch.float16
57
+ )
58
+
59
+ # Load base model
60
+ base_model = AutoModelForCausalLM.from_pretrained(
61
  base_id,
62
  token=HF_TOKEN,
63
  trust_remote_code=True,
64
+ quantization_config=quantization_config,
65
  low_cpu_mem_usage=True,
66
+ torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16,
67
+ device_map=None, # We'll move manually
68
+ max_memory={0: MAX_MEMORY} if torch.cuda.is_available() else None,
69
+ )
70
+
71
+ # Move to device
72
+ if DEVICE == "cpu":
73
+ base_model = base_model.to("cpu")
74
+ print("[Load] Model moved to CPU")
75
+
76
+ # Load and apply LoRA adapter
77
+ print("[Load] Loading LoRA adapter...")
78
  model = PeftModel.from_pretrained(
79
+ base_model,
80
  ADAPTER_ID,
81
  token=HF_TOKEN,
82
  trust_remote_code=True,
83
+ is_trainable=False, # Inference only
84
  )
85
+
86
+ # Merge adapter with base (reduces memory overhead during inference)
87
+ print("[Load] Merging adapter for efficiency...")
88
+ model = model.merge_and_unload()
89
+
90
+ print("[Load] Model loaded successfully!")
91
+
92
+ except Exception as e:
93
+ print(f"[Error] Failed to load model: {e}")
94
+ raise gr.Error(
95
+ f"Failed to load model. This may be due to memory constraints on free tier. "
96
+ f"Consider using a smaller model or upgrading to GPU tier. Error: {str(e)}"
97
+ )
98
+
99
+ # Load tokenizer
100
+ print("[Load] Loading tokenizer...")
101
+ tokenizer = AutoTokenizer.from_pretrained(
102
  base_id,
103
  token=HF_TOKEN,
104
  use_fast=True,
105
  trust_remote_code=True,
106
  )
107
+
108
+ # Set padding token
109
+ if tokenizer.pad_token is None:
110
+ tokenizer.pad_token = tokenizer.eos_token
111
+ tokenizer.padding_side = "left" # Better for generation
112
+
113
+ # Set model to eval mode
114
  model.eval()
115
+
116
+ # Memory cleanup
117
+ gc.collect()
118
+
119
+ return model, tokenizer, base_id
120
 
121
+ # Load model globally (done once at startup)
122
+ try:
123
+ model, tokenizer, BASE_ID = load_model_and_tokenizer()
124
+ MODEL_LOADED = True
125
+ except Exception as e:
126
+ print(f"[Fatal] Could not load model: {e}")
127
+ MODEL_LOADED = False
128
+ model, tokenizer, BASE_ID = None, None, None
 
 
 
 
 
129
 
130
+ # ── Generation function ─────────────────────────────────────────────────────────
131
  def generate_response(
132
+ message: str,
133
+ history: List[Tuple[str, str]],
134
+ temperature: float = 0.7,
135
+ max_new_tokens: int = 256,
136
+ top_p: float = 0.95,
137
+ repetition_penalty: float = 1.1,
138
+ ) -> str:
139
+ """Generate response using the fine-tuned model"""
140
+
141
+ if not MODEL_LOADED:
142
+ return "⚠️ Model failed to load. This usually happens due to memory constraints on the free tier. Please try again later or contact the space owner."
143
+
144
  try:
145
+ # Build conversation history
146
+ conversation = []
147
+ for user_msg, assistant_msg in history:
148
+ if user_msg:
149
+ conversation.append({"role": "user", "content": user_msg})
150
+ if assistant_msg:
151
+ conversation.append({"role": "assistant", "content": assistant_msg})
152
+ conversation.append({"role": "user", "content": message})
153
+
154
+ # Apply chat template
155
+ prompt = tokenizer.apply_chat_template(
156
+ conversation,
157
+ add_generation_prompt=True,
158
+ return_tensors="pt"
159
  )
160
+
161
+ # Move to model device
162
+ prompt = prompt.to(model.device)
163
+
164
+ # Generate with memory-efficient settings
165
  with torch.no_grad():
166
+ # Use cache for faster generation
167
+ outputs = model.generate(
168
+ input_ids=prompt,
169
+ max_new_tokens=min(int(max_new_tokens), 256), # Cap for free tier
170
  temperature=float(temperature),
171
  top_p=float(top_p),
172
  repetition_penalty=float(repetition_penalty),
173
  do_sample=True,
174
  pad_token_id=tokenizer.pad_token_id,
175
  eos_token_id=tokenizer.eos_token_id,
176
+ use_cache=True, # Enable KV cache
177
  )
178
+
179
+ # Decode only the generated tokens
180
+ generated_tokens = outputs[0][prompt.shape[-1]:]
181
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
182
+
183
+ # Memory cleanup after generation
184
+ del outputs, prompt, generated_tokens
185
+ gc.collect()
186
+
187
+ return response
188
+
189
+ except torch.cuda.OutOfMemoryError:
190
+ gc.collect()
191
+ torch.cuda.empty_cache()
192
+ return "⚠️ Out of memory. Try reducing max_new_tokens or restarting the space."
193
  except Exception as e:
194
+ return f"⚠️ Error generating response: {str(e)}"
195
 
196
+ # ── Gradio Interface ────────────────────────────────────────────────────────────
197
  examples = [
198
  ["What is the capital of Goa?"],
199
+ ["Tell me about Konkani language"],
200
+ ["What are the famous beaches in Goa?"],
201
+ ["Describe Goan fish curry"],
202
+ ["What is the history of Old Goa?"],
203
  ]
204
 
205
+ # Create the chat interface
206
+ if MODEL_LOADED:
207
+ demo = gr.ChatInterface(
208
+ fn=generate_response,
209
+ title=TITLE,
210
+ description=DESCRIPTION,
211
+ examples=examples,
212
+ retry_btn=None, # Disable retry to save memory
213
+ undo_btn=None, # Disable undo to save memory
214
+ additional_inputs=[
215
+ gr.Slider(
216
+ minimum=0.1,
217
+ maximum=1.0,
218
+ value=0.7,
219
+ step=0.05,
220
+ label="Temperature (lower = more focused)"
221
+ ),
222
+ gr.Slider(
223
+ minimum=32,
224
+ maximum=256,
225
+ value=128, # Reduced default for free tier
226
+ step=16,
227
+ label="Max new tokens"
228
+ ),
229
+ gr.Slider(
230
+ minimum=0.1,
231
+ maximum=1.0,
232
+ value=0.95,
233
+ step=0.05,
234
+ label="Top-p (nucleus sampling)"
235
+ ),
236
+ gr.Slider(
237
+ minimum=1.0,
238
+ maximum=2.0,
239
+ value=1.1,
240
+ step=0.05,
241
+ label="Repetition penalty"
242
+ ),
243
+ ],
244
+ theme=gr.themes.Soft(),
245
+ )
246
+ else:
247
+ # Fallback interface if model fails to load
248
+ demo = gr.Interface(
249
+ fn=lambda x: "⚠️ Model failed to load. Please check the logs or try restarting the space.",
250
+ inputs=gr.Textbox(label="Message"),
251
+ outputs=gr.Textbox(label="Response"),
252
+ title=TITLE,
253
+ description="**Error**: Model could not be loaded. This is likely due to memory constraints on the free tier.",
254
+ )
255
+
256
+ # Queue for handling multiple users
257
+ demo.queue(
258
+ concurrency_count=1, # Process one at a time to save memory
259
+ max_size=10, # Reduced queue size for free tier
260
+ )
261
+
262
+ # Launch the app
263
+ if __name__ == "__main__":
264
+ demo.launch()