serhany commited on
Commit
f1ea8a0
·
verified ·
1 Parent(s): 18449fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -174
app.py CHANGED
@@ -2,94 +2,101 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
4
  import time
5
- import os # Keep os, it might be useful
 
 
 
 
 
 
 
 
 
6
 
7
  # --- Configuration ---
8
  BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
9
- FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Assuming this is correct
10
 
11
  SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
12
  1. Provide personalized movie recommendations based on user preferences
13
  2. Give brief, compelling rationales for why you recommend each movie
14
  3. Ask thoughtful follow-up questions to better understand user tastes
15
  4. Maintain an enthusiastic but not overwhelming tone about cinema
16
-
17
  When recommending movies, always explain WHY the movie fits their preferences."""
18
  SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
19
 
20
- # --- Global Model Cache (models will be loaded on first use) ---
21
- _models_cache = {
22
- "base": None,
23
- "finetuned": None,
24
- "tokenizer_base": None,
25
- "tokenizer_ft": None,
 
 
 
 
 
 
26
  }
27
 
28
- # --- Model Loading Function (to be called inside decorated functions) ---
29
- def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
30
- """Loads a model and tokenizer if not already in cache."""
31
- if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
32
- print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
33
- return _models_cache[model_key], _models_cache[tokenizer_key]
34
-
35
- print(f"Loading {model_key} model ({model_identifier})...")
36
- try:
37
- tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
38
- model = AutoModelForCausalLM.from_pretrained(
39
- model_identifier,
40
- torch_dtype=torch.bfloat16, # Or torch.float16 if better for available GPU
41
- device_map="auto", # This will utilize the GPU allocated by @spaces.GPU
42
- trust_remote_code=True,
43
- )
44
- model.eval()
45
-
46
- if tokenizer.pad_token is None:
47
- tokenizer.pad_token = tokenizer.eos_token
48
- if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
49
- tokenizer.pad_token_id = tokenizer.eos_token_id
50
-
51
- _models_cache[model_key] = model
52
- _models_cache[tokenizer_key] = tokenizer
53
- print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
54
- return model, tokenizer
55
- except Exception as e:
56
- print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
57
- _models_cache[model_key] = "error" # Mark as error to avoid retrying
58
- _models_cache[tokenizer_key] = "error"
59
- raise # Re-raise the exception to see it in Gradio UI or logs
60
-
61
- # --- Inference Function (modified to ensure models are loaded) ---
62
- def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
63
- model, tokenizer = None, None
64
- system_prompt = ""
65
 
66
- if model_type_to_load == "base":
67
- if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
68
- yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
69
- return
70
- model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
71
- system_prompt = SYSTEM_PROMPT_BASE
72
- elif model_type_to_load == "finetuned":
73
- # Critical check for the FINETUNED_MODEL_ID itself
74
- if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
75
- print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID} (Type: {type(FINETUNED_MODEL_ID)})")
76
- yield "Error: Fine-tuned model ID is not configured correctly."
77
- return
78
- if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
79
- yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
80
  return
81
- model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
82
- system_prompt = SYSTEM_PROMPT_CINEGUIDE
83
- else:
84
- yield "Invalid model type."
85
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- if model is None or tokenizer is None: # Should be caught by "error" check or exception above
88
- yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
89
  return
90
 
 
91
  conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
92
- conversation.extend(chat_history) # Assuming chat_history is already type="messages"
93
  conversation.append({"role": "user", "content": message})
94
 
95
  prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
@@ -97,94 +104,71 @@ def generate_chat_response(message: str, chat_history: list, model_type_to_load:
97
 
98
  eos_tokens_ids = [tokenizer.eos_token_id]
99
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
100
- if im_end_id != getattr(tokenizer, 'unk_token_id', None): # Check if <|im_end|> is in vocab
101
  eos_tokens_ids.append(im_end_id)
102
- # Remove duplicates just in case eos_token_id is the same as im_end_id
103
- eos_tokens_ids = list(set(eos_tokens_ids))
104
-
105
 
106
- generated_token_ids = model.generate(
107
- **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9,
108
- repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids
109
- )
110
- new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
111
- response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
112
-
113
- full_response = ""
114
- for char in response_text:
115
- full_response += char
116
- time.sleep(0.005) # Adjust for desired speed
117
- yield full_response
118
-
119
-
120
- # --- Gradio UI Event Handlers (THESE GET DECORATED) ---
121
- # Note: The @spaces.GPU decorator needs to be imported if not automatically available
122
- # from huggingface_hub import-like syntax or it might be injected.
123
- # For now, let's assume it's magically available in the Space environment.
124
- # If not, you might need to find how to import it for ZeroGPU shared pool.
125
- # It's often available as: `from Fg.spaces import GPU` and used as `@GPU`
126
- # or simply `@spaces.GPU` if `spaces` is an auto-imported object.
127
-
128
- # Try without explicit import first, as HF might inject it.
129
- # If "spaces is not defined" error, you'll need to find the correct import for it.
130
-
131
- # @spaces.GPU # Placeholder for actual decorator
132
- @gr.्रु # This is a Gradio decorator for functions, not the HF GPU one.
133
- # We need to find the correct HF spaces GPU decorator.
134
- # For now, I'll structure as if it exists.
135
- # The actual execution of model loading and generation will happen here.
136
-
137
- # It's common to decorate the function called by the Gradio event.
138
- # Let's try decorating the prediction functions.
139
- # If `@spaces.GPU` is not found, the app will error earlier. You might need to find its import from HF docs for ZeroGPU.
140
- # `from hf_spaces_shared_gpu import gpu_heavy_task` is a made-up example.
141
- # Let's assume for now that if the hardware is "ZeroGPU" and this decorator is required,
142
- # the Hugging Face platform makes `spaces.GPU` available.
143
-
144
- def base_model_predict_decorated(user_message, chat_history):
145
- # This function will now be responsible for triggering the load and then generating.
146
  try:
147
- # Model loading now happens here, within the GPU-allocated function
148
- # The generate_chat_response will call load_model_and_tokenizer internally if needed
149
- bot_response_stream = generate_chat_response(user_message, chat_history, "base")
150
- full_bot_message = ""
151
- for chunk in bot_response_stream:
152
- full_bot_message = chunk
153
- yield full_bot_message
154
- except Exception as e:
155
- print(f"Error in base_model_predict_decorated: {e}")
156
- yield f"Error generating base model response: {e}"
 
 
 
 
 
 
157
 
158
- def ft_model_predict_decorated(user_message, chat_history):
159
- try:
160
- # Model loading now happens here
161
- bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
162
- full_bot_message = ""
163
- for chunk in bot_response_stream:
164
- full_bot_message = chunk
165
- yield full_bot_message
166
  except Exception as e:
167
- print(f"Error in ft_model_predict_decorated: {e}")
168
- yield f"Error generating fine-tuned response: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # --- Gradio UI Definition ---
171
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
172
  gr.Markdown(
173
  f"""
174
  # 🎬 CineGuide vs. Base {BASE_MODEL_ID}
175
- Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`)
176
- with the base {BASE_MODEL_ID} model.
177
- Type your movie-related query below and see how each model responds!
178
- **Note:** Models are loaded on first use and may take some time. Using shared GPU pool.
179
  """
180
  )
181
  with gr.Row():
182
  with gr.Column(scale=1):
183
  gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}")
184
- chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages") # Use type="messages"
185
  with gr.Column(scale=1):
186
- gr.Markdown(f"## 🤖 Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})")
187
- chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages") # Use type="messages"
188
 
189
  with gr.Row():
190
  shared_input_textbox = gr.Textbox(
@@ -197,63 +181,43 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
197
  "Hi! I'm looking for something funny to watch tonight.",
198
  "I love dry, witty humor more than slapstick.",
199
  "I'm really into complex sci-fi movies that make you think.",
 
 
200
  ],
201
  inputs=[shared_input_textbox], label="Example Prompts"
202
  )
203
 
204
- # Apply the @spaces.GPU decorator if you find the correct way to import/use it.
205
- # For now, the functions themselves will handle loading.
206
- # If the decorator is `@spaces.GPU()`, it would be:
207
- # submit_button.click(spaces.GPU()(base_model_predict_decorated), ...)
208
- # This part is tricky without knowing the exact decorator syntax for ZeroGPU.
209
- # Let's assume the functions are called and *they* handle the GPU context internally.
210
- # If the platform *requires* the event handler itself to be decorated, that's a different structure.
211
-
212
- # The functions `base_model_predict_decorated` and `ft_model_predict_decorated`
213
- # are what Gradio will call. If these need the `@spaces.GPU` decorator, you'd apply it like:
214
- # @spaces.GPU
215
- # def decorated_base_predict(user_message, chat_history):
216
- # yield from base_model_predict_decorated(user_message, chat_history)
217
- # And then pass `decorated_base_predict` to `submit_button.click`
218
-
219
- # Simpler approach for now: let Gradio call these directly.
220
- # If a wrapper is needed for the decorator, we can add it.
221
  submit_button.click(
222
- base_model_predict_decorated,
223
  [shared_input_textbox, chatbot_base],
224
  [chatbot_base],
225
- # api_name="base_predict" # Optional
226
  )
227
  submit_button.click(
228
- ft_model_predict_decorated,
229
  [shared_input_textbox, chatbot_ft],
230
  [chatbot_ft],
231
- # api_name="ft_predict" # Optional
232
  )
233
- # Handle textbox submit event for both
234
  shared_input_textbox.submit(
235
- base_model_predict_decorated,
236
  [shared_input_textbox, chatbot_base],
237
  [chatbot_base]
238
  )
239
  shared_input_textbox.submit(
240
- ft_model_predict_decorated,
241
  [shared_input_textbox, chatbot_ft],
242
  [chatbot_ft]
243
  )
244
 
245
  def clear_textbox_fn(): return ""
246
- submit_button.click(clear_textbox_fn, [], [shared_input_textbox])
247
- shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox])
248
 
249
 
250
  if __name__ == "__main__":
251
- # The following line is usually specific to certain Space configurations.
252
- # For ZeroGPU with @spaces.GPU, this might be needed in the README.md/config.yaml
253
- # rather than here, or the decorator itself implies it.
254
- # demo.config(dependencies=["torch", "transformers", "accelerate", ...])
255
-
256
- # Check Gradio docs for how to make a function eligible for @spaces.GPU if it's not a direct event handler.
257
- # Often, the main event handler itself is decorated.
258
- demo.queue()
259
  demo.launch(debug=True)
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
4
  import time
5
+ import os
6
+
7
+ # Attempt to import the spaces GPU decorator.
8
+ # This is a common pattern, but the exact import might vary or be injected.
9
+ try:
10
+ import spaces # This might make spaces.GPU available
11
+ except ImportError:
12
+ spaces = None # Define it as None if import fails, so we can check later
13
+ print("WARNING: 'spaces' module not found. @spaces.GPU decorator might not be available or work as expected.")
14
+
15
 
16
  # --- Configuration ---
17
  BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
18
+ FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Confirmed by you as correct
19
 
20
  SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
21
  1. Provide personalized movie recommendations based on user preferences
22
  2. Give brief, compelling rationales for why you recommend each movie
23
  3. Ask thoughtful follow-up questions to better understand user tastes
24
  4. Maintain an enthusiastic but not overwhelming tone about cinema
 
25
  When recommending movies, always explain WHY the movie fits their preferences."""
26
  SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
27
 
28
+ # --- Global Model Storage (placeholders) ---
29
+ # We will store model objects here after they are loaded within a GPU context.
30
+ # This addresses John6666's point about global variables not updating correctly
31
+ # if modified outside the main Gradio event flow or GPU context.
32
+ # We'll treat these more like a cache that's populated by GPU-context functions.
33
+ MODELS_LOADED = {
34
+ "base_model": None,
35
+ "base_tokenizer": None,
36
+ "ft_model": None,
37
+ "ft_tokenizer": None,
38
+ "base_load_error": None,
39
+ "ft_load_error": None,
40
  }
41
 
42
+ # --- Core Model Loading and Inference Logic (to be wrapped by @spaces.GPU) ---
43
+ def _load_and_infer(message: str, chat_history: list, model_id_to_load: str, system_prompt: str, model_kind: str):
44
+ """
45
+ This function handles loading (if necessary) and inference.
46
+ It's designed to be called by a function decorated with @spaces.GPU.
47
+ """
48
+ model_key = f"{model_kind}_model"
49
+ tokenizer_key = f"{model_kind}_tokenizer"
50
+ error_key = f"{model_kind}_load_error"
51
+
52
+ # Check if model failed to load previously
53
+ if MODELS_LOADED[error_key]:
54
+ yield f"Previous attempt to load {model_kind} model ({model_id_to_load}) failed: {MODELS_LOADED[error_key]}"
55
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Load model and tokenizer if not already loaded
58
+ if MODELS_LOADED[model_key] is None or MODELS_LOADED[tokenizer_key] is None:
59
+ print(f"Attempting to load {model_kind} model: {model_id_to_load} (Type: {type(model_id_to_load)})")
60
+ if not model_id_to_load or not isinstance(model_id_to_load, str):
61
+ MODELS_LOADED[error_key] = f"Invalid model ID: {model_id_to_load}"
62
+ yield f"Error: {model_kind} model ID is not configured correctly ({model_id_to_load})."
 
 
 
 
 
 
 
 
63
  return
64
+ try:
65
+ tokenizer = AutoTokenizer.from_pretrained(model_id_to_load, trust_remote_code=True)
66
+ # On ZeroGPU, device_map="auto" should leverage the @spaces.GPU context
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ model_id_to_load,
69
+ torch_dtype=torch.bfloat16, # Qwen models often prefer bfloat16
70
+ device_map="auto",
71
+ trust_remote_code=True,
72
+ )
73
+ model.eval()
74
+
75
+ if tokenizer.pad_token is None:
76
+ tokenizer.pad_token = tokenizer.eos_token
77
+ if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
78
+ tokenizer.pad_token_id = tokenizer.eos_token_id
79
+
80
+ MODELS_LOADED[model_key] = model
81
+ MODELS_LOADED[tokenizer_key] = tokenizer
82
+ print(f"Successfully loaded and cached {model_kind} model and tokenizer.")
83
+ except Exception as e:
84
+ MODELS_LOADED[error_key] = str(e)
85
+ print(f"ERROR loading {model_kind} model ({model_id_to_load}): {e}")
86
+ yield f"Error loading {model_kind} model: {e}" # Yield error to Gradio
87
+ return # Stop further execution for this call
88
+
89
+ # Retrieve from cache
90
+ model = MODELS_LOADED[model_key]
91
+ tokenizer = MODELS_LOADED[tokenizer_key]
92
 
93
+ if model is None or tokenizer is None: # Should not happen if loading was successful
94
+ yield f"Model or tokenizer for {model_kind} is unexpectedly None after loading attempt."
95
  return
96
 
97
+ # Prepare conversation
98
  conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
99
+ conversation.extend(chat_history)
100
  conversation.append({"role": "user", "content": message})
101
 
102
  prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
 
104
 
105
  eos_tokens_ids = [tokenizer.eos_token_id]
106
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
107
+ if im_end_id != getattr(tokenizer, 'unk_token_id', None) and im_end_id not in eos_tokens_ids:
108
  eos_tokens_ids.append(im_end_id)
109
+ eos_tokens_ids = list(set(eos_tokens_ids)) # Remove duplicates
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  try:
112
+ generated_token_ids = model.generate(
113
+ **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9,
114
+ repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids
115
+ )
116
+ new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
117
+ response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
118
+
119
+ full_response = ""
120
+ for char_idx, char_val in enumerate(response_text):
121
+ full_response += char_val
122
+ # Yield more slowly or in chunks if char-by-char is too slow/frequent for Gradio
123
+ if char_idx % 5 == 0 or char_idx == len(response_text) -1 : # Yield every 5 chars or at the end
124
+ time.sleep(0.001) # Minimal sleep
125
+ yield full_response
126
+ if not response_text: # Handle empty generation
127
+ yield ""
128
 
 
 
 
 
 
 
 
 
129
  except Exception as e:
130
+ print(f"Error during {model_kind} model generation: {e}")
131
+ yield f"Error during generation: {e}"
132
+
133
+
134
+ # --- Gradio Event Handler Wrappers (these get decorated) ---
135
+ def create_gpu_handler(model_id, system_prompt, model_kind_str):
136
+ # This function will be decorated by @spaces.GPU
137
+ # It calls the actual logic.
138
+ def gpu_fn(message, chat_history):
139
+ yield from _load_and_infer(message, chat_history, model_id, system_prompt, model_kind_str)
140
+ return gpu_fn
141
+
142
+ # Apply the decorator IF `spaces` module was imported and has `GPU`
143
+ if spaces and hasattr(spaces, "GPU"):
144
+ print("Applying @spaces.GPU decorator.")
145
+ base_model_predict = spaces.GPU(create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base"))
146
+ ft_model_predict = spaces.GPU(create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft"))
147
+ else:
148
+ print("WARNING: @spaces.GPU decorator not applied. GPU acceleration on ZeroGPU might not work as expected.")
149
+ # Fallback to non-decorated calls; this will likely lead to "No @spaces.GPU function detected"
150
+ # or CUDA errors if running on ZeroGPU that expects the decorator.
151
+ base_model_predict = create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base")
152
+ ft_model_predict = create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft")
153
+
154
 
155
  # --- Gradio UI Definition ---
156
+ with gr.Blocks(theme=gr.themes.Default()) as demo: # Changed to Default theme, Soft can sometimes have issues
157
  gr.Markdown(
158
  f"""
159
  # 🎬 CineGuide vs. Base {BASE_MODEL_ID}
160
+ Compare the fine-tuned CineGuide (`{FINETUNED_MODEL_ID}`) with the base {BASE_MODEL_ID}.
161
+ **Note:** Models are loaded on first use within a GPU context and may take time.
162
+ This Space attempts to use the ZeroGPU shared pool via `@spaces.GPU`.
 
163
  """
164
  )
165
  with gr.Row():
166
  with gr.Column(scale=1):
167
  gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}")
168
+ chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages")
169
  with gr.Column(scale=1):
170
+ gr.Markdown(f"## 🤖 Fine-tuned CineGuide")
171
+ chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages")
172
 
173
  with gr.Row():
174
  shared_input_textbox = gr.Textbox(
 
181
  "Hi! I'm looking for something funny to watch tonight.",
182
  "I love dry, witty humor more than slapstick.",
183
  "I'm really into complex sci-fi movies that make you think.",
184
+ "Tell me about some good action movies from the 90s.",
185
+ "Recommend a thought-provoking sci-fi film about AI.",
186
  ],
187
  inputs=[shared_input_textbox], label="Example Prompts"
188
  )
189
 
190
+ # Event handling
191
+ # The `base_model_predict` and `ft_model_predict` are now the (potentially) decorated functions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  submit_button.click(
193
+ base_model_predict,
194
  [shared_input_textbox, chatbot_base],
195
  [chatbot_base],
196
+ api_name="base_predict" # Good for testing API route
197
  )
198
  submit_button.click(
199
+ ft_model_predict,
200
  [shared_input_textbox, chatbot_ft],
201
  [chatbot_ft],
202
+ api_name="ft_predict"
203
  )
 
204
  shared_input_textbox.submit(
205
+ base_model_predict,
206
  [shared_input_textbox, chatbot_base],
207
  [chatbot_base]
208
  )
209
  shared_input_textbox.submit(
210
+ ft_model_predict,
211
  [shared_input_textbox, chatbot_ft],
212
  [chatbot_ft]
213
  )
214
 
215
  def clear_textbox_fn(): return ""
216
+ submit_button.click(clear_textbox_fn, [], [shared_input_textbox], queue=False) # queue=False for instant clear
217
+ shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox], queue=False)
218
 
219
 
220
  if __name__ == "__main__":
221
+ demo.queue() # Enable queuing for multiple users
222
+ # debug=True can sometimes interfere with production Spaces, but fine for testing
 
 
 
 
 
 
223
  demo.launch(debug=True)