Reubencf commited on
Commit
baaf291
Β·
verified Β·
1 Parent(s): ae129eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -296
app.py CHANGED
@@ -1,223 +1,177 @@
1
- # app.py β€” Corrected for proper LoRA adapter loading
 
 
 
 
 
 
 
 
2
 
3
  import os
4
  import gc
 
 
 
5
  import torch
6
  import gradio as gr
7
- from typing import List, Tuple
8
- import warnings
9
- warnings.filterwarnings('ignore')
10
 
11
  try:
12
  from peft import PeftConfig, PeftModel
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
 
 
14
  IMPORTS_OK = True
15
- except ImportError as e:
16
  IMPORTS_OK = False
17
  print(f"Missing dependencies: {e}")
18
- print("Please install: pip install transformers peft torch gradio accelerate")
19
 
20
  # ── Configuration ──────────────────────────────────────────────────────────────
21
- HF_TOKEN = os.environ.get("HF_TOKEN") # Optional for public models
22
-
23
- # Your LoRA adapter location (HuggingFace repo or local path)
24
- ADAPTER_ID = "Reubencf/gemma3-goan-finetuned"
25
- # For local adapter: ADAPTER_ID = "./path/to/your/adapter"
26
-
27
- # Base model - MUST match what you used for fine-tuning!
28
- # Check your adapter's config.json for "base_model_name_or_path"
29
- BASE_MODEL_ID = "google/gemma-3-4b-it" # Change this to your actual base model
30
- # Common options:
31
- # - "google/gemma-2b-it" (2B parameters, easier on memory)
32
- # - "unsloth/gemma-2-2b-it-bnb-4bit" (quantized version)
33
- # - Your actual base model used for training
34
-
35
- # Settings
36
- USE_8BIT = False # Set to True if you have GPU and want to use 8-bit quantization
37
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
 
39
  TITLE = "🌴 Gemma Goan Q&A Bot"
40
- DESCRIPTION = """
41
- Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset.
42
- Ask about Goa, Konkani culture, or general topics!
 
 
43
 
44
- **Status**: {}
45
- """
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # ── Load model + tokenizer (correct LoRA loading) ──────────────────────────────
48
  def load_model_and_tokenizer():
49
- """Load base model and apply LoRA adapter correctly"""
50
-
51
  if not IMPORTS_OK:
52
- raise ImportError("Required packages not installed")
53
-
54
- print("[Init] Starting model load...")
55
- print(f"[Config] Base model: {BASE_MODEL_ID}")
56
- print(f"[Config] LoRA adapter: {ADAPTER_ID}")
57
  print(f"[Config] Device: {DEVICE}")
58
-
59
- # Memory cleanup
60
  gc.collect()
61
  if torch.cuda.is_available():
62
  torch.cuda.empty_cache()
63
-
64
- status = ""
65
- model = None
66
- tokenizer = None
67
-
68
  try:
69
- # Step 1: Try to read adapter config to get the correct base model
70
- actual_base_model = BASE_MODEL_ID
71
- try:
72
- print(f"[Load] Checking adapter configuration...")
73
- peft_config = PeftConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN)
74
- actual_base_model = peft_config.base_model_name_or_path
75
  print(f"[Load] Adapter expects base model: {actual_base_model}")
76
-
77
- # Warn if mismatch
78
- if actual_base_model != BASE_MODEL_ID:
79
- print(f"[Warning] BASE_MODEL_ID ({BASE_MODEL_ID}) doesn't match adapter's base ({actual_base_model})")
80
- print(f"[Load] Using adapter's base model: {actual_base_model}")
81
- except Exception as e:
82
- print(f"[Warning] Cannot read adapter config: {e}")
83
- print(f"[Load] Will try with configured base model: {BASE_MODEL_ID}")
84
- actual_base_model = BASE_MODEL_ID
85
-
86
- # Step 2: Load the BASE MODEL (not the adapter!)
87
- print(f"[Load] Loading base model: {actual_base_model}")
88
-
89
- # Quantization config for GPU
90
- quantization_config = None
91
- if USE_8BIT and torch.cuda.is_available():
92
- print("[Load] Using 8-bit quantization")
93
- quantization_config = BitsAndBytesConfig(
94
- load_in_8bit=True,
95
- bnb_8bit_compute_dtype=torch.float16
96
- )
97
-
98
- # Load base model
99
- base_model = AutoModelForCausalLM.from_pretrained(
100
- actual_base_model,
101
- token=HF_TOKEN,
102
- trust_remote_code=True,
103
- quantization_config=quantization_config,
104
- low_cpu_mem_usage=True,
105
- torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16,
106
- device_map="auto" if torch.cuda.is_available() else None,
107
- )
108
-
109
- # Move to device if needed
110
- if DEVICE == "cpu" and not torch.cuda.is_available():
111
- base_model = base_model.to("cpu")
112
- print("[Load] Model on CPU")
113
-
114
- print("[Load] Base model loaded successfully")
115
-
116
- # Step 3: Load tokenizer from the BASE MODEL
117
- print(f"[Load] Loading tokenizer from base model...")
118
- tokenizer = AutoTokenizer.from_pretrained(
119
- actual_base_model,
120
- token=HF_TOKEN,
121
- use_fast=True,
122
- trust_remote_code=True,
123
- )
124
-
125
- if tokenizer.pad_token is None:
126
- tokenizer.pad_token = tokenizer.eos_token
127
- tokenizer.padding_side = "left"
128
-
129
- # Step 4: Try to apply LoRA adapter
130
- try:
131
- print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}")
132
- model = PeftModel.from_pretrained(
133
- base_model,
134
- ADAPTER_ID,
135
- token=HF_TOKEN,
136
- trust_remote_code=True,
137
- is_trainable=False, # Inference only
138
- )
139
-
140
- # Optional: Merge adapter with base model for faster inference
141
- # This combines the weights permanently (uses more memory initially but faster inference)
142
- merge = input("\nπŸ’‘ Merge adapter for faster inference? (y/n, default=y): ").strip().lower()
143
- if merge != 'n':
144
- print("[Load] Merging adapter with base model...")
145
- model = model.merge_and_unload()
146
- print("[Load] Adapter merged successfully")
147
- status = f"βœ… Using fine-tuned model (merged): {ADAPTER_ID}"
148
- else:
149
- print("[Load] Using adapter without merging")
150
- status = f"βœ… Using fine-tuned model: {ADAPTER_ID}"
151
-
152
- except FileNotFoundError as e:
153
- print(f"[Error] Adapter files not found: {e}")
154
- print("[Fallback] Using base model without fine-tuning")
155
- model = base_model
156
- status = f"⚠️ Adapter not found. Using base model only: {actual_base_model}"
157
-
158
- except Exception as e:
159
- print(f"[Error] Failed to load adapter: {e}")
160
- print("[Fallback] Using base model without fine-tuning")
161
- model = base_model
162
- status = f"⚠️ Could not load adapter. Using base model only: {actual_base_model}"
163
-
164
- # Step 5: Final setup
165
- model.eval()
166
- print(f"[Load] Model ready on {DEVICE}!")
167
-
168
- # Memory cleanup
169
- gc.collect()
170
- if torch.cuda.is_available():
171
- torch.cuda.empty_cache()
172
-
173
- return model, tokenizer, status
174
-
175
  except Exception as e:
176
- error_msg = f"Failed to load model: {str(e)}"
177
- print(f"[Fatal] {error_msg}")
178
-
179
- # Try fallback to smallest model
180
- if "gemma-2b" not in BASE_MODEL_ID.lower():
181
- print("[Fallback] Trying with gemma-2b-it...")
182
- try:
183
- base_model = AutoModelForCausalLM.from_pretrained(
184
- "google/gemma-3-4b-it",
185
- token=HF_TOKEN,
186
- trust_remote_code=True,
187
- low_cpu_mem_usage=True,
188
- torch_dtype=torch.float32,
189
- device_map=None,
190
- ).to("cpu")
191
-
192
- tokenizer = AutoTokenizer.from_pretrained(
193
- "google/gemma-3-4b-it",
194
- token=HF_TOKEN,
195
- trust_remote_code=True,
196
- )
197
- if tokenizer.pad_token is None:
198
- tokenizer.pad_token = tokenizer.eos_token
199
-
200
- base_model.eval()
201
- return base_model, tokenizer, "⚠️ Using fallback model: gemma-2b-it (no fine-tuning)"
202
-
203
- except Exception as fallback_error:
204
- print(f"[Fatal] Fallback also failed: {fallback_error}")
205
- raise gr.Error(f"Cannot load any model. Check your configuration.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  else:
207
- raise gr.Error(error_msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- # Load model globally
 
 
210
  try:
211
  model, tokenizer, STATUS_MSG = load_model_and_tokenizer()
212
  MODEL_LOADED = True
213
- DESCRIPTION = DESCRIPTION.format(STATUS_MSG)
214
  except Exception as e:
215
  print(f"[Fatal] Could not load model: {e}")
216
  MODEL_LOADED = False
217
- model, tokenizer = None, None
218
- DESCRIPTION = DESCRIPTION.format(f"❌ Model failed to load: {str(e)[:100]}")
 
 
219
 
220
- # ── Generation function ──────────────────────────────────────────────────���──────
221
  def generate_response(
222
  message: str,
223
  history: List[Tuple[str, str]],
@@ -226,58 +180,43 @@ def generate_response(
226
  top_p: float = 0.95,
227
  repetition_penalty: float = 1.1,
228
  ) -> str:
229
- """Generate response using the model"""
230
-
231
  if not MODEL_LOADED:
232
- return "⚠️ Model failed to load. Please check the logs or restart the application."
233
-
234
  try:
235
- # Build conversation
236
  conversation = []
237
  if history:
238
- # Keep last 3 exchanges for context
239
- for user_msg, assistant_msg in history[-3:]:
240
- if user_msg:
241
- conversation.append({"role": "user", "content": user_msg})
242
- if assistant_msg:
243
- conversation.append({"role": "assistant", "content": assistant_msg})
244
  conversation.append({"role": "user", "content": message})
245
-
246
- # Apply chat template
247
  try:
248
- prompt = tokenizer.apply_chat_template(
249
  conversation,
250
  add_generation_prompt=True,
251
- return_tensors="pt"
252
- )
253
- except Exception as e:
254
- print(f"[Warning] Chat template failed: {e}, using fallback format")
255
- # Fallback format
256
- prompt_text = ""
257
- for msg in conversation:
258
- if msg["role"] == "user":
259
- prompt_text += f"User: {msg['content']}\n"
260
- else:
261
- prompt_text += f"Assistant: {msg['content']}\n"
262
- prompt_text += "Assistant: "
263
-
264
- inputs = tokenizer(
265
- prompt_text,
266
  return_tensors="pt",
267
- truncation=True,
268
- max_length=512
269
  )
270
- prompt = inputs.input_ids
271
-
272
- # Move to device
273
- prompt = prompt.to(model.device if hasattr(model, 'device') else DEVICE)
274
-
275
- # Generate
276
- print(f"[Generate] Input length: {prompt.shape[-1]} tokens")
 
 
 
 
 
277
  with torch.no_grad():
278
- outputs = model.generate(
279
- input_ids=prompt,
280
- max_new_tokens=min(int(max_new_tokens), 256),
281
  temperature=float(temperature),
282
  top_p=float(top_p),
283
  repetition_penalty=float(repetition_penalty),
@@ -286,106 +225,75 @@ def generate_response(
286
  eos_token_id=tokenizer.eos_token_id,
287
  use_cache=True,
288
  )
289
-
290
- # Decode only generated tokens
291
- generated_tokens = outputs[0][prompt.shape[-1]:]
292
- response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
293
-
294
- print(f"[Generate] Output length: {len(generated_tokens)} tokens")
295
-
296
  # Cleanup
297
- del outputs, prompt, generated_tokens
298
  gc.collect()
299
-
300
- return response
301
-
 
 
302
  except Exception as e:
303
- error_msg = f"⚠️ Error generating response: {str(e)}"
304
- print(f"[Error] {error_msg}")
305
-
306
- # Try to recover memory
307
  gc.collect()
308
  if torch.cuda.is_available():
309
  torch.cuda.empty_cache()
310
-
311
- return error_msg
312
 
313
- # ── Gradio Interface ────────────────────────────────────────────────────────────
314
  examples = [
315
- ["What is the capital of Goa?"],
316
- ["Tell me about Konkani language"],
317
- ["What are famous beaches in Goa?"],
318
- ["Describe Goan fish curry"],
319
- ["What is the history of Old Goa?"],
320
  ]
321
 
322
- # Create interface
 
 
 
 
 
323
  if MODEL_LOADED:
324
  demo = gr.ChatInterface(
325
  fn=generate_response,
326
  title=TITLE,
327
  description=DESCRIPTION,
328
  examples=examples,
329
- retry_btn=None,
330
- undo_btn=None,
331
  additional_inputs=[
332
- gr.Slider(
333
- minimum=0.1,
334
- maximum=1.0,
335
- value=0.7,
336
- step=0.05,
337
- label="Temperature (lower = more focused)"
338
- ),
339
- gr.Slider(
340
- minimum=32,
341
- maximum=256,
342
- value=128,
343
- step=16,
344
- label="Max new tokens"
345
- ),
346
- gr.Slider(
347
- minimum=0.1,
348
- maximum=1.0,
349
- value=0.95,
350
- step=0.05,
351
- label="Top-p (nucleus sampling)"
352
- ),
353
- gr.Slider(
354
- minimum=1.0,
355
- maximum=2.0,
356
- value=1.1,
357
- step=0.05,
358
- label="Repetition penalty"
359
- ),
360
  ],
361
- theme=gr.themes.Soft(),
362
  )
363
  else:
364
  demo = gr.Interface(
365
- fn=lambda x: "Model failed to load. Check console for errors.",
366
  inputs=gr.Textbox(label="Message"),
367
  outputs=gr.Textbox(label="Response"),
368
  title=TITLE,
369
  description=DESCRIPTION,
 
370
  )
371
 
372
- # Queue with version compatibility
373
  try:
374
- # Try newer Gradio syntax first (4.x)
375
- demo.queue(default_concurrency_limit=1, max_size=10)
376
- except TypeError:
377
- try:
378
- # Fall back to older syntax (3.x)
379
- demo.queue(concurrency_count=1, max_size=10)
380
- except:
381
- # If both fail, try without parameters
382
- demo.queue()
383
 
384
  if __name__ == "__main__":
385
- print("\n" + "="*50)
386
- print(f"πŸš€ Starting Gradio app on {DEVICE}...")
387
- print(f"πŸ“ Base model: {BASE_MODEL_ID}")
388
  print(f"πŸ”§ LoRA adapter: {ADAPTER_ID}")
389
- print("="*50 + "\n")
390
-
 
391
  demo.launch()
 
1
+ # app.py β€” Hugging Face Space ready (LoRA adapter, Gradio compat)
2
+ # ---------------------------------------------------------------
3
+ # What changed vs your script
4
+ # - Removed ChatInterface args that broke on old Gradio (retry_btn, undo_btn)
5
+ # - No interactive input() for merging (Spaces are non-interactive). Use MERGE_LORA env var.
6
+ # - Secrets: read HF token from env (Settings β†’ Secrets β†’ HF_TOKEN), never hardcode.
7
+ # - Token passing works across transformers/peft versions (token/use_auth_token fallback).
8
+ # - Optional 8-bit via USE_8BIT=1 (GPU only). Safe CPU defaults.
9
+ # - Robust theme/queue/launch for mixed Gradio versions.
10
 
11
  import os
12
  import gc
13
+ import warnings
14
+ from typing import List, Tuple
15
+
16
  import torch
17
  import gradio as gr
18
+
19
+ warnings.filterwarnings("ignore")
20
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
21
 
22
  try:
23
  from peft import PeftConfig, PeftModel
24
+ from transformers import (
25
+ AutoTokenizer,
26
+ AutoModelForCausalLM,
27
+ BitsAndBytesConfig,
28
+ )
29
  IMPORTS_OK = True
30
+ except Exception as e:
31
  IMPORTS_OK = False
32
  print(f"Missing dependencies: {e}")
33
+ print("Install: pip install --upgrade 'transformers>=4.41' peft accelerate gradio torch bitsandbytes")
34
 
35
  # ── Configuration ──────────────────────────────────────────────────────────────
36
+ HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Settings β†’ Secrets β†’ HF_TOKEN
37
+
38
+ # LoRA adapter repo (must be compatible with BASE_MODEL_ID)
39
+ ADAPTER_ID = os.getenv("ADAPTER_ID", "Reubencf/gemma3-goan-finetuned")
40
+
41
+ # Base model used during fine-tuning (should match adapter's base)
42
+ BASE_MODEL_ID_DEFAULT = os.getenv("BASE_MODEL_ID", "google/gemma-3-4b-it")
43
+
44
+ # Quantization toggle (GPU only): set USE_8BIT=1 in Space variables
45
+ USE_8BIT = os.getenv("USE_8BIT", "0").lower() in {"1", "true", "yes", "y"}
46
+
47
+ # Merge LoRA into the base for faster inference: MERGE_LORA=1/0
48
+ MERGE_LORA = os.getenv("MERGE_LORA", "1").lower() in {"1", "true", "yes", "y"}
49
+
 
 
50
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
51
 
52
  TITLE = "🌴 Gemma Goan Q&A Bot"
53
+ DESCRIPTION_TMPL = (
54
+ "Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset.\n"
55
+ "Ask about Goa, Konkani culture, or general topics!\n\n"
56
+ "**Status**: {}"
57
+ )
58
 
59
+ # ── Helpers ───────────────────────────────────────────────────────────────────
60
+
61
+ def call_with_token(fn, *args, **kwargs):
62
+ """Call HF/Transformers/PEFT functions with token OR use_auth_token for
63
+ broad version compatibility."""
64
+ if HF_TOKEN:
65
+ try:
66
+ return fn(*args, token=HF_TOKEN, **kwargs)
67
+ except TypeError:
68
+ return fn(*args, use_auth_token=HF_TOKEN, **kwargs)
69
+ return fn(*args, **kwargs)
70
+
71
+ # ── Load model + tokenizer ─────────────────────────────────────────────────────
72
 
 
73
  def load_model_and_tokenizer():
 
 
74
  if not IMPORTS_OK:
75
+ raise ImportError("Required packages not installed.")
76
+
77
+ print("[Init] Starting model load…")
 
 
78
  print(f"[Config] Device: {DEVICE}")
79
+
80
+ # GC + VRAM cleanup
81
  gc.collect()
82
  if torch.cuda.is_available():
83
  torch.cuda.empty_cache()
84
+
85
+ # Step 1: Confirm base model from the adapter's config if possible
86
+ actual_base_model = BASE_MODEL_ID_DEFAULT
 
 
87
  try:
88
+ print(f"[Load] Reading adapter config: {ADAPTER_ID}")
89
+ peft_cfg = call_with_token(PeftConfig.from_pretrained, ADAPTER_ID)
90
+ if getattr(peft_cfg, "base_model_name_or_path", None):
91
+ actual_base_model = peft_cfg.base_model_name_or_path
 
 
92
  print(f"[Load] Adapter expects base model: {actual_base_model}")
93
+ else:
94
+ print("[Warn] Adapter did not expose base_model_name_or_path; using configured base.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  except Exception as e:
96
+ print(f"[Warn] Could not read adapter config ({e}); using configured base: {actual_base_model}")
97
+
98
+ # Step 2: Load base model (optionally quantized on GPU)
99
+ print(f"[Load] Loading base model: {actual_base_model}")
100
+ quant_cfg = None
101
+ if USE_8BIT and torch.cuda.is_available():
102
+ print("[Load] Enabling 8-bit quantization (bitsandbytes)")
103
+ quant_cfg = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)
104
+
105
+ base_model = call_with_token(
106
+ AutoModelForCausalLM.from_pretrained,
107
+ actual_base_model,
108
+ trust_remote_code=True,
109
+ quantization_config=quant_cfg,
110
+ low_cpu_mem_usage=True,
111
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
112
+ device_map="auto" if torch.cuda.is_available() else None,
113
+ )
114
+
115
+ if DEVICE == "cpu" and not torch.cuda.is_available():
116
+ base_model = base_model.to("cpu")
117
+ print("[Load] Model on CPU")
118
+
119
+ print("[Load] Base model loaded βœ”")
120
+
121
+ # Step 3: Tokenizer
122
+ print("[Load] Loading tokenizer…")
123
+ tokenizer = call_with_token(
124
+ AutoTokenizer.from_pretrained,
125
+ actual_base_model,
126
+ use_fast=True,
127
+ trust_remote_code=True,
128
+ )
129
+ if tokenizer.pad_token is None:
130
+ tokenizer.pad_token = tokenizer.eos_token
131
+ tokenizer.padding_side = "left"
132
+
133
+ # Step 4: Apply LoRA adapter
134
+ status = ""
135
+ model = base_model
136
+ try:
137
+ print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}")
138
+ model = call_with_token(PeftModel.from_pretrained, base_model, ADAPTER_ID)
139
+
140
+ if MERGE_LORA:
141
+ print("[Load] Merging adapter into base (merge_and_unload)…")
142
+ model = model.merge_and_unload()
143
+ status = f"βœ… Using fine-tuned model (merged): {ADAPTER_ID}"
144
  else:
145
+ status = f"βœ… Using fine-tuned model via adapter: {ADAPTER_ID}"
146
+ except FileNotFoundError as e:
147
+ print(f"[Error] Adapter files not found: {e}")
148
+ status = f"⚠️ Adapter not found. Using base only: {actual_base_model}"
149
+ except Exception as e:
150
+ print(f"[Error] Failed to load adapter: {e}")
151
+ status = f"⚠️ Could not load adapter. Using base only: {actual_base_model}"
152
+
153
+ model.eval()
154
+ print(f"[Load] Model ready on {DEVICE} βœ”")
155
+
156
+ gc.collect()
157
+ if torch.cuda.is_available():
158
+ torch.cuda.empty_cache()
159
 
160
+ return model, tokenizer, status
161
+
162
+ # Global load at import time (Space-friendly)
163
  try:
164
  model, tokenizer, STATUS_MSG = load_model_and_tokenizer()
165
  MODEL_LOADED = True
166
+ DESCRIPTION = DESCRIPTION_TMPL.format(STATUS_MSG)
167
  except Exception as e:
168
  print(f"[Fatal] Could not load model: {e}")
169
  MODEL_LOADED = False
170
+ model = tokenizer = None
171
+ DESCRIPTION = DESCRIPTION_TMPL.format(f"❌ Model failed to load: {str(e)[:140]}")
172
+
173
+ # ── Generation ────────────────────────────────────────────────────────────────
174
 
 
175
  def generate_response(
176
  message: str,
177
  history: List[Tuple[str, str]],
 
180
  top_p: float = 0.95,
181
  repetition_penalty: float = 1.1,
182
  ) -> str:
 
 
183
  if not MODEL_LOADED:
184
+ return "⚠️ Model failed to load. Check Space logs."
185
+
186
  try:
187
+ # Build short chat history
188
  conversation = []
189
  if history:
190
+ for u, a in history[-3:]:
191
+ if u:
192
+ conversation.append({"role": "user", "content": u})
193
+ if a:
194
+ conversation.append({"role": "assistant", "content": a})
 
195
  conversation.append({"role": "user", "content": message})
196
+
197
+ # Try the tokenizer's chat template first
198
  try:
199
+ input_ids = tokenizer.apply_chat_template(
200
  conversation,
201
  add_generation_prompt=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  return_tensors="pt",
 
 
203
  )
204
+ except Exception as e:
205
+ print(f"[Warn] chat_template failed: {e}; using manual format")
206
+ prompt_text = "".join(
207
+ [
208
+ ("User: " + m["content"] + "\n") if m["role"] == "user" else ("Assistant: " + m["content"] + "\n")
209
+ for m in conversation
210
+ ]
211
+ ) + "Assistant: "
212
+ input_ids = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=1024).input_ids
213
+
214
+ input_ids = input_ids.to(model.device if hasattr(model, "device") else DEVICE)
215
+
216
  with torch.no_grad():
217
+ out = model.generate(
218
+ input_ids=input_ids,
219
+ max_new_tokens=max(1, min(int(max_new_tokens), 512)),
220
  temperature=float(temperature),
221
  top_p=float(top_p),
222
  repetition_penalty=float(repetition_penalty),
 
225
  eos_token_id=tokenizer.eos_token_id,
226
  use_cache=True,
227
  )
228
+
229
+ gen = out[0][input_ids.shape[-1]:]
230
+ text = tokenizer.decode(gen, skip_special_tokens=True).strip()
231
+
 
 
 
232
  # Cleanup
233
+ del out, input_ids, gen
234
  gc.collect()
235
+ if torch.cuda.is_available():
236
+ torch.cuda.empty_cache()
237
+
238
+ return text or "(no output)"
239
+
240
  except Exception as e:
 
 
 
 
241
  gc.collect()
242
  if torch.cuda.is_available():
243
  torch.cuda.empty_cache()
244
+ return f"⚠️ Error generating response: {e}"
 
245
 
246
+ # ── UI ────────────────────────────────────────────────────────────────────────
247
  examples = [
248
+ "What is the capital of Goa?",
249
+ "Tell me about the Konkani language.",
250
+ "What are famous beaches in Goa?",
251
+ "Describe Goan fish curry.",
252
+ "What is the history of Old Goa?",
253
  ]
254
 
255
+ # Best-effort theme across versions
256
+ try:
257
+ THEME = gr.themes.Soft()
258
+ except Exception:
259
+ THEME = None
260
+
261
  if MODEL_LOADED:
262
  demo = gr.ChatInterface(
263
  fn=generate_response,
264
  title=TITLE,
265
  description=DESCRIPTION,
266
  examples=examples,
 
 
267
  additional_inputs=[
268
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature"),
269
+ gr.Slider(minimum=32, maximum=512, value=256, step=16, label="Max new tokens"),
270
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
271
+ gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  ],
273
+ theme=THEME,
274
  )
275
  else:
276
  demo = gr.Interface(
277
+ fn=lambda x: "Model failed to load. Check Space logs.",
278
  inputs=gr.Textbox(label="Message"),
279
  outputs=gr.Textbox(label="Response"),
280
  title=TITLE,
281
  description=DESCRIPTION,
282
+ theme=THEME,
283
  )
284
 
285
+ # Queue β€” keep params minimal for cross-version compat
286
  try:
287
+ demo.queue()
288
+ except Exception:
289
+ pass
 
 
 
 
 
 
290
 
291
  if __name__ == "__main__":
292
+ print("\n" + "=" * 60)
293
+ print(f"πŸš€ Starting Gradio app on {DEVICE} …")
294
+ print(f"πŸ“ Base model: {BASE_MODEL_ID_DEFAULT}")
295
  print(f"πŸ”§ LoRA adapter: {ADAPTER_ID}")
296
+ print(f"🧩 Merge LoRA: {MERGE_LORA}")
297
+ print("=" * 60 + "\n")
298
+ # On Spaces, just calling launch() is fine.
299
  demo.launch()