import os
import cv2
import torch
import random
import logging
import tempfile
import numpy as np
from copy import copy
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from utils.registry_class import DATASETS

@DATASETS.register_class()
class ImageDataset(Dataset):
    def __init__(self, 
            data_list, 
            data_dir_list,
            max_words=1000,
            vit_resolution=[224, 224],
            resolution=(384, 256),
            max_frames=1,
            transforms=None,
            vit_transforms=None,
            **kwargs):
        
        self.max_frames = max_frames
        self.resolution = resolution
        self.transforms = transforms
        self.vit_resolution = vit_resolution
        self.vit_transforms = vit_transforms

        image_list = []
        for item_path, data_dir in zip(data_list, data_dir_list):
            lines = open(item_path, 'r').readlines()
            lines = [[data_dir, item.strip()] for item in lines]
            image_list.extend(lines)
        self.image_list = image_list

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        data_dir, file_path = self.image_list[index]
        img_key = file_path.split('|||')[0]
        try:
            ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path)
        except Exception as e:
            logging.info('{} get frames failed... with error: {}'.format(img_key, e))
            caption = ''
            img_key = ''
            ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
            vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
            video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) 
        return ref_frame, vit_frame, video_data, caption, img_key

    def _get_image_data(self, data_dir, file_path):
        frame_list = []
        img_key, caption = file_path.split('|||')
        file_path = os.path.join(data_dir, img_key)
        for _ in range(5):
            try:
                image = Image.open(file_path)
                if image.mode != 'RGB':
                    image = image.convert('RGB')
                frame_list.append(image) 
                break
            except Exception as e:
                logging.info('{} read video frame failed with error: {}'.format(img_key, e))
                continue
        
        video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
        try:
            if len(frame_list) > 0:
                mid_frame = frame_list[0]
                vit_frame = self.vit_transforms(mid_frame)
                frame_tensor = self.transforms(frame_list)
                video_data[:len(frame_list), ...] = frame_tensor
            else:
                vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
        except:
            vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
        ref_frame = copy(video_data[0])
        
        return ref_frame, vit_frame, video_data, caption