saeecl commited on
Commit
dc2c5e4
·
1 Parent(s): b7b00e2

add api routes

Browse files
Files changed (1) hide show
  1. app.py +104 -7
app.py CHANGED
@@ -138,6 +138,7 @@ def image_to_3d(
138
  dict: The information of the generated 3D model.
139
  str: The path to the video of the 3D model.
140
  """
 
141
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
142
  if not is_multiimage:
143
  outputs = pipeline.run(
@@ -180,6 +181,43 @@ def image_to_3d(
180
  return state, video_path
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  @spaces.GPU(duration=90)
184
  def extract_glb(
185
  state: dict,
@@ -198,10 +236,11 @@ def extract_glb(
198
  Returns:
199
  str: The path to the extracted GLB file.
200
  """
201
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
  gs, mesh = unpack_state(state)
203
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
204
- glb_path = os.path.join(user_dir, 'sample.glb')
 
205
  glb.export(glb_path)
206
  torch.cuda.empty_cache()
207
  return glb_path, glb_path
@@ -254,6 +293,57 @@ def split_image(image: Image.Image) -> List[Image.Image]:
254
  images.append(Image.fromarray(image[:, s:e+1]))
255
  return [preprocess_image(image) for image in images]
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  with gr.Blocks(delete_cache=(600, 600)) as demo:
259
  gr.Markdown("""
@@ -405,10 +495,17 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
405
 
406
  # Launch the Gradio app
407
  if __name__ == "__main__":
 
 
 
408
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
409
  pipeline.cuda()
410
- try:
411
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
412
- except:
413
- pass
414
- demo.launch()
 
 
 
 
 
138
  dict: The information of the generated 3D model.
139
  str: The path to the video of the 3D model.
140
  """
141
+
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
143
  if not is_multiimage:
144
  outputs = pipeline.run(
 
181
  return state, video_path
182
 
183
 
184
+ @spaces.GPU
185
+ def image_to_3d2(
186
+ image: Image.Image,
187
+ seed: int,
188
+ ss_guidance_strength: float,
189
+ ss_sampling_steps: int,
190
+ slat_guidance_strength: float,
191
+ slat_sampling_steps: int,
192
+ ) -> Tuple[dict, str]:
193
+
194
+ outputs = pipeline.run(
195
+ image,
196
+ seed=seed,
197
+ formats=["gaussian", "mesh"],
198
+ preprocess_image=False,
199
+ sparse_structure_sampler_params={
200
+ "steps": ss_sampling_steps,
201
+ "cfg_strength": ss_guidance_strength,
202
+ },
203
+ slat_sampler_params={
204
+ "steps": slat_sampling_steps,
205
+ "cfg_strength": slat_guidance_strength,
206
+ },
207
+ )
208
+
209
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
210
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
211
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
212
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
213
+ video_path = temp_video.name
214
+ imageio.mimsave(video_path, video, fps=15)
215
+
216
+ torch.cuda.empty_cache()
217
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
218
+ return state, video_path
219
+
220
+ import random
221
  @spaces.GPU(duration=90)
222
  def extract_glb(
223
  state: dict,
 
236
  Returns:
237
  str: The path to the extracted GLB file.
238
  """
239
+ user_dir = TMP_DIR
240
  gs, mesh = unpack_state(state)
241
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
242
+ glb_path = os.path.join(user_dir, f"test_{random.random()}.glb")
243
+
244
  glb.export(glb_path)
245
  torch.cuda.empty_cache()
246
  return glb_path, glb_path
 
293
  images.append(Image.fromarray(image[:, s:e+1]))
294
  return [preprocess_image(image) for image in images]
295
 
296
+ from fastapi import FastAPI, File, UploadFile, HTTPException
297
+ from fastapi.responses import JSONResponse, FileResponse
298
+ import tempfile
299
+
300
+ app = FastAPI()
301
+
302
+ @app.post("/generatee")
303
+ async def generate_3d(image: UploadFile = File(...)):
304
+ if not image:
305
+ raise HTTPException(status_code=400, detail="No image provided")
306
+
307
+ try:
308
+ image_data = Image.open(image.file)
309
+ image_data = image_data.convert("RGBA")
310
+ # image_data = torch.tensor(np.array(image_data)).to('cuda')
311
+
312
+ seed = 42
313
+ ss_guidance_strength = 7.5
314
+ ss_sampling_steps = 12
315
+ slat_guidance_strength = 3.0
316
+ slat_sampling_steps = 12
317
+
318
+ print("STARTING TO GENERATE VIDEO ...")
319
+ state, _ = image_to_3d2(
320
+ image_data,
321
+ seed=seed,
322
+ ss_guidance_strength=ss_guidance_strength,
323
+ ss_sampling_steps=ss_sampling_steps,
324
+ slat_guidance_strength=slat_guidance_strength,
325
+ slat_sampling_steps=slat_sampling_steps,
326
+ )
327
+ mesh_simplify = 0.95
328
+ texture_size = 1024
329
+ print("STARTING TO GENERATE GLB FILE ...")
330
+ glb_path, _ = extract_glb(
331
+ state=state,
332
+ mesh_simplify=mesh_simplify,
333
+ texture_size=texture_size,
334
+ req=None # Assuming req is not needed here
335
+ )
336
+ print("DONE")
337
+ return FileResponse(glb_path, media_type='application/octet-stream', filename='model.glb')
338
+
339
+ except Exception as e:
340
+ print("ERROR IN GENERATING 3D FILE : " , e)
341
+ raise HTTPException(status_code=500, detail=str(e))
342
+
343
+
344
+ @app.get("/test")
345
+ async def test_():
346
+ return JSONResponse(content={"TEST": "DONE!"})
347
 
348
  with gr.Blocks(delete_cache=(600, 600)) as demo:
349
  gr.Markdown("""
 
495
 
496
  # Launch the Gradio app
497
  if __name__ == "__main__":
498
+ import uvicorn
499
+ from gradio import mount_gradio_app
500
+
501
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
502
  pipeline.cuda()
503
+
504
+ app = mount_gradio_app(app, demo, path="/")
505
+ uvicorn.run(app, host="0.0.0.0", port=7860)
506
+
507
+ # try:
508
+ # pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
509
+ # except:
510
+ # pass
511
+ # demo.launch()