Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
e4f398a
1
Parent(s):
2a9063b
update app
Browse files
app.py
CHANGED
@@ -1,7 +1,335 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
demo.launch()
|
|
|
1 |
+
import random
|
2 |
+
from threading import Thread
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
import torch # Need this for torch.no_grad()
|
6 |
+
from datasets import load_dataset
|
7 |
+
from qwen_vl_utils import process_vision_info
|
8 |
+
from transformers import (
|
9 |
+
AutoProcessor,
|
10 |
+
Qwen2_5_VLForConditionalGeneration,
|
11 |
+
TextIteratorStreamer,
|
12 |
+
)
|
13 |
+
from trl import ModelConfig
|
14 |
+
|
15 |
+
# run with:
|
16 |
+
# CUDA_VISIBLE_DEVICES=0 uv run gradio demo/demo.py
|
17 |
+
|
18 |
+
|
19 |
+
def get_eval_dataset():
|
20 |
+
full_dataset = load_dataset("sunildkumar/message-decoding-words-and-sequences")["train"]
|
21 |
+
full_dataset = full_dataset.shuffle(seed=42)
|
22 |
+
|
23 |
+
# split the dataset with the same seed as used in the training script
|
24 |
+
splits = full_dataset.train_test_split(test_size=0.1, seed=42)
|
25 |
+
test_dataset = splits["test"]
|
26 |
+
|
27 |
+
return test_dataset
|
28 |
+
|
29 |
+
|
30 |
+
def load_model_and_tokenizer():
|
31 |
+
model_config = ModelConfig(
|
32 |
+
model_name_or_path="Groundlight/message-decoding-r1",
|
33 |
+
torch_dtype="bfloat16",
|
34 |
+
use_peft=False,
|
35 |
+
)
|
36 |
+
|
37 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
38 |
+
pretrained_model_name_or_path=model_config.model_name_or_path,
|
39 |
+
torch_dtype=model_config.torch_dtype,
|
40 |
+
use_cache=False,
|
41 |
+
device_map="auto", # Force CPU usage
|
42 |
+
)
|
43 |
+
|
44 |
+
# put model in eval mode
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
processor = AutoProcessor.from_pretrained(
|
48 |
+
model_config.model_name_or_path, padding_side="left"
|
49 |
+
)
|
50 |
+
|
51 |
+
return model, processor
|
52 |
+
|
53 |
+
|
54 |
+
# Move resource loading inside a function
|
55 |
+
def load_resources():
|
56 |
+
global eval_dataset, model, processor
|
57 |
+
eval_dataset = get_eval_dataset()
|
58 |
+
model, processor = load_model_and_tokenizer()
|
59 |
+
|
60 |
+
|
61 |
+
def show_random_example():
|
62 |
+
# Get a random example
|
63 |
+
random_idx = random.randint(0, len(eval_dataset) - 1)
|
64 |
+
example = eval_dataset[random_idx]
|
65 |
+
|
66 |
+
# Return image for display, mapping for state, and image for state
|
67 |
+
return example["image"], example["mapping"], example["image"]
|
68 |
+
|
69 |
+
|
70 |
+
def prepare_model_input(image, mapping, processor, submitted_word):
|
71 |
+
"""
|
72 |
+
Prepare the input for the model using the mapping, processor, and submitted word.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
image: The decoder image to use
|
76 |
+
mapping (dict): The mapping data from the dataset
|
77 |
+
processor: The model's processor/tokenizer
|
78 |
+
submitted_word (str): The word submitted by the user
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
dict: The processed inputs ready for the model
|
82 |
+
"""
|
83 |
+
decoded_message = submitted_word.lower()
|
84 |
+
print(f"Decoded message: {decoded_message}")
|
85 |
+
|
86 |
+
# reverse the decoder to encode the word
|
87 |
+
encoder = {v: k for k, v in mapping.items()}
|
88 |
+
print(f"Encoder: {encoder}")
|
89 |
+
# leaving the space as is
|
90 |
+
coded_message = [encoder[c] if c in encoder else c for c in decoded_message]
|
91 |
+
print(f"Coded message: {coded_message}")
|
92 |
+
|
93 |
+
# add spaces between each character to prevent tokenization issues
|
94 |
+
coded_message = " ".join(coded_message)
|
95 |
+
|
96 |
+
instruction = (
|
97 |
+
f'Use the decoder in the image to decode this coded message: "{coded_message}". '
|
98 |
+
"The decoded message will be one or more words. Underscore characters "
|
99 |
+
'("_") in the coded message should be mapped to a space (" ") when decoding.'
|
100 |
+
)
|
101 |
+
|
102 |
+
ending = (
|
103 |
+
"Show your work in <think> </think> tags and return the answer in <answer> </answer> tags. "
|
104 |
+
"While thinking, you must include a section with the decoded characters using <chars></chars> tags. "
|
105 |
+
"The <chars> section should include the decoded characters in the order they are decoded. It should include the "
|
106 |
+
"underscore character wherever there is a space in the decoded message. For example, if the coded message is "
|
107 |
+
"a b c _ d e f, the <chars> section might be <chars> c a t _ d o g </chars>. Once you are done thinking, "
|
108 |
+
"provide your answer in the <answer> section, e.g. <answer> cat dog </answer>."
|
109 |
+
)
|
110 |
+
instruction = f"{instruction} {ending}"
|
111 |
+
|
112 |
+
print(f"Instruction: {instruction}")
|
113 |
+
|
114 |
+
r1_messages = [
|
115 |
+
{
|
116 |
+
"role": "system",
|
117 |
+
"content": [
|
118 |
+
{
|
119 |
+
"type": "text",
|
120 |
+
"text": "You are a helpful assistant. You first think about the reasoning process in the mind and then provide the user with the answer.",
|
121 |
+
}
|
122 |
+
],
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"role": "user",
|
126 |
+
"content": [
|
127 |
+
{"type": "image", "image": image},
|
128 |
+
{"type": "text", "text": instruction},
|
129 |
+
],
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"role": "assistant",
|
133 |
+
"content": [
|
134 |
+
{"type": "text", "text": "Let me solve this step by step.\n<think>"}
|
135 |
+
],
|
136 |
+
},
|
137 |
+
]
|
138 |
+
|
139 |
+
texts = processor.apply_chat_template(
|
140 |
+
r1_messages, continue_final_message=True, tokenize=False
|
141 |
+
)
|
142 |
+
|
143 |
+
image_input, _ = process_vision_info(r1_messages)
|
144 |
+
|
145 |
+
image_input = [image_input]
|
146 |
+
|
147 |
+
batch = processor(
|
148 |
+
text=texts,
|
149 |
+
images=image_input,
|
150 |
+
padding=True,
|
151 |
+
return_tensors="pt",
|
152 |
+
)
|
153 |
+
|
154 |
+
return batch
|
155 |
+
|
156 |
+
|
157 |
+
def encode_word(word, mapping):
|
158 |
+
"""
|
159 |
+
Encode a word using the given mapping.
|
160 |
+
"""
|
161 |
+
if not word or not mapping:
|
162 |
+
return ""
|
163 |
+
|
164 |
+
word = word.lower()
|
165 |
+
# reverse the decoder to encode the word
|
166 |
+
encoder = {v: k for k, v in mapping.items()}
|
167 |
+
# leaving the space as is
|
168 |
+
coded_message = [encoder[c] if c in encoder else c for c in word]
|
169 |
+
return " ".join(coded_message)
|
170 |
+
|
171 |
+
|
172 |
+
def validate_and_submit(word, mapping):
|
173 |
+
# Check if input contains only letters
|
174 |
+
if not word.replace(" ", "").isalpha():
|
175 |
+
return (
|
176 |
+
gr.update(), # word input
|
177 |
+
gr.update(), # submit button
|
178 |
+
gr.update(interactive=False), # run button - disable but keep visible
|
179 |
+
gr.update(visible=False) # encoded word display
|
180 |
+
)
|
181 |
+
|
182 |
+
word = word.lower()
|
183 |
+
encoded_word = encode_word(word, mapping)
|
184 |
+
|
185 |
+
# Only enable run button if we have a valid encoded word
|
186 |
+
has_valid_encoded_word = bool(encoded_word.strip())
|
187 |
+
|
188 |
+
# Return updates for input, submit button, run button, and encoded word display
|
189 |
+
return (
|
190 |
+
gr.update(value=word, interactive=False, label="Submitted Word"),
|
191 |
+
gr.update(interactive=False), # Disable submit button
|
192 |
+
gr.update(interactive=has_valid_encoded_word), # Enable run button only if valid, but always visible
|
193 |
+
gr.update(value=f"Encoded word: {encoded_word}", visible=has_valid_encoded_word) # Show encoded word
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def prepare_for_inference():
|
198 |
+
"""Setup function that runs before streaming starts"""
|
199 |
+
return (
|
200 |
+
gr.update(value="", visible=True), # Clear and show output
|
201 |
+
gr.update(interactive=False), # Disable run button
|
202 |
+
gr.update(visible=True), # Show loading indicator
|
203 |
+
)
|
204 |
+
|
205 |
+
|
206 |
+
def run_inference(word, image, mapping):
|
207 |
+
"""Main inference function, now focused just on generation"""
|
208 |
+
if not word or not image or not mapping:
|
209 |
+
raise gr.Error("Please submit a word and load a decoder first")
|
210 |
+
|
211 |
+
# Prepare model input
|
212 |
+
model_inputs = prepare_model_input(image, mapping, processor, word)
|
213 |
+
model_inputs = {k: v.to("cuda") for k, v in model_inputs.items()}
|
214 |
+
|
215 |
+
# Initialize streamer
|
216 |
+
streamer = TextIteratorStreamer(
|
217 |
+
tokenizer=processor,
|
218 |
+
skip_special_tokens=True,
|
219 |
+
decode_kwargs={"skip_special_tokens": True},
|
220 |
+
)
|
221 |
+
|
222 |
+
# Set up generation parameters
|
223 |
+
generation_kwargs = dict(
|
224 |
+
**model_inputs,
|
225 |
+
max_new_tokens=512,
|
226 |
+
do_sample=True,
|
227 |
+
temperature=1.0,
|
228 |
+
streamer=streamer,
|
229 |
+
)
|
230 |
+
|
231 |
+
# Start generation in a separate thread with torch.no_grad()
|
232 |
+
def generate_with_no_grad():
|
233 |
+
with torch.no_grad():
|
234 |
+
model.generate(**generation_kwargs)
|
235 |
+
|
236 |
+
thread = Thread(target=generate_with_no_grad)
|
237 |
+
thread.start()
|
238 |
+
|
239 |
+
# Stream the output
|
240 |
+
generated_text = ""
|
241 |
+
for new_text in streamer:
|
242 |
+
generated_text += new_text
|
243 |
+
yield generated_text
|
244 |
+
|
245 |
+
thread.join()
|
246 |
+
return generated_text
|
247 |
+
|
248 |
+
|
249 |
+
# Create the Gradio interface
|
250 |
+
with gr.Blocks() as demo:
|
251 |
+
# Load resources when the app starts
|
252 |
+
load_resources()
|
253 |
+
|
254 |
+
gr.Markdown("# Message Decoding Demo")
|
255 |
+
current_mapping = gr.State()
|
256 |
+
current_image = gr.State()
|
257 |
+
|
258 |
+
with gr.Row():
|
259 |
+
# Image display component
|
260 |
+
image_output = gr.Image(label="Decoder")
|
261 |
+
|
262 |
+
# Button to load new random example
|
263 |
+
next_button = gr.Button("Generate Random Decoder")
|
264 |
+
next_button.click(
|
265 |
+
fn=show_random_example, outputs=[image_output, current_mapping, current_image]
|
266 |
+
)
|
267 |
+
|
268 |
+
# Text input for the word
|
269 |
+
word_input = gr.Textbox(
|
270 |
+
label="Enter a single word",
|
271 |
+
placeholder="Enter word here...",
|
272 |
+
max_lines=1,
|
273 |
+
show_copy_button=False,
|
274 |
+
)
|
275 |
+
|
276 |
+
# Add encoded word display
|
277 |
+
encoded_word_display = gr.Textbox(
|
278 |
+
label="Encoded Word",
|
279 |
+
interactive=False,
|
280 |
+
visible=False,
|
281 |
+
max_lines=1,
|
282 |
+
show_copy_button=True,
|
283 |
+
)
|
284 |
+
|
285 |
+
# Group submit and run buttons vertically
|
286 |
+
with gr.Column(): # Use Column instead of Row for vertical layout
|
287 |
+
submit_button = gr.Button("Submit Word")
|
288 |
+
run_button = gr.Button("Run Model", interactive=False) # Initialize as visible but disabled
|
289 |
+
|
290 |
+
# Output area for model response
|
291 |
+
model_output = gr.Textbox(
|
292 |
+
label="Model Output",
|
293 |
+
interactive=False,
|
294 |
+
visible=False,
|
295 |
+
max_lines=10,
|
296 |
+
container=True,
|
297 |
+
show_copy_button=True,
|
298 |
+
)
|
299 |
+
|
300 |
+
# Add loading indicator
|
301 |
+
with gr.Row():
|
302 |
+
loading_indicator = gr.HTML(visible=False)
|
303 |
+
|
304 |
+
# Validate word on submit and update interface
|
305 |
+
submit_button.click(
|
306 |
+
fn=validate_and_submit,
|
307 |
+
inputs=[word_input, current_mapping],
|
308 |
+
outputs=[word_input, submit_button, run_button, encoded_word_display],
|
309 |
+
)
|
310 |
+
|
311 |
+
# Run inference when run button is clicked
|
312 |
+
run_button.click(
|
313 |
+
fn=prepare_for_inference,
|
314 |
+
outputs=[model_output, run_button, loading_indicator],
|
315 |
+
).then(
|
316 |
+
fn=run_inference,
|
317 |
+
inputs=[word_input, current_image, current_mapping],
|
318 |
+
outputs=model_output,
|
319 |
+
api_name=False,
|
320 |
+
).then(
|
321 |
+
# Reset interface after generation
|
322 |
+
lambda: (
|
323 |
+
gr.update(interactive=False), # Disable run button but keep visible
|
324 |
+
gr.update(visible=False), # Hide loading indicator
|
325 |
+
gr.update(interactive=True, label="Enter a single word"), # Re-enable word input
|
326 |
+
gr.update(interactive=True), # Re-enable submit button
|
327 |
+
gr.update(visible=False), # Hide encoded word display
|
328 |
+
),
|
329 |
+
None,
|
330 |
+
[run_button, loading_indicator, word_input, submit_button, encoded_word_display],
|
331 |
+
)
|
332 |
|
|
|
|
|
333 |
|
334 |
+
if __name__ == "__main__":
|
335 |
+
demo.launch()
|