stevengrove commited on
Commit
11e6da8
·
verified ·
1 Parent(s): ccb39b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -32
app.py CHANGED
@@ -12,30 +12,14 @@ NEGATIVE_PROMPT = '''
12
  low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.
13
  '''
14
 
15
-
16
- def parse_args():
17
- args = argparse.ArgumentParser(description='MindOmni')
18
- args.add_argument('--device', type=str, default='cuda')
19
- args.add_argument('--dtype', type=str, default='bf16')
20
- args.add_argument('--model_path', type=str,
21
- default='EasonXiao-888/MindOmni')
22
- args = args.parse_args()
23
- return args
24
-
25
-
26
- def build_model(args):
27
- device = args.device
28
- MindOmni_model = MindOmni.from_pretrained(args.model_path)
29
- if args.dtype == "bf16":
30
- dtype = torch.bfloat16
31
- MindOmni_model.to(device=device, dtype=dtype)
32
- MindOmni_model.eval()
33
- return MindOmni_model
34
 
35
 
36
  @spaces.GPU
37
  def understand_func(
38
- MindOmni_model, text, do_sample, temperature,
39
  max_new_tokens, input_llm_images):
40
  if input_llm_images is not None and not isinstance(input_llm_images, list):
41
  input_llm_images = [input_llm_images]
@@ -47,7 +31,7 @@ def understand_func(
47
 
48
  @spaces.GPU
49
  def generate_func(
50
- MindOmni_model, text, use_cot, height, width, guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model, max_input_image_size, randomize_seed, save_images, do_sample, temperature, max_new_tokens, input_llm_images, only_understand):
51
  if input_llm_images is not None and not isinstance(input_llm_images, list):
52
  input_llm_images = [input_llm_images]
53
 
@@ -76,7 +60,7 @@ def generate_func(
76
  return img, prompt_, seed
77
 
78
 
79
- def build_gradio(args, MindOmni_model):
80
  with gr.Blocks() as demo:
81
  gr.Markdown("## 🪄 MindOmni Demo")
82
 
@@ -133,7 +117,7 @@ def build_gradio(args, MindOmni_model):
133
  )
134
 
135
  g_btn.click(
136
- partial(generate_func, MindOmni_model),
137
  inputs=[g_prompt, g_use_cot, g_height, g_width, g_scale, g_steps,
138
  g_seed, g_sep_cfg, g_offload, g_max_img, g_rand, g_save,
139
  g_do_sample, g_temperature, g_max_new_tok,
@@ -156,7 +140,7 @@ def build_gradio(args, MindOmni_model):
156
  u_answer = gr.Textbox(label="Answer", lines=8)
157
 
158
  u_btn.click(
159
- partial(understand_func, MindOmni_model),
160
  inputs=[u_prompt, u_do_sample,
161
  u_temperature, u_max_new_tok, u_image],
162
  outputs=u_answer)
@@ -164,12 +148,5 @@ def build_gradio(args, MindOmni_model):
164
  demo.launch()
165
 
166
 
167
- def main():
168
- args = parse_args()
169
- print(f'running args: {args}')
170
- MindOmni_model = build_model(args)
171
- build_gradio(args, MindOmni_model)
172
-
173
-
174
  if __name__ == '__main__':
175
- main()
 
12
  low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.
13
  '''
14
 
15
+ MindOmni_model = MindOmni.from_pretrained('EasonXiao-888/MindOmni')
16
+ MindOmni_model.to(device='cuda', dtype=torch.bfloat16)
17
+ MindOmni_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  @spaces.GPU
21
  def understand_func(
22
+ text, do_sample, temperature,
23
  max_new_tokens, input_llm_images):
24
  if input_llm_images is not None and not isinstance(input_llm_images, list):
25
  input_llm_images = [input_llm_images]
 
31
 
32
  @spaces.GPU
33
  def generate_func(
34
+ text, use_cot, height, width, guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model, max_input_image_size, randomize_seed, save_images, do_sample, temperature, max_new_tokens, input_llm_images, only_understand):
35
  if input_llm_images is not None and not isinstance(input_llm_images, list):
36
  input_llm_images = [input_llm_images]
37
 
 
60
  return img, prompt_, seed
61
 
62
 
63
+ def build_gradio():
64
  with gr.Blocks() as demo:
65
  gr.Markdown("## 🪄 MindOmni Demo")
66
 
 
117
  )
118
 
119
  g_btn.click(
120
+ generate_func,
121
  inputs=[g_prompt, g_use_cot, g_height, g_width, g_scale, g_steps,
122
  g_seed, g_sep_cfg, g_offload, g_max_img, g_rand, g_save,
123
  g_do_sample, g_temperature, g_max_new_tok,
 
140
  u_answer = gr.Textbox(label="Answer", lines=8)
141
 
142
  u_btn.click(
143
+ understand_func,
144
  inputs=[u_prompt, u_do_sample,
145
  u_temperature, u_max_new_tok, u_image],
146
  outputs=u_answer)
 
148
  demo.launch()
149
 
150
 
 
 
 
 
 
 
 
151
  if __name__ == '__main__':
152
+ build_gradio()