Groundlight commited on
Commit
e4f398a
·
1 Parent(s): 2a9063b

update app

Browse files
Files changed (1) hide show
  1. app.py +332 -4
app.py CHANGED
@@ -1,7 +1,335 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import random
2
+ from threading import Thread
3
+
4
  import gradio as gr
5
+ import torch # Need this for torch.no_grad()
6
+ from datasets import load_dataset
7
+ from qwen_vl_utils import process_vision_info
8
+ from transformers import (
9
+ AutoProcessor,
10
+ Qwen2_5_VLForConditionalGeneration,
11
+ TextIteratorStreamer,
12
+ )
13
+ from trl import ModelConfig
14
+
15
+ # run with:
16
+ # CUDA_VISIBLE_DEVICES=0 uv run gradio demo/demo.py
17
+
18
+
19
+ def get_eval_dataset():
20
+ full_dataset = load_dataset("sunildkumar/message-decoding-words-and-sequences")["train"]
21
+ full_dataset = full_dataset.shuffle(seed=42)
22
+
23
+ # split the dataset with the same seed as used in the training script
24
+ splits = full_dataset.train_test_split(test_size=0.1, seed=42)
25
+ test_dataset = splits["test"]
26
+
27
+ return test_dataset
28
+
29
+
30
+ def load_model_and_tokenizer():
31
+ model_config = ModelConfig(
32
+ model_name_or_path="Groundlight/message-decoding-r1",
33
+ torch_dtype="bfloat16",
34
+ use_peft=False,
35
+ )
36
+
37
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
+ pretrained_model_name_or_path=model_config.model_name_or_path,
39
+ torch_dtype=model_config.torch_dtype,
40
+ use_cache=False,
41
+ device_map="auto", # Force CPU usage
42
+ )
43
+
44
+ # put model in eval mode
45
+ model.eval()
46
+
47
+ processor = AutoProcessor.from_pretrained(
48
+ model_config.model_name_or_path, padding_side="left"
49
+ )
50
+
51
+ return model, processor
52
+
53
+
54
+ # Move resource loading inside a function
55
+ def load_resources():
56
+ global eval_dataset, model, processor
57
+ eval_dataset = get_eval_dataset()
58
+ model, processor = load_model_and_tokenizer()
59
+
60
+
61
+ def show_random_example():
62
+ # Get a random example
63
+ random_idx = random.randint(0, len(eval_dataset) - 1)
64
+ example = eval_dataset[random_idx]
65
+
66
+ # Return image for display, mapping for state, and image for state
67
+ return example["image"], example["mapping"], example["image"]
68
+
69
+
70
+ def prepare_model_input(image, mapping, processor, submitted_word):
71
+ """
72
+ Prepare the input for the model using the mapping, processor, and submitted word.
73
+
74
+ Args:
75
+ image: The decoder image to use
76
+ mapping (dict): The mapping data from the dataset
77
+ processor: The model's processor/tokenizer
78
+ submitted_word (str): The word submitted by the user
79
+
80
+ Returns:
81
+ dict: The processed inputs ready for the model
82
+ """
83
+ decoded_message = submitted_word.lower()
84
+ print(f"Decoded message: {decoded_message}")
85
+
86
+ # reverse the decoder to encode the word
87
+ encoder = {v: k for k, v in mapping.items()}
88
+ print(f"Encoder: {encoder}")
89
+ # leaving the space as is
90
+ coded_message = [encoder[c] if c in encoder else c for c in decoded_message]
91
+ print(f"Coded message: {coded_message}")
92
+
93
+ # add spaces between each character to prevent tokenization issues
94
+ coded_message = " ".join(coded_message)
95
+
96
+ instruction = (
97
+ f'Use the decoder in the image to decode this coded message: "{coded_message}". '
98
+ "The decoded message will be one or more words. Underscore characters "
99
+ '("_") in the coded message should be mapped to a space (" ") when decoding.'
100
+ )
101
+
102
+ ending = (
103
+ "Show your work in <think> </think> tags and return the answer in <answer> </answer> tags. "
104
+ "While thinking, you must include a section with the decoded characters using <chars></chars> tags. "
105
+ "The <chars> section should include the decoded characters in the order they are decoded. It should include the "
106
+ "underscore character wherever there is a space in the decoded message. For example, if the coded message is "
107
+ "a b c _ d e f, the <chars> section might be <chars> c a t _ d o g </chars>. Once you are done thinking, "
108
+ "provide your answer in the <answer> section, e.g. <answer> cat dog </answer>."
109
+ )
110
+ instruction = f"{instruction} {ending}"
111
+
112
+ print(f"Instruction: {instruction}")
113
+
114
+ r1_messages = [
115
+ {
116
+ "role": "system",
117
+ "content": [
118
+ {
119
+ "type": "text",
120
+ "text": "You are a helpful assistant. You first think about the reasoning process in the mind and then provide the user with the answer.",
121
+ }
122
+ ],
123
+ },
124
+ {
125
+ "role": "user",
126
+ "content": [
127
+ {"type": "image", "image": image},
128
+ {"type": "text", "text": instruction},
129
+ ],
130
+ },
131
+ {
132
+ "role": "assistant",
133
+ "content": [
134
+ {"type": "text", "text": "Let me solve this step by step.\n<think>"}
135
+ ],
136
+ },
137
+ ]
138
+
139
+ texts = processor.apply_chat_template(
140
+ r1_messages, continue_final_message=True, tokenize=False
141
+ )
142
+
143
+ image_input, _ = process_vision_info(r1_messages)
144
+
145
+ image_input = [image_input]
146
+
147
+ batch = processor(
148
+ text=texts,
149
+ images=image_input,
150
+ padding=True,
151
+ return_tensors="pt",
152
+ )
153
+
154
+ return batch
155
+
156
+
157
+ def encode_word(word, mapping):
158
+ """
159
+ Encode a word using the given mapping.
160
+ """
161
+ if not word or not mapping:
162
+ return ""
163
+
164
+ word = word.lower()
165
+ # reverse the decoder to encode the word
166
+ encoder = {v: k for k, v in mapping.items()}
167
+ # leaving the space as is
168
+ coded_message = [encoder[c] if c in encoder else c for c in word]
169
+ return " ".join(coded_message)
170
+
171
+
172
+ def validate_and_submit(word, mapping):
173
+ # Check if input contains only letters
174
+ if not word.replace(" ", "").isalpha():
175
+ return (
176
+ gr.update(), # word input
177
+ gr.update(), # submit button
178
+ gr.update(interactive=False), # run button - disable but keep visible
179
+ gr.update(visible=False) # encoded word display
180
+ )
181
+
182
+ word = word.lower()
183
+ encoded_word = encode_word(word, mapping)
184
+
185
+ # Only enable run button if we have a valid encoded word
186
+ has_valid_encoded_word = bool(encoded_word.strip())
187
+
188
+ # Return updates for input, submit button, run button, and encoded word display
189
+ return (
190
+ gr.update(value=word, interactive=False, label="Submitted Word"),
191
+ gr.update(interactive=False), # Disable submit button
192
+ gr.update(interactive=has_valid_encoded_word), # Enable run button only if valid, but always visible
193
+ gr.update(value=f"Encoded word: {encoded_word}", visible=has_valid_encoded_word) # Show encoded word
194
+ )
195
+
196
+
197
+ def prepare_for_inference():
198
+ """Setup function that runs before streaming starts"""
199
+ return (
200
+ gr.update(value="", visible=True), # Clear and show output
201
+ gr.update(interactive=False), # Disable run button
202
+ gr.update(visible=True), # Show loading indicator
203
+ )
204
+
205
+
206
+ def run_inference(word, image, mapping):
207
+ """Main inference function, now focused just on generation"""
208
+ if not word or not image or not mapping:
209
+ raise gr.Error("Please submit a word and load a decoder first")
210
+
211
+ # Prepare model input
212
+ model_inputs = prepare_model_input(image, mapping, processor, word)
213
+ model_inputs = {k: v.to("cuda") for k, v in model_inputs.items()}
214
+
215
+ # Initialize streamer
216
+ streamer = TextIteratorStreamer(
217
+ tokenizer=processor,
218
+ skip_special_tokens=True,
219
+ decode_kwargs={"skip_special_tokens": True},
220
+ )
221
+
222
+ # Set up generation parameters
223
+ generation_kwargs = dict(
224
+ **model_inputs,
225
+ max_new_tokens=512,
226
+ do_sample=True,
227
+ temperature=1.0,
228
+ streamer=streamer,
229
+ )
230
+
231
+ # Start generation in a separate thread with torch.no_grad()
232
+ def generate_with_no_grad():
233
+ with torch.no_grad():
234
+ model.generate(**generation_kwargs)
235
+
236
+ thread = Thread(target=generate_with_no_grad)
237
+ thread.start()
238
+
239
+ # Stream the output
240
+ generated_text = ""
241
+ for new_text in streamer:
242
+ generated_text += new_text
243
+ yield generated_text
244
+
245
+ thread.join()
246
+ return generated_text
247
+
248
+
249
+ # Create the Gradio interface
250
+ with gr.Blocks() as demo:
251
+ # Load resources when the app starts
252
+ load_resources()
253
+
254
+ gr.Markdown("# Message Decoding Demo")
255
+ current_mapping = gr.State()
256
+ current_image = gr.State()
257
+
258
+ with gr.Row():
259
+ # Image display component
260
+ image_output = gr.Image(label="Decoder")
261
+
262
+ # Button to load new random example
263
+ next_button = gr.Button("Generate Random Decoder")
264
+ next_button.click(
265
+ fn=show_random_example, outputs=[image_output, current_mapping, current_image]
266
+ )
267
+
268
+ # Text input for the word
269
+ word_input = gr.Textbox(
270
+ label="Enter a single word",
271
+ placeholder="Enter word here...",
272
+ max_lines=1,
273
+ show_copy_button=False,
274
+ )
275
+
276
+ # Add encoded word display
277
+ encoded_word_display = gr.Textbox(
278
+ label="Encoded Word",
279
+ interactive=False,
280
+ visible=False,
281
+ max_lines=1,
282
+ show_copy_button=True,
283
+ )
284
+
285
+ # Group submit and run buttons vertically
286
+ with gr.Column(): # Use Column instead of Row for vertical layout
287
+ submit_button = gr.Button("Submit Word")
288
+ run_button = gr.Button("Run Model", interactive=False) # Initialize as visible but disabled
289
+
290
+ # Output area for model response
291
+ model_output = gr.Textbox(
292
+ label="Model Output",
293
+ interactive=False,
294
+ visible=False,
295
+ max_lines=10,
296
+ container=True,
297
+ show_copy_button=True,
298
+ )
299
+
300
+ # Add loading indicator
301
+ with gr.Row():
302
+ loading_indicator = gr.HTML(visible=False)
303
+
304
+ # Validate word on submit and update interface
305
+ submit_button.click(
306
+ fn=validate_and_submit,
307
+ inputs=[word_input, current_mapping],
308
+ outputs=[word_input, submit_button, run_button, encoded_word_display],
309
+ )
310
+
311
+ # Run inference when run button is clicked
312
+ run_button.click(
313
+ fn=prepare_for_inference,
314
+ outputs=[model_output, run_button, loading_indicator],
315
+ ).then(
316
+ fn=run_inference,
317
+ inputs=[word_input, current_image, current_mapping],
318
+ outputs=model_output,
319
+ api_name=False,
320
+ ).then(
321
+ # Reset interface after generation
322
+ lambda: (
323
+ gr.update(interactive=False), # Disable run button but keep visible
324
+ gr.update(visible=False), # Hide loading indicator
325
+ gr.update(interactive=True, label="Enter a single word"), # Re-enable word input
326
+ gr.update(interactive=True), # Re-enable submit button
327
+ gr.update(visible=False), # Hide encoded word display
328
+ ),
329
+ None,
330
+ [run_button, loading_indicator, word_input, submit_button, encoded_word_display],
331
+ )
332
 
 
 
333
 
334
+ if __name__ == "__main__":
335
+ demo.launch()