ShaoRun commited on
Commit
7854c68
·
verified ·
1 Parent(s): b2a1d7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -145
app.py CHANGED
@@ -1,154 +1,228 @@
1
- import gradio as gr
2
- import numpy as np
 
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
  inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
 
 
 
 
149
  ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
1
+ import sys
2
+ from io import BytesIO
3
+ import os
4
  import random
 
 
 
5
  import torch
6
+ import gradio as gr
7
+ sys.path.append("../")
8
+ from mm_models import AllSparkForCausalLM
9
+ from transformers import AutoImageProcessor, AutoTokenizer
10
+ from PIL import Image
11
+ import numpy as np
12
+ from plyfile import PlyData
13
+ import plotly.graph_objects as go
14
+ from mm_datasets.data_utils import point_preprocess, load_pts, process_pts
15
+ import matplotlib.pyplot as plt
16
+ from utils import SYSTEM_PROMPT
17
+
18
+ system_prompt = SYSTEM_PROMPT
19
+
20
+
21
+ def show_pointcloud(point_input, background='rgb(50,50,50)'):
22
+ if point_input is None:
23
+ return None
24
+ data = load_pts(point_input)
25
+ data = process_pts(data, 8192, True).numpy()
26
+ points = data[:, :3]
27
+ colors = data[:, 3:6]
28
+
29
+ if colors is not None:
30
+ # * if colors in range(0-1)
31
+ if np.max(colors) <= 1:
32
+ color_data = np.multiply(colors, 255).astype(int) # Convert float values (0-1) to integers (0-255)
33
+ # * if colors in range(0-255)
34
+ elif np.max(colors) <= 255:
35
+ color_data = colors.astype(int)
36
+ else:
37
+ color_data = np.zeros_like(points).astype(int) # Default to black color if RGB information is not available
38
+ colors = color_data.astype(np.float32) / 255 # model input is (0-1)
39
+
40
+ color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]
41
+
42
+ fig = go.Figure(
43
+ data=[
44
+ go.Scatter3d(
45
+ x=points[:, 0], y=points[:, 1], z=points[:, 2],
46
+ mode='markers',
47
+ marker=dict(
48
+ size=1.2,
49
+ color=color_strings, # Use the list of RGB strings for the marker colors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
+ )
52
+ ],
53
+ layout=dict(
54
+ scene=dict(
55
+ xaxis=dict(visible=False),
56
+ yaxis=dict(visible=False),
57
+ zaxis=dict(visible=False)
58
+ ),
59
+ paper_bgcolor='rgb(50,50,50)' if background is None else background # Set the background color to dark gray 50, 50, 50
60
+ ),
61
+ )
62
+ # convert to PIL image
63
+ img_bytes = fig.to_image(format="png", engine="kaleido")
64
+ img = Image.open(BytesIO(img_bytes))
65
+ return img
66
+
67
+
68
+ # load model
69
+ model_path = "[path/to/model]"
70
+
71
+ try:
72
+
73
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
74
+ model = AllSparkForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()
75
+ img_processor = AutoImageProcessor.from_pretrained(model_path)
76
+ modal_place_token = dict()
77
+ for modal_cfg in model.config.modal_configs:
78
+ modal_place_token[modal_cfg['modal_tag']] = modal_cfg['modal_placeholder_token']
79
+
80
+ except:
81
+ model = None
82
+
83
+ MARKDOWN = """
84
+ # AllSpark V2🔥
85
+ <div>
86
+ <a href="https://arxiv.org/pdf/2408.00203">
87
+ <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
88
+ </a>
89
+ </div>
90
+ AllSparkv2 is a language-centric progressive omni-modal learning framework
91
+ """
92
 
93
+ @torch.inference_mode()
94
+ # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
95
+ def process(
96
+ image_input,
97
+ point_input,
98
+ text_input
99
+ ):
100
+ if model is None:
101
+ return 'Please load the model first'
102
+
103
+ # no user input
104
+ if text_input is None:
105
+ return 'Please enter your question'
106
+
107
+ # only natural language
108
+ if image_input is None and point_input is None:
109
+ messages = [
110
+ {"role": "system", "content": system_prompt},
111
+ {"role": "user", "content": text_input}
112
+ ]
113
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(model.device)
114
+
115
+ outputs = model.generate(inputs,
116
+ do_sample=True,
117
+ temperature=0.6,
118
+ eos_token_id=tokenizer.eos_token_id,
119
+ pad_token_id=tokenizer.pad_token_id,
120
+ max_new_tokens=512)
121
+
122
+ text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
123
+ return None, text_output
124
+
125
+ # language - vision
126
+ if point_input is None:
127
+ # Text+Vision
128
+ img = image_input.convert("RGB")
129
+
130
+ img = img_processor(images=img, return_tensors="pt").pixel_values.to("cuda").squeeze().to(model.dtype)
131
+
132
+ modal_inputs = [('vision', img)]
133
+
134
+ question = modal_place_token['vision'] + "\n" + text_input
135
+
136
+ messages = [
137
+ {"role": "system", "content": system_prompt},
138
+ {"role": "user", "content": question}
139
+ ]
140
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(model.device)
141
+
142
+ outputs = model.generate(
143
+ inputs,
144
+ modal_inputs = [modal_inputs],
145
+ do_sample=True,
146
+ eos_token_id=tokenizer.eos_token_id,
147
+ pad_token_id=tokenizer.pad_token_id,
148
+ max_new_tokens=1024)
149
+
150
+ output = tokenizer.decode(outputs[0], skip_special_tokens=True)
151
+ return output
152
+
153
+ # language - point
154
+ point_cloud = load_pts(point_input)
155
+ point_cloud = process_pts(point_cloud, 8192, True)
156
+
157
+ show_pointcloud(point_cloud.numpy(), background='rgb(50,50,50)')
158
+
159
+ point_cloud = point_cloud.to(model.device).squeeze().to(model.dtype)
160
+ modal_inputs = [('point', point_cloud)]
161
+
162
+ question = modal_place_token['point'] + "\n" + text_input
163
+
164
+ messages = [
165
+ {"role": "system", "content": system_prompt},
166
+ {"role": "user", "content": question}
167
+ ]
168
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(model.device)
169
+
170
+ outputs = model.generate(
171
+ inputs,
172
+ modal_inputs = [modal_inputs],
173
+ do_sample=True,
174
+ temperature=0.6,
175
+ eos_token_id=tokenizer.eos_token_id,
176
+ pad_token_id=tokenizer.pad_token_id,
177
+ max_new_tokens=1024)
178
+
179
+ output = tokenizer.decode(outputs[0], skip_special_tokens=True)
180
+
181
+ return output
182
+
183
+ with gr.Blocks() as demo:
184
+ gr.Markdown(MARKDOWN)
185
+ with gr.Row():
186
+ with gr.Column():
187
+ image_input_component = gr.Image(
188
+ type='pil', label='Upload image')
189
+ point_input_component = gr.File(
190
+ label="Upload point data",
191
+ file_types=['.npy'],
192
+ file_count='single')
193
+ text_input_component = gr.Textbox(label="Text input", placeholder="Chat with AllSparkv2...")
194
+ submit_button_component = gr.Button(
195
+ value='Submit', variant='primary')
196
+ with gr.Column():
197
+ image_output_component = gr.Image(type='pil', label='Image Output')
198
+ text_output_component = gr.Textbox(label='Answer', placeholder='Text Output')
199
+
200
+ # automatically visualize the point cloud data once uploaded
201
+ point_input_component.change(
202
+ fn=show_pointcloud,
203
+ inputs=point_input_component,
204
+ outputs=image_output_component
205
+ )
206
 
207
+ submit_button_component.click(
208
+ fn=process,
 
 
209
  inputs=[
210
+ image_input_component,
211
+ point_input_component,
212
+ text_input_component
213
+ ],
214
+ outputs=text_output_component
215
+ )
216
+
217
+ gr.Examples(
218
+ examples=[
219
+ ["How do you explain to an elementary school student: why the sun rises in the east and sets in the west?", None, None],
220
+ ["What does this picture mean for max?", "inference/demo_assets/image2.png", None],
221
+ ["What is it?", None, "inference/demo_assets/e393be9a47a24a7cae6142e13f5686d1_8192.npy"]
222
  ],
223
+ inputs=[text_input_component, image_input_component, point_input_component]
224
  )
225
 
226
+ # demo.launch(debug=False, show_error=True, share=True)
227
+ # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
228
+ demo.queue().launch(share=True)