import os from PIL import Image import torch from torch import nn from torch.utils.data import Dataset from transformers import PreTrainedModel, PretrainedConfig # CLIP from models.modeling_clipPT import CLIPVisionTransformer from transformers import CLIPImageProcessor # Qwen from transformers import AutoTokenizer from models.modeling_qwen2 import Qwen2Model # Timer from models.modeling_timer import TimerForPrediction class MultiModalTimerConfig(PretrainedConfig): def __init__( self, forecasting_length = None, vision_model_name = None, text_model_name = None, vision_model_prompt_len = None, text_model_prompt_len = None, timer_prompt_len = None, **kwargs ): super().__init__(**kwargs) self.forecasting_length = forecasting_length self.vision_model_name = vision_model_name self.text_model_name = text_model_name self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10 self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4 self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4 class MultiModalTimerModel(PreTrainedModel): config_class = MultiModalTimerConfig def __init__(self, config): super().__init__(config) self.config = config # Vision Model if config.vision_model_name is None: pass elif config.vision_model_name == 'CLIP': from transformers import AutoModel vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model state_dict = vision_model.state_dict() state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len) self.vision_model.load_state_dict(state_dict, strict=False) for name, param in self.vision_model.named_parameters(): # Freeze layers other than prompts if "encoder.prompts" in name: param.requires_grad = True else: param.requires_grad = False else: pass # Text Model if config.text_model_name is None: pass elif config.text_model_name == 'Qwen': self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct") from transformers import AutoModelForCausalLM text_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2-1.5B-Instruct", torch_dtype=torch.bfloat16, device_map="cpu", attn_implementation="sdpa" ).model state_dict = text_model.state_dict() self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len) self.text_model.load_state_dict(state_dict, strict=False) for name, param in self.text_model.named_parameters(): # Freeze layers other than prompts if "prompts" in name: param.requires_grad = True else: param.requires_grad = False else: pass # Timer from transformers import AutoModelForCausalLM timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True) state_dict = timer.state_dict() self.timer = TimerForPrediction(timer.config, config.timer_prompt_len) self.timer.load_state_dict(state_dict, strict=False) for name, param in self.timer.named_parameters(): # Freeze layers other than prompts if "model.prompts" in name: param.requires_grad = True else: param.requires_grad = False # Vision Interaction Layer if config.vision_model_name is None: pass else: self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size) # Text Interaction Layer if config.text_model_name is None: pass else: self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size) def forward(self, input_ids = None, images = None, texts = None, labels = None): if self.config.vision_model_name is None and images is None: vision_embedding = None else: vision_embedding = self.vision_model(images) vision_embedding = vision_embedding.pooler_output vision_embedding = self.vision_interaction_layer(vision_embedding) if self.config.text_model_name is None and all(x is None for x in texts): text_embedding = None else: tokenized_texts = self.tokenizer(texts, return_tensors="pt").to(input_ids.device) text_embedding = self.text_model(**tokenized_texts) text_embedding = text_embedding.last_hidden_state[:, 0 , :] text_embedding = self.text_interaction_layer(text_embedding) out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding) out = out["logits"] if labels is not None: if self.config.forecasting_length == out.shape[-1]: loss = torch.mean(torch.square(out-labels)) # MSE else: # pretrained Timer has 96 forecasting length. This is in case of shorter forecasting length. Forecasting length larger than 96 will occure an error. loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels)) else: loss = None return { "loss": loss, "logits": out } class MultiModalTimerDataset(Dataset): # need to refactored def __init__(self, dataset_path, vision_model_name = None, dataset_text = None, forecasting_length: int = 96): self.dataset_path = dataset_path self.vision_model_name = vision_model_name self.dataset_text = dataset_text if vision_model_name is None: pass elif vision_model_name == 'CLIP': self.processor = CLIPImageProcessor() else: pass self.inputs = torch.load(os.path.join(dataset_path, "inputs.pt")) if forecasting_length: self.targets = torch.load(os.path.join(dataset_path, f"targets_{forecasting_length}.pt")) else: self.targets = torch.load(os.path.join(dataset_path, "targets.pt")) self.keys = list(self.targets.keys()) def __len__(self): return len(self.keys) def __getitem__(self, idx): img_name = self.keys[idx] if self.vision_model_name is None: images = None else: img_path = os.path.join(self.dataset_path, 'img', img_name) images = Image.open(img_path).convert("RGB") images = self.processor.preprocess(images)['pixel_values'][0] input_tensor = self.inputs[img_name].float().squeeze() target_tensor = self.targets[img_name].float().squeeze() return { "input_ids": input_tensor, "images": images, "texts": self.dataset_text, "labels": target_tensor }