Bils commited on
Commit
6842006
·
verified ·
1 Parent(s): 172e437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -64,9 +64,10 @@ def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_s
64
  generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device)
65
 
66
  pkv = None
67
- for i in range(576):
68
- if progress:
69
- progress((i + 1) / 576, desc="Generating image tokens")
 
70
 
71
  outputs = vl_gpt.language_model.model(
72
  inputs_embeds=inputs_embeds,
@@ -114,7 +115,8 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
114
  if not prompt.strip():
115
  raise gr.Error("Please enter a valid prompt.")
116
 
117
- progress(0, desc="Initializing...")
 
118
  torch.cuda.empty_cache()
119
 
120
  # Seed management
@@ -135,7 +137,9 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
135
  ) + vl_chat_processor.image_start_tag
136
 
137
  input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device)
138
- progress(0.1, desc="Generating image tokens...")
 
 
139
 
140
  generated_tokens = generate(
141
  input_ids,
@@ -147,7 +151,8 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
147
  progress=progress
148
  )
149
 
150
- progress(0.9, desc="Processing images...")
 
151
  patches = vl_gpt.gen_vision_model.decode_code(
152
  generated_tokens.to(dtype=torch.int),
153
  shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE]
@@ -157,8 +162,11 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
157
  return images
158
 
159
  except Exception as e:
160
- logger.error(f"Generation failed: {str(e)}")
161
- raise gr.Error(f"Image generation failed: {str(e)}")
 
 
 
162
 
163
  def create_interface():
164
  with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo:
@@ -169,7 +177,11 @@ def create_interface():
169
 
170
  with gr.Row():
171
  with gr.Column(scale=3):
172
- prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate...", lines=3)
 
 
 
 
173
  generate_btn = gr.Button("Generate Images", variant="primary")
174
 
175
  with gr.Accordion("Advanced Settings", open=False):
@@ -178,22 +190,36 @@ def create_interface():
178
  label="Seed",
179
  value=None,
180
  precision=0,
181
- info="Leave empty for random seed" # Fixed parameter
182
  )
183
  guidance_slider = gr.Slider(
184
- 3, 10, value=5, step=0.5,
185
  label="CFG Guidance Weight",
 
 
 
 
186
  info="Higher values = more prompt adherence, lower values = more creativity"
187
  )
188
  temp_slider = gr.Slider(
189
- 0.1, 1.0, value=1.0, step=0.1,
190
  label="Temperature",
 
 
 
 
191
  info="Higher values = more randomness, lower values = more deterministic"
192
  )
193
 
194
  with gr.Column(scale=2):
195
- output_gallery = gr.Gallery(label="Generated Images", columns=2, height=600, preview=True)
196
- status = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
197
 
198
  gr.Examples(
199
  examples=[
 
64
  generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device)
65
 
66
  pkv = None
67
+ total_steps = 576
68
+ for i in range(total_steps):
69
+ if progress is not None:
70
+ progress((i + 1) / total_steps, desc="Generating image tokens")
71
 
72
  outputs = vl_gpt.language_model.model(
73
  inputs_embeds=inputs_embeds,
 
115
  if not prompt.strip():
116
  raise gr.Error("Please enter a valid prompt.")
117
 
118
+ if progress is not None:
119
+ progress(0, desc="Initializing...")
120
  torch.cuda.empty_cache()
121
 
122
  # Seed management
 
137
  ) + vl_chat_processor.image_start_tag
138
 
139
  input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device)
140
+
141
+ if progress is not None:
142
+ progress(0.1, desc="Generating image tokens...")
143
 
144
  generated_tokens = generate(
145
  input_ids,
 
151
  progress=progress
152
  )
153
 
154
+ if progress is not None:
155
+ progress(0.9, desc="Processing images...")
156
  patches = vl_gpt.gen_vision_model.decode_code(
157
  generated_tokens.to(dtype=torch.int),
158
  shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE]
 
162
  return images
163
 
164
  except Exception as e:
165
+ logger.error(f"Generation failed: {str(e)}", exc_info=True)
166
+ if "index out of range" in str(e).lower():
167
+ raise gr.Error("Image generation failed due to internal error. Please try again with different parameters.")
168
+ else:
169
+ raise gr.Error(f"Image generation failed: {str(e)}")
170
 
171
  def create_interface():
172
  with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo:
 
177
 
178
  with gr.Row():
179
  with gr.Column(scale=3):
180
+ prompt_input = gr.Textbox(
181
+ label="Prompt",
182
+ placeholder="Describe the image you want to generate...",
183
+ lines=3
184
+ )
185
  generate_btn = gr.Button("Generate Images", variant="primary")
186
 
187
  with gr.Accordion("Advanced Settings", open=False):
 
190
  label="Seed",
191
  value=None,
192
  precision=0,
193
+ info="Leave empty for random seed"
194
  )
195
  guidance_slider = gr.Slider(
 
196
  label="CFG Guidance Weight",
197
+ minimum=3,
198
+ maximum=10,
199
+ value=5,
200
+ step=0.5,
201
  info="Higher values = more prompt adherence, lower values = more creativity"
202
  )
203
  temp_slider = gr.Slider(
 
204
  label="Temperature",
205
+ minimum=0.1,
206
+ maximum=1.0,
207
+ value=1.0,
208
+ step=0.1,
209
  info="Higher values = more randomness, lower values = more deterministic"
210
  )
211
 
212
  with gr.Column(scale=2):
213
+ output_gallery = gr.Gallery(
214
+ label="Generated Images",
215
+ columns=2,
216
+ height=600,
217
+ preview=True
218
+ )
219
+ status = gr.Textbox(
220
+ label="Status",
221
+ interactive=False
222
+ )
223
 
224
  gr.Examples(
225
  examples=[