fix output issue
Browse filesSigned-off-by: binliu <[email protected]>
- hugging_face_pipeline.py +1 -1
- vista3d_pipeline.py +2 -0
hugging_face_pipeline.py
CHANGED
@@ -32,7 +32,7 @@ class HuggingFacePipelineHelper:
|
|
32 |
config_dict = kwargs.pop("config_dict", None)
|
33 |
self._update_config(config, config_dict)
|
34 |
model = VISTA3DModel(config)
|
35 |
-
model.from_pretrained(
|
36 |
pretrained_model_name_or_path=pretrained_model_name_or_path
|
37 |
)
|
38 |
return VISTA3DPipeline(model, **kwargs)
|
|
|
32 |
config_dict = kwargs.pop("config_dict", None)
|
33 |
self._update_config(config, config_dict)
|
34 |
model = VISTA3DModel(config)
|
35 |
+
model = model.from_pretrained(
|
36 |
pretrained_model_name_or_path=pretrained_model_name_or_path
|
37 |
)
|
38 |
return VISTA3DPipeline(model, **kwargs)
|
vista3d_pipeline.py
CHANGED
@@ -433,6 +433,8 @@ class VISTA3DPipeline(Pipeline):
|
|
433 |
return outputs
|
434 |
|
435 |
def postprocess(self, outputs, **kwargs):
|
|
|
|
|
436 |
for key, value in kwargs.items():
|
437 |
if key not in self.POSTPROCESSING_EXTRA_ARGS:
|
438 |
logging.warning(f"Cannot set parameter {key} for postprocessing.")
|
|
|
433 |
return outputs
|
434 |
|
435 |
def postprocess(self, outputs, **kwargs):
|
436 |
+
outputs[Keys.IMAGE] = outputs[Keys.IMAGE].to(self.device)
|
437 |
+
outputs[Keys.PRED] = outputs[Keys.PRED].to(self.device)
|
438 |
for key, value in kwargs.items():
|
439 |
if key not in self.POSTPROCESSING_EXTRA_ARGS:
|
440 |
logging.warning(f"Cannot set parameter {key} for postprocessing.")
|