File size: 601 Bytes
08efd84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from transformers import PretrainedConfig
class VISTA3DConfig(PretrainedConfig):
"""Configuration class for vista3d"""
model_type = "VISTA3D"
def __init__(self, encoder_embed_dim: int = 48, input_channels: int = 1, **kwargs):
"""
Set the hyperparameters for the VISTA3D model.
Parameters:
input_channels: channel of input images.
encoder_embed_dim: the encoder_embed_dim of the VISTA3D model.
"""
self.input_channels = input_channels
self.encoder_embed_dim = encoder_embed_dim
super().__init__(**kwargs)
|