serhany commited on
Commit
235bd9f
·
verified ·
1 Parent(s): 6935641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -79
app.py CHANGED
@@ -57,7 +57,20 @@ def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_ke
57
  _models_cache[tokenizer_key] = "error"
58
  raise
59
 
60
- def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  """Generate response using specified model type."""
62
  model, tokenizer = None, None
63
  system_prompt = ""
@@ -91,94 +104,69 @@ def generate_chat_response(message: str, chat_history: list, model_type_to_load:
91
  if system_prompt:
92
  conversation.append({"role": "system", "content": system_prompt})
93
 
94
- # Add chat history
95
- conversation.extend(chat_history)
 
96
  conversation.append({"role": "user", "content": message})
97
 
98
- # Generate response
99
- prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
100
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
101
-
102
- # Prepare EOS tokens
103
- eos_tokens_ids = [tokenizer.eos_token_id]
104
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
105
- if im_end_id != getattr(tokenizer, 'unk_token_id', None):
106
- eos_tokens_ids.append(im_end_id)
107
- eos_tokens_ids = list(set(eos_tokens_ids))
108
-
109
- # Generate
110
- with torch.no_grad():
111
- generated_token_ids = model.generate(
112
- **inputs,
113
- max_new_tokens=512,
114
- do_sample=True,
115
- temperature=0.7,
116
- top_p=0.9,
117
- repetition_penalty=1.1,
118
- pad_token_id=tokenizer.pad_token_id,
119
- eos_token_id=eos_tokens_ids
120
- )
121
-
122
- new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
123
- response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
124
-
125
- # Stream the response
126
- full_response = ""
127
- for char in response_text:
128
- full_response += char
129
- time.sleep(0.005)
130
- yield full_response
131
-
132
- @spaces.GPU
133
- def base_model_predict(user_message, chat_history):
134
- """Predict using base model - decorated with @spaces.GPU."""
135
- try:
136
- bot_response_stream = generate_chat_response(user_message, chat_history, "base")
137
- for chunk in bot_response_stream:
138
- yield chunk
139
- except Exception as e:
140
- print(f"Error in base_model_predict: {e}")
141
- yield f"Error generating base model response: {str(e)}"
142
-
143
- @spaces.GPU
144
- def ft_model_predict(user_message, chat_history):
145
- """Predict using fine-tuned model - decorated with @spaces.GPU."""
146
  try:
147
- bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
148
- for chunk in bot_response_stream:
149
- yield chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  except Exception as e:
151
- print(f"Error in ft_model_predict: {e}")
152
- yield f"Error generating fine-tuned response: {str(e)}"
153
-
154
- def format_chat_history(history, message):
155
- """Format the chat history for the models."""
156
- formatted_history = []
157
- for chat in history:
158
- if isinstance(chat, dict) and 'role' in chat:
159
- formatted_history.append(chat)
160
- elif isinstance(chat, list) and len(chat) == 2:
161
- formatted_history.extend([
162
- {"role": "user", "content": chat[0]},
163
- {"role": "assistant", "content": chat[1]}
164
- ])
165
- return formatted_history
166
 
167
  def respond_base(message, history):
168
  """Handle base model response for Gradio ChatInterface."""
169
- formatted_history = format_chat_history(history, message)
170
- response_gen = base_model_predict(message, formatted_history)
171
-
172
- for response in response_gen:
173
- yield response
 
 
174
 
175
  def respond_ft(message, history):
176
  """Handle fine-tuned model response for Gradio ChatInterface."""
177
- formatted_history = format_chat_history(history, message)
178
- response_gen = ft_model_predict(message, formatted_history)
179
-
180
- for response in response_gen:
181
- yield response
 
 
182
 
183
  # --- Gradio UI Definition ---
184
  with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo:
 
57
  _models_cache[tokenizer_key] = "error"
58
  raise
59
 
60
+ def convert_gradio_history_to_messages(history):
61
+ """Convert Gradio ChatInterface history format to messages format."""
62
+ messages = []
63
+ for exchange in history:
64
+ if isinstance(exchange, (list, tuple)) and len(exchange) == 2:
65
+ user_msg, assistant_msg = exchange
66
+ if user_msg: # Only add if not empty
67
+ messages.append({"role": "user", "content": str(user_msg)})
68
+ if assistant_msg: # Only add if not empty
69
+ messages.append({"role": "assistant", "content": str(assistant_msg)})
70
+ return messages
71
+
72
+ @spaces.GPU
73
+ def generate_chat_response(message: str, history: list, model_type_to_load: str):
74
  """Generate response using specified model type."""
75
  model, tokenizer = None, None
76
  system_prompt = ""
 
104
  if system_prompt:
105
  conversation.append({"role": "system", "content": system_prompt})
106
 
107
+ # Convert and add chat history
108
+ formatted_history = convert_gradio_history_to_messages(history)
109
+ conversation.extend(formatted_history)
110
  conversation.append({"role": "user", "content": message})
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  try:
113
+ # Generate response
114
+ prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
115
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
116
+
117
+ # Prepare EOS tokens
118
+ eos_tokens_ids = [tokenizer.eos_token_id]
119
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
120
+ if im_end_id != getattr(tokenizer, 'unk_token_id', None):
121
+ eos_tokens_ids.append(im_end_id)
122
+ eos_tokens_ids = list(set(eos_tokens_ids))
123
+
124
+ # Generate
125
+ with torch.no_grad():
126
+ generated_token_ids = model.generate(
127
+ **inputs,
128
+ max_new_tokens=512,
129
+ do_sample=True,
130
+ temperature=0.7,
131
+ top_p=0.9,
132
+ repetition_penalty=1.1,
133
+ pad_token_id=tokenizer.pad_token_id,
134
+ eos_token_id=eos_tokens_ids
135
+ )
136
+
137
+ new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
138
+ response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
139
+
140
+ # Stream the response
141
+ full_response = ""
142
+ for char in response_text:
143
+ full_response += char
144
+ time.sleep(0.005)
145
+ yield full_response
146
+
147
  except Exception as e:
148
+ print(f"Error during generation: {e}")
149
+ yield f"Error during text generation: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  def respond_base(message, history):
152
  """Handle base model response for Gradio ChatInterface."""
153
+ try:
154
+ response_gen = generate_chat_response(message, history, "base")
155
+ for response in response_gen:
156
+ yield response
157
+ except Exception as e:
158
+ print(f"Error in respond_base: {e}")
159
+ yield f"Error: {str(e)}"
160
 
161
  def respond_ft(message, history):
162
  """Handle fine-tuned model response for Gradio ChatInterface."""
163
+ try:
164
+ response_gen = generate_chat_response(message, history, "finetuned")
165
+ for response in response_gen:
166
+ yield response
167
+ except Exception as e:
168
+ print(f"Error in respond_ft: {e}")
169
+ yield f"Error: {str(e)}"
170
 
171
  # --- Gradio UI Definition ---
172
  with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo: