from typing import Any from transformers import AutoConfig, PretrainedConfig, LlamaConfig class VisionConfig(PretrainedConfig): model_type: str = "vision_model" def __init__( self, **kwargs: Any, ): super().__init__(**kwargs) class ConnectorConfig(PretrainedConfig): model_type: str = "connector" def __init__( self, **kwargs: Any, ): super().__init__(**kwargs) class VLMConfig(LlamaConfig): model_type: str = "vlm" _sub_config_classes: dict[str, type[PretrainedConfig]] = { "vision_config": VisionConfig, "connector_config": ConnectorConfig, } def __init__( self, vision_config_args: dict[str, Any] = None, connector_config_args: dict[str, Any] = None, lazy_load: bool = False, **kwargs: Any, ): final_vision_args = kwargs.pop("vision_config", vision_config_args) final_connector_args = kwargs.pop("connector_config", connector_config_args) self.vision_config: VisionConfig = VisionConfig(**(final_vision_args or {})) self.connector_config: ConnectorConfig = ConnectorConfig(**(final_connector_args or {})) self.lazy_load: bool = lazy_load # Initialize the base language model configuration part super().__init__(**kwargs) AutoConfig.register("vlm", VLMConfig)