BinLiunls commited on
Commit
d86e712
·
1 Parent(s): 3601075

fix output issue

Browse files

Signed-off-by: binliu <[email protected]>

Files changed (2) hide show
  1. hugging_face_pipeline.py +1 -1
  2. vista3d_pipeline.py +2 -0
hugging_face_pipeline.py CHANGED
@@ -32,7 +32,7 @@ class HuggingFacePipelineHelper:
32
  config_dict = kwargs.pop("config_dict", None)
33
  self._update_config(config, config_dict)
34
  model = VISTA3DModel(config)
35
- model.from_pretrained(
36
  pretrained_model_name_or_path=pretrained_model_name_or_path
37
  )
38
  return VISTA3DPipeline(model, **kwargs)
 
32
  config_dict = kwargs.pop("config_dict", None)
33
  self._update_config(config, config_dict)
34
  model = VISTA3DModel(config)
35
+ model = model.from_pretrained(
36
  pretrained_model_name_or_path=pretrained_model_name_or_path
37
  )
38
  return VISTA3DPipeline(model, **kwargs)
vista3d_pipeline.py CHANGED
@@ -433,6 +433,8 @@ class VISTA3DPipeline(Pipeline):
433
  return outputs
434
 
435
  def postprocess(self, outputs, **kwargs):
 
 
436
  for key, value in kwargs.items():
437
  if key not in self.POSTPROCESSING_EXTRA_ARGS:
438
  logging.warning(f"Cannot set parameter {key} for postprocessing.")
 
433
  return outputs
434
 
435
  def postprocess(self, outputs, **kwargs):
436
+ outputs[Keys.IMAGE] = outputs[Keys.IMAGE].to(self.device)
437
+ outputs[Keys.PRED] = outputs[Keys.PRED].to(self.device)
438
  for key, value in kwargs.items():
439
  if key not in self.POSTPROCESSING_EXTRA_ARGS:
440
  logging.warning(f"Cannot set parameter {key} for postprocessing.")