arad1367 commited on
Commit
49da751
·
verified ·
1 Parent(s): 9b1a390

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -162
app.py CHANGED
@@ -1,162 +1,163 @@
1
- import os
2
- import matplotlib.pyplot as plt
3
- import matplotlib.patches as patches
4
- from PIL import Image
5
- import gradio as gr
6
- from transformers import AutoProcessor, AutoModelForCausalLM
7
- import torch
8
- import numpy as np
9
- import spaces
10
- import subprocess
11
-
12
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
-
14
- # Initialize Florence-2-large model and processor
15
- model_id = 'microsoft/Florence-2-large'
16
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
17
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
18
-
19
- # Function to resize and preprocess image
20
- def preprocess_image(image_path, max_size=(800, 800)):
21
- image = Image.open(image_path).convert('RGB')
22
- if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
23
- image.thumbnail(max_size, Image.LANCZOS)
24
-
25
- # Convert image to numpy array
26
- image_np = np.array(image)
27
-
28
- # Ensure the image is in the format [height, width, channels]
29
- if image_np.ndim == 2: # Grayscale image
30
- image_np = np.expand_dims(image_np, axis=-1)
31
- elif image_np.shape[0] == 3: # Image in [channels, height, width] format
32
- image_np = np.transpose(image_np, (1, 2, 0))
33
-
34
- return image_np, image.size
35
-
36
- # Function to run Florence-2-large model
37
- @spaces.GPU
38
- def run_florence_model(image_np, image_size, task_prompt, text_input=None):
39
- if text_input is None:
40
- prompt = task_prompt
41
- else:
42
- prompt = task_prompt + text_input
43
-
44
- inputs = processor(text=prompt, images=image_np, return_tensors="pt")
45
-
46
- with torch.no_grad():
47
- outputs = model.generate(
48
- input_ids=inputs["input_ids"].cuda(),
49
- pixel_values=inputs["pixel_values"].cuda(),
50
- max_new_tokens=1024,
51
- early_stopping=False,
52
- do_sample=False,
53
- num_beams=3,
54
- )
55
-
56
- generated_text = processor.batch_decode(outputs, skip_special_tokens=False)[0]
57
- parsed_answer = processor.post_process_generation(
58
- generated_text,
59
- task=task_prompt,
60
- image_size=image_size
61
- )
62
-
63
- return parsed_answer, generated_text
64
-
65
- # Function to plot image with bounding boxes
66
- def plot_image_with_bboxes(image_np, bboxes, labels=None):
67
- fig, ax = plt.subplots(1)
68
- ax.imshow(image_np)
69
- colors = ['red', 'blue', 'green', 'yellow', 'purple', 'cyan']
70
- for i, bbox in enumerate(bboxes):
71
- color = colors[i % len(colors)]
72
- x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
73
- rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor=color, facecolor='none')
74
- ax.add_patch(rect)
75
- if labels and i < len(labels):
76
- ax.text(x, y, labels[i], color=color, fontsize=8, bbox=dict(facecolor='white', alpha=0.7))
77
- plt.axis('off')
78
- return fig
79
-
80
- # Gradio function to process uploaded images
81
- @spaces.GPU
82
- def process_image(image_path):
83
- image_np, image_size = preprocess_image(image_path)
84
-
85
- # Convert image_np to float32
86
- image_np = image_np.astype(np.float32)
87
-
88
- # Image Captioning
89
- caption_result, _ = run_florence_model(image_np, image_size, '<CAPTION>')
90
- detailed_caption_result, _ = run_florence_model(image_np, image_size, '<DETAILED_CAPTION>')
91
-
92
- # Object Detection
93
- od_result, _ = run_florence_model(image_np, image_size, '<OD>')
94
- od_bboxes = od_result['<OD>'].get('bboxes', [])
95
- od_labels = od_result['<OD>'].get('labels', [])
96
-
97
- # OCR
98
- ocr_result, _ = run_florence_model(image_np, image_size, '<OCR>')
99
-
100
- # Phrase Grounding
101
- pg_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=caption_result['<CAPTION>'])
102
- pg_bboxes = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
103
- pg_labels = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
104
-
105
- # Cascaded Tasks (Detailed Caption + Phrase Grounding)
106
- cascaded_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=detailed_caption_result['<DETAILED_CAPTION>'])
107
- cascaded_bboxes = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
108
- cascaded_labels = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
109
-
110
- # Create plots
111
- od_fig = plot_image_with_bboxes(image_np, od_bboxes, od_labels)
112
- pg_fig = plot_image_with_bboxes(image_np, pg_bboxes, pg_labels)
113
- cascaded_fig = plot_image_with_bboxes(image_np, cascaded_bboxes, cascaded_labels)
114
-
115
- # Prepare response
116
- response = f"""
117
- Image Captioning:
118
- - Simple Caption: {caption_result['<CAPTION>']}
119
- - Detailed Caption: {detailed_caption_result['<DETAILED_CAPTION>']}
120
-
121
- Object Detection:
122
- - Detected {len(od_bboxes)} objects
123
-
124
- OCR:
125
- {ocr_result['<OCR>']}
126
-
127
- Phrase Grounding:
128
- - Grounded {len(pg_bboxes)} phrases from the simple caption
129
-
130
- Cascaded Tasks:
131
- - Grounded {len(cascaded_bboxes)} phrases from the detailed caption
132
- """
133
-
134
- return response, od_fig, pg_fig, cascaded_fig
135
-
136
- # Gradio interface
137
- with gr.Blocks(theme='NoCrypt/miku') as demo:
138
- gr.Markdown("""
139
- # Image Processing with Florence-2-large
140
- Upload an image to perform image captioning, object detection, OCR, phrase grounding, and cascaded tasks.
141
- """)
142
-
143
- image_input = gr.Image(type="filepath")
144
- text_output = gr.Textbox()
145
- plot_output_1 = gr.Plot()
146
- plot_output_2 = gr.Plot()
147
- plot_output_3 = gr.Plot()
148
-
149
- image_input.upload(process_image, inputs=[image_input], outputs=[text_output, plot_output_1, plot_output_2, plot_output_3])
150
-
151
- footer = """
152
- <div style="text-align: center; margin-top: 20px;">
153
- <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
154
- <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
155
- <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a>
156
- <br>
157
- Made with 💖 by Pejman Ebrahimi
158
- </div>
159
- """
160
- gr.HTML(footer)
161
-
162
- demo.launch()
 
 
1
+ import os
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ from PIL import Image
5
+ import gradio as gr
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ import torch
8
+ import numpy as np
9
+ import subprocess
10
+
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ # Initialize Florence-2-large model and processor
14
+ model_id = 'microsoft/Florence-2-large'
15
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
16
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
17
+
18
+ # Function to resize and preprocess image
19
+ def preprocess_image(image_path, max_size=(800, 800)):
20
+ image = Image.open(image_path).convert('RGB')
21
+ if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
22
+ image.thumbnail(max_size, Image.LANCZOS)
23
+
24
+ # Convert image to numpy array
25
+ image_np = np.array(image)
26
+
27
+ # Ensure the image is in the format [height, width, channels]
28
+ if image_np.ndim == 2: # Grayscale image
29
+ image_np = np.expand_dims(image_np, axis=-1)
30
+ elif image_np.shape[0] == 3: # Image in [channels, height, width] format
31
+ image_np = np.transpose(image_np, (1, 2, 0))
32
+
33
+ return image_np, image.size
34
+
35
+ # Function to run Florence-2-large model
36
+ def run_florence_model(image_np, image_size, task_prompt, text_input=None):
37
+ if text_input is None:
38
+ prompt = task_prompt
39
+ else:
40
+ prompt = task_prompt + text_input
41
+
42
+ inputs = processor(text=prompt, images=image_np, return_tensors="pt")
43
+
44
+ with torch.no_grad():
45
+ outputs = model.generate(
46
+ input_ids=inputs["input_ids"].cuda(),
47
+ pixel_values=inputs["pixel_values"].cuda(),
48
+ max_new_tokens=1024,
49
+ early_stopping=False,
50
+ do_sample=False,
51
+ num_beams=3,
52
+ )
53
+
54
+ generated_text = processor.batch_decode(outputs, skip_special_tokens=False)[0]
55
+ parsed_answer = processor.post_process_generation(
56
+ generated_text,
57
+ task=task_prompt,
58
+ image_size=image_size
59
+ )
60
+
61
+ return parsed_answer, generated_text
62
+
63
+ # Function to plot image with bounding boxes
64
+ def plot_image_with_bboxes(image_np, bboxes, labels=None):
65
+ # Normalize the image array to [0.0, 1.0]
66
+ if image_np.dtype == np.uint8:
67
+ image_np = image_np / 255.0
68
+
69
+ fig, ax = plt.subplots(1)
70
+ ax.imshow(image_np)
71
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'cyan']
72
+ for i, bbox in enumerate(bboxes):
73
+ color = colors[i % len(colors)]
74
+ x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
75
+ rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor=color, facecolor='none')
76
+ ax.add_patch(rect)
77
+ if labels and i < len(labels):
78
+ ax.text(x, y, labels[i], color=color, fontsize=8, bbox=dict(facecolor='white', alpha=0.7))
79
+ plt.axis('off')
80
+ return fig
81
+
82
+ # Gradio function to process uploaded images
83
+ def process_image(image_path):
84
+ image_np, image_size = preprocess_image(image_path)
85
+
86
+ # Convert image_np to float32
87
+ image_np = image_np.astype(np.float32)
88
+
89
+ # Image Captioning
90
+ caption_result, _ = run_florence_model(image_np, image_size, '<CAPTION>')
91
+ detailed_caption_result, _ = run_florence_model(image_np, image_size, '<DETAILED_CAPTION>')
92
+
93
+ # Object Detection
94
+ od_result, _ = run_florence_model(image_np, image_size, '<OD>')
95
+ od_bboxes = od_result['<OD>'].get('bboxes', [])
96
+ od_labels = od_result['<OD>'].get('labels', [])
97
+
98
+ # OCR
99
+ ocr_result, _ = run_florence_model(image_np, image_size, '<OCR>')
100
+
101
+ # Phrase Grounding
102
+ pg_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=caption_result['<CAPTION>'])
103
+ pg_bboxes = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
104
+ pg_labels = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
105
+
106
+ # Cascaded Tasks (Detailed Caption + Phrase Grounding)
107
+ cascaded_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=detailed_caption_result['<DETAILED_CAPTION>'])
108
+ cascaded_bboxes = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
109
+ cascaded_labels = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
110
+
111
+ # Create plots
112
+ od_fig = plot_image_with_bboxes(image_np, od_bboxes, od_labels)
113
+ pg_fig = plot_image_with_bboxes(image_np, pg_bboxes, pg_labels)
114
+ cascaded_fig = plot_image_with_bboxes(image_np, cascaded_bboxes, cascaded_labels)
115
+
116
+ # Prepare response
117
+ response = f"""
118
+ Image Captioning:
119
+ - Simple Caption: {caption_result['<CAPTION>']}
120
+ - Detailed Caption: {detailed_caption_result['<DETAILED_CAPTION>']}
121
+
122
+ Object Detection:
123
+ - Detected {len(od_bboxes)} objects
124
+
125
+ OCR:
126
+ {ocr_result['<OCR>']}
127
+
128
+ Phrase Grounding:
129
+ - Grounded {len(pg_bboxes)} phrases from the simple caption
130
+
131
+ Cascaded Tasks:
132
+ - Grounded {len(cascaded_bboxes)} phrases from the detailed caption
133
+ """
134
+
135
+ return response, od_fig, pg_fig, cascaded_fig
136
+
137
+ # Gradio interface
138
+ with gr.Blocks(theme='NoCrypt/miku') as demo:
139
+ gr.Markdown("""
140
+ # Image Processing with Florence-2-large
141
+ Upload an image to perform image captioning, object detection, OCR, phrase grounding, and cascaded tasks.
142
+ """)
143
+
144
+ image_input = gr.Image(type="filepath")
145
+ text_output = gr.Textbox()
146
+ plot_output_1 = gr.Plot()
147
+ plot_output_2 = gr.Plot()
148
+ plot_output_3 = gr.Plot()
149
+
150
+ image_input.upload(process_image, inputs=[image_input], outputs=[text_output, plot_output_1, plot_output_2, plot_output_3])
151
+
152
+ footer = """
153
+ <div style="text-align: center; margin-top: 20px;">
154
+ <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
155
+ <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
156
+ <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a>
157
+ <br>
158
+ Made with 💖 by Pejman Ebrahimi
159
+ </div>
160
+ """
161
+ gr.HTML(footer)
162
+
163
+ demo.launch()