File size: 1,308 Bytes
08efd84 d86e712 08efd84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from transformers import pipeline
from vista3d_config import VISTA3DConfig
from vista3d_model import VISTA3DModel, register_my_model
from vista3d_pipeline import VISTA3DPipeline, register_simple_pipeline
class HuggingFacePipelineHelper:
def __init__(self, pipeline_name: str = "vista3d"):
self.pipeline_name = pipeline_name
def __model_register(self):
register_my_model()
def __pipeline_register(self):
register_simple_pipeline()
def get_pipeline(self):
self.__model_register()
self.__pipeline_register()
return pipeline(self.pipeline_name)
def _update_config(self, config, config_dict):
if config_dict:
for key in config_dict:
if hasattr(config, key) and getattr(config, key) != config_dict[key]:
setattr(config, key, config_dict[key])
return config
def init_pipeline(self, pretrained_model_name_or_path: str, **kwargs):
config = VISTA3DConfig()
config_dict = kwargs.pop("config_dict", None)
self._update_config(config, config_dict)
model = VISTA3DModel(config)
model = model.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path
)
return VISTA3DPipeline(model, **kwargs)
|