import os
import sys
import re
import numpy as np
import cv2
import json
import yaml
import vis_utils as v_uts
import struct
from cv_base import (
    Faces, Aux, Obj, DEFAULT_MATERIAL
)

hasTorch = True
try:
    import torch
except:
    hasTorch = False

import functools
import pandas as pd
from tqdm import tqdm
from PIL import Image

try:
    from plyfile import PlyData
except:
    "no ply"

import pdb
b=pdb.set_trace

def default(x, val):
    return val if x is None else x


class IOShop:
    def __init__(self, name, **kwargs):
        ioFuncs = {'depth': DepthIO,
                   'image': ImageIO,
                   'flow': FlowIO,
                   'segment': SegmentIO,
                   'prob': ProbIO,
                   'video': VideoIO}

        self.io = ioFuncs[name](**kwargs)

    def load(self, file_name, **kwargs):
        return self.io.load(file_name, **kwargs)

    def dump(self, file_name, file, **kwargs):
        self.io.dump(file_name, file, **kwargs)


class BaseIO:
    def __init__(self, appex='jpg'):
        self.type = 'image'
        self.appex = appex

    def load(self, file_name):
        file_name = '%s.%s' % (file_name, self.appex)
        image = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
        assert not (image is None), '%s not exists' % file_name

        return image

    def dump(self, file_name, file):
        v_uts.mkdir_if_need(os.path.dirname(file_name))
        file_name = '%s.%s' % (file_name, self.appex)
        cv2.imwrite(file_name, file)


class ImageIO(BaseIO):
    def __init__(self, appex='jpg'):
        super(ImageIO, self).__init__(appex=appex)
        self.type = 'image'

    def load(self, file_name):
        if file_name.endswith('heic') or file_name.endswith('HEIC'):
            byte = read2byte(file_name)
            image = decodeImage(byte)
        else:
            image = super(ImageIO, self).load(file_name)

        return image

    @staticmethod
    def imwrite(file_name, data, order='rgb'):
        cv2.imwrite(file_name, data[:, :, ::-1])


class SegmentIO(BaseIO):
    def __init__(self):
        super(SegmentIO, self).__init__(appex='png')
        self.type = 'segment'


class ProbIO(BaseIO):
    def __init__(self):
        super(ProbIO, self).__init__()
        self.type = 'prob'
        self.max_class = 4

    def load(self, file_name, channels=None):
        image = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
        channels = default(channels, self.max_class)
        output = np.zeros(image.shape[:2])
        # for i in range(channels):
            

    def dump(self, file_name, file):
        """ 
            height, width, channel
        """
        output = np.zeros((height, width), dtype=np.uint16)
        h, w, c = file.shape
        for i in range(c):
            output = output + np.uint16(file[:, :, i] * 255) + i * 256

        cv2.imwrite(file_name, output.astype('uint16'))

        

class MeshIO(BaseIO):
    def __init__(self):
        super().__init__(appex='obj')
        self.type = 'mesh'

    def dump_obj(self, filename, obj):
        export_obj(filename, obj)

    def load_obj(self, filename):
        return load_obj(filename)


def normalize_normal(mat):
    mat = (mat / 255.0 * 2.0 - 1.0).astype('float32')
    l1 = np.linalg.norm(mat, axis=2)
    for j in range(3):
        mat[:,:,j] /= (l1 + 1e-9)
    return mat


class NormalIO(BaseIO):
    def __init__(self, xyz='rgb'):
        """
        rgb: means the normal saved in the order of x: r ...
        """
        self._xyz = xyz 

    def read(self, filename):
        normal = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
        if self._xyz == 'rgb':
            normal = normal[:, :, ::-1]

        normal = normalize_normal(normal)
        return normal


class DepthIO(BaseIO):
    def __init__(self, bit=8):
        super(DepthIO, self).__init__(appex='pfm')
        assert bit in [8, 16]
        scale = {8: 1, 16: 2}
        self.bits = scale[bit]
        self.dump_vis = True

    def load(self, path):
        """Read pfm file.
        Args:
            path (str): path to file

        Returns:
            tuple: (data, scale)
        """

        path = '%s.%s' % (path, self.appex)
        with open(path, "rb") as file:

            color = None
            width = None
            height = None
            scale = None
            endian = None

            header = file.readline().rstrip()
            if header.decode("ascii") == "PF":
                color = True
            elif header.decode("ascii") == "Pf":
                color = False
            else:
                raise Exception("Not a PFM file: " + path)

            dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
            if dim_match:
                width, height = list(map(int, dim_match.groups()))
            else:
                raise Exception("Malformed PFM header.")

            scale = float(file.readline().decode("ascii").rstrip())
            if scale < 0:
                # little-endian
                endian = "<"
                scale = -scale
            else:
                # big-endian
                endian = ">"

            data = np.fromfile(file, endian + "f")
            shape = (height, width, 3) if color else (height, width)

            data = np.reshape(data, shape)
            data = np.flipud(data)

            return data, scale

    def dump(self, path, image, scale=1):
        """Write pfm file.

        Args:
            path (str): pathto file
            image (array): data
            scale (int, optional): Scale. Defaults to 1.
        """

        v_uts.mkdir_if_need(os.path.dirname(path))
        path = path + '.pfm'

        with open(path, "wb") as file:
            color = None

            if image.dtype.name != "float32":
                raise Exception("Image dtype must be float32.")

            image = np.flipud(image)

            if len(image.shape) == 3 and image.shape[2] == 3:  # color image
                color = True
            elif (
                len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
            ):  # greyscale
                color = False
            else:
                raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")

            file.write("PF\n" if color else "Pf\n".encode())
            file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))

            endian = image.dtype.byteorder

            if endian == "<" or endian == "=" and sys.byteorder == "little":
                scale = -scale

            file.write("%f\n".encode() % scale)
            image.tofile(file)

        if self.dump_vis:
            self.dump_visualize(path[:-4], image, self.bits)

    @staticmethod
    def to8UC3(depth, scale=1000):
        """
        Convert depth image to 8UC3 format.
        """
        h, w = depth.shape
        max_depth = (256.0 ** 3 - 1) / scale

        # Clip depth values exceeding the maximum depth
        depth = np.clip(depth, 0, max_depth)

        # Scale the depth values
        value = depth * scale

        # Split the depth values into three channels
        ch = np.zeros((h, w, 3), dtype=np.uint8)
        ch[:, :, 0] = np.uint8(value / (256 ** 2))
        ch[:, :, 1] = np.uint8((value % (256 ** 2)) / 256)
        ch[:, :, 2] = np.uint8(value % 256)

        return ch


    @staticmethod
    def read8UC3(depth, scale=1000):
        """
        Convert 8UC3 image to scaled depth representation.
        """
        if isinstance(depth, str):
            depth = cv2.imread(depth, cv2.IMREAD_UNCHANGED)

        # Merge the three channels into a single depth value
        depth_uint16 = depth[:, :, 0] * (256 ** 2) + \
                    depth[:, :, 1] * 256 + depth[:, :, 2]
        # Convert depth to the scaled representation
        depth = depth_uint16.astype(np.float32) / scale

        return depth

    @staticmethod
    def dump_visualize(path, depth, bits=1):

        depth_min = depth.min()
        depth_max = depth.max()

        max_val = (2**(8*bits))-1

        if depth_max - depth_min > np.finfo("float").eps:
            out = max_val * (depth - depth_min) / (depth_max - depth_min)
        else:
            out = 0

        if bits == 1:
            cv2.imwrite(path + ".png", out.astype("uint8"))
        elif bits == 2:
            cv2.imwrite(path + ".png", out.astype("uint16"))

        return

    @staticmethod
    def load_png(path):
        depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        return depth

    @staticmethod
    def dump_png(path, depth, bits=2, max_depth=20.0):
        assert (path.endswith(".png"))
        max_val = (2**(8*bits))-1
        depth = depth / max_depth * max_val
        cv2.imwrite(path, depth.astype("uint16"))

    @staticmethod
    def read_depth(filename, scale=6000, sz=None, is_disparity=False):
        if not hasTorch:
            return None

        depth = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
        depth = np.float32(depth) / scale
        if sz:
            h, w = sz
            depth = cv2.resize(depth, (w, h), 
                        interpolation=cv2.INTER_NEAREST)
        
        depth = torch.from_numpy(depth)

        if is_disparity:   # convert to depth
            depth = 1.0 / torch.clamp(depth, min=1e-10)

        return depth

def write_depth(path, depth, grayscale, bits=1):
    """Write depth map to png file.

    Args:
        path (str): filepath without extension
        depth (array): depth
        grayscale (bool): use a grayscale colormap?
    """
    if not grayscale:
        bits = 1

    if not np.isfinite(depth).all():
        depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0)
        print("WARNING: Non-finite depth values present")

    depth_min = depth.min()
    depth_max = depth.max()

    max_val = (2**(8*bits))-1

    if depth_max - depth_min > np.finfo("float").eps:
        out = max_val * (depth - depth_min) / (depth_max - depth_min)
    else:
        out = np.zeros(depth.shape, dtype=depth.dtype)

    if not grayscale:
        out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO)

    if bits == 1:
        cv2.imwrite(path + ".png", out.astype("uint8"))
    elif bits == 2:
        cv2.imwrite(path + ".png", out.astype("uint16"))

    return


class NormalIO(BaseIO):
    def __init__(self):
        super(NormalIO, self).__init__(appex='npy')
        self.dump_vis = False

    @staticmethod
    def read_normal(filename, sz=None, to_torch=False):
        if not hasTorch:
            return None
        if not os.path.exists(filename):
            h, w = sz
            return torch.ones((h, w, 3)) * 0.3

        image = cv2.imread(filename)[:, :, ::-1]
        image = np.float32(image)
        image = (image / 127.5 - 1)
        if sz:
            h, w = sz
            image = cv2.resize(image, (w, h), 
                        interpolation=cv2.INTER_NEAREST)

        return torch.from_numpy(image)

    def to8UC3(self, normal):
        return np.uint8((normal + 1) * 127.5)


class FlowIO(BaseIO):
    def __init__(self):
        super(FlowIO, self).__init__(appex='npy')
        self.dump_vis = False

    def normalize(self, flow, shape=None):
        if shape is None:
            shape = flow.shape[:2]

        flow[:, :, 0] /= shape[1]
        flow[:, :, 1] /= shape[0]
        return flow

    def denormalize(self, flow, shape=None):
        if shape is None:
            shape = flow.shape[:2]

        flow[:, :, 0] *= shape[1]
        flow[:, :, 1] *= shape[0]
        return flow

    def visualization(self, flow):
        pass

    def load(self, path, shape=None):
        path = path + '.npy'
        flow = np.load(path)
        flow = self.denormalize(flow, shape)
        assert flow is not None
        return flow

    def dump(self, path, flow):
        v_uts.mkdir_if_need(os.path.dirname(path))
        path = path + '.npy'
        flow = self.normalize(flow)
        np.save(path, flow)

        if self.dump_vis:
            self.dump_visualize(path[:-4], flow)

    def dump_visualize(self, path, flow):
        _, flow_c = v_uts.flow2color(flow)
        cv2.imwrite(path + '.png', flow_c)


class VideoIO(BaseIO):
    def __init__(self, longside_len=None):
        super(VideoIO, self).__init__()
        self.longside_len = longside_len

    def get_fps(self, path):
        vidcap = cv2.VideoCapture(path)
        return vidcap.get(cv2.CAP_PROP_FPS)

    def load_first_frame(self, path):
        import skvideo.io as vio
        video = vio.vreader(path)
        frame = next(video)
        if self.longside_len is not None:
            frame = v_uts.resize2maxsize(frame, self.longside_len)

        return frame

    def load(self, path, sample_rate=1, max_len=1e10, 
             load_to_dir=False, 
             dir_name=None, 
             pre_len=5, 
             save_transform=None):
        import skvideo.io as vio

        def default_transform(x):
            if x.ndim == 2:
                return x
            if x.ndim == 3 and x.shape[2] == 3:
                return x[:, :, ::-1]
            return x

        frames = []
        reader = vio.vreader(path)

        if load_to_dir:
            v_uts.mkdir(dir_name)

        if save_transform is None:
            save_transform = lambda x : x

        for count, frame in enumerate(reader):
            if count == max_len:
                break
            if count % sample_rate == 0:
                if self.longside_len is not None:
                    frame = v_uts.resize2maxsize(
                        frame, self.longside_len)
                if load_to_dir:
                    img_file = f"{dir_name}/{count:05}.png"
                    frame = save_transform(frame)
                    cv2.imwrite(img_file, frame)
                else:
                    frames.append(frame)

        if not load_to_dir:
            return frames


    def load_till_end(self, path, sample_rate=1):
        import skvideo.io as vio
        frames = []
        reader = vio.vreader(path)
        count = 0
        while True:
            try:
                frame = next(reader)
            except:
                break

            if count % sample_rate == 0:
                if self.longside_len is not None:
                    frame = v_uts.resize2maxsize(
                        frame, self.longside_len)
                frames.append(frame)
            count += 1

        return frames

    def load_w_cv(self, path, out_dir, sample_rate = 1, ext="jpg"):
        v_uts.video_to_frame(path, 
                             out_dir, 
                             max_len=self.longside_len, 
                             sample_rate=sample_rate, 
                             ext=ext)

    def dump_to_images(self, frames, image_path):
        v_uts.mkdir_if_need(image_path)
        for count, frame in tqdm(enumerate(frames)):
            image_file = '%s/%04d.jpg' % (image_path, count)
            cv2.imwrite(image_file, frame[:, :, ::-1])

    def dump(self, path, frames, fps=30, lossless=False):
        from moviepy.editor import ImageSequenceClip, VideoFileClip
        if isinstance(frames[0], str):
            frame_np = []
            for frame in tqdm(frames):
                cur_frame = cv2.imread(frame, cv2.IMREAD_UNCHANGED)[:, :, ::-1]
                frame_np.append(cur_frame)
            frames = frame_np

        clip = ImageSequenceClip(frames, fps)
        if lossless:
            assert path.endswith('avi')
            clip.write_videofile(path, codec='png')
        else:
            clip.write_videofile(path, fps=fps)

    def dump_skv(self, path, frames, fps=30):
        if frames[0].ndim == 2:
            frames = [cv2.cvtColor(frame,cv2.COLOR_GRAY2RGB) for frame in frames]
        else:
            frames = [frame[:, :, ::-1] for frame in frames]
        v_uts.frame_to_video_simple(frames, fps, video_name=path)
        # import skvideo.io as vio
        # fps = str(int(fps))
        # vid_out = vio.FFmpegWriter(path,
        #                            inputdict={'-r': fps},
        #                            outputdict={
        #                                       '-vcodec': 'libx264',
        #                                       '-pix_fmt': 'yuv420p',
        #                                       '-r': fps,
        #                                   },
        #                            verbosity=1)
        # for idx, frame in enumerate(frames):
        #     vid_out.writeFrame(frame)
        # vid_out.close()

    def resave_video(self, video_file, start, end,
                     outvideo_file):
        """

        :param start: sec start
        :param end: sec end
        :return:
        """
        fps = self.get_fps(video_file)
        frames = self.load(video_file)
        start_frame = int(start * fps)
        end_frame = int(end * fps)
        frames = frames[start_frame:end_frame]
        self.dump_skv(outvideo_file, frames, fps)
        
    def frame2video(self, folder, output, ext=".jpg"):
        image_files = v_uts.list_all_files(folder, exts=[ext])
        frames = []
        for name in tqdm(image_files):
            frames.append(cv2.imread(name)[:, :, ::-1])

        self.dump(output, frames) 


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)


def read2byte(filename):
    with open(filename, 'rb') as f:
        file_data = f.read()
    return file_data


def decodeImage(bytesIo):
    import whatimage
    import pyheif
    from PIL import Image

    fmt = whatimage.identify_image(bytesIo)
    if fmt in ['heic', 'avif']:
        i = pyheif.read_heif(bytesIo)
        # Convert to other file format like jpeg
        pi = Image.frombytes(
            mode=i.mode, size=i.size, data=i.data)
        image = np.asarray(pi)
        image = image[:, :, ::-1] # to BGR
        return image
    else:
        return None


def image2Normal(imagePath):
    from skimage import io
    normal = io.imread(imagePath)
    normal = ((np.float32(normal) / 255.0) * 2 - 1.0 )
    return normal

def normal2Image(normal):
    nm_pred_val = (normal + 1.) / 2.
    nm_pred_val = np.uint8(nm_pred_val*255.)
    return nm_pred_val


def dump_normal(filename, normal):
    normal = normal2Image(normal)
    cv2.imwrite(filename + '.png', array)


def dump_prob2image(filename, array):
    """
        dump probility map to image when
        array: [x, height, width] (x = 1, 3, 4)
    """
    class_num = array.shape[0]
    # assert class_num <= 4
    if class_num >= 4 :
        print('warning: only save the first 3 channels')
        array = array[:3, :, :]

    if class_num == 2:
        raise ValueError('not implement')

    array = np.transpose(np.uint8(array * 255), (1, 2, 0))
    if filename.endswith('.png'):
        cv2.imwrite(filename, array)
        return

    cv2.imwrite(filename + '.png', array)
    assert os.path.exists(filename)


def load_image2prob(filename):
    if not filename.endswith('.png'):
        filename = filename + '.png'

    array = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
    array = np.transpose(array, (2, 0, 1)) / 255

    return array


def shape_match(images):
    assert len(images) > 1
    shape = images[0].shape[:2]
    for image in images[1:]:
        cur_shape = image.shape[:2]
        if np.sum(np.abs(np.array(shape) - \
                         np.array(cur_shape))):
            return False

    return True

def append_apex(filename, appex):
    filename = filename.split('.')
    prefix = '.'.join(filename[:-1])
    filetype = filename[-1]
    return '%s_%s.%s' % (prefix, appex, filetype)

def load_json(json_file):
    with open(json_file) as f:
        res = json.load(f)
    return res

def dump_numpy(filename, x: np.ndarray):
    np.savetxt(filename, x, delimiter=' ', fmt='%1.6f')

def dump_json(filename, odgt, w_np=False):
    with open(filename, 'w') as f:
        if not w_np:
            json.dump(odgt, f, indent=4)
        else:
            json.dump(odgt, f, indent=4, cls=NpEncoder)

def dump_jsonl(filename, odgt):
    with open(filename, 'w') as file:
        for entry in odgt:
            json.dump(entry, file)
            file.write('\n')

def dump_pair_data(image_list, 
                   label_list, 
                   outfile, 
                   root='', 
                   data_type='txt', 
                   fields=None):

    if fields is None:
        fields = ["image", "segment"]

    if data_type == 'txt':
        fp = open(outfile, 'w')
        for imagefile, labelfile in zip(image_list, label_list):
            imagefile = imagefile.replace(root, '.')
            labelfile = labelfile.replace(root, '.')
            fp.write('%s %s\n' % (imagefile, labelfile))
        fp.close()

    elif data_type == "odgt":
        odgt = []
        for imagefile, labelfile in zip(image_list, label_list):
            imagefile = imagefile.replace(root, '.')
            labelfile = labelfile.replace(root, '.')
            item = {fields[0]: imagefile, 
                    fields[1]: labelfile}
            odgt.append(item)
        dump_json(outfile, odgt)


def save_xlsx(filename, dicts, sheets=None):
    """
      Save a list of dicts to an xlsx file.
    """
    with pd.ExcelWriter(filename, mode='w') as writer:
        if sheets is None:
            df1 = pd.DataFrame(dicts)
            df1.to_excel(writer, index=False)
            return 
        for sheet in sheets:
            df1 = pd.DataFrame(dicts[sheet])
            df1.to_excel(writer, sheet_name=sheet, index=False)

def load_xlsx(filename, sheets=None):
    assert os.path.exists(filename) , f"File not found: {filename}"
    if sheets is None:
        df = pd.read_excel(filename)
        dict = {}
        for column in df.columns:
            dict[column] = df[column].tolist()
    else:
        dict = {}
        for sheet in sheets:
            df = pd.read_excel(filename, sheet_name=sheet)
            cur_dict = {}
            for column in df.columns:
                cur_dict[column] = df[column].tolist()
            print(cur_dict.keys())
            dict[sheet] = cur_dict
    print(dict.keys())
    return dict

def dump_lines(filename, file_list):
    f = open(filename, 'w')
    tbar = tqdm(file_list)
    for i, elements in enumerate(tbar):
        if isinstance(elements, (tuple, list)):
            line = ' '.join(elements)
        elif isinstance(elements, str):
            line = elements
        appex = '' if i == len(file_list)  - 1 else '\n'
        f.write('%s%s' % (line, appex))

    f.close()


def load_lines(txt_file):
    lines = [line.strip() for line in open(txt_file, 'r')]
    return lines


def load_jsonl(jsonl_file):
    # List to hold all JSON objects
    data = []

    # Open the file and read line by line
    with open(jsonl_file, 'r') as file:
        for line in file:
            # Each line is a JSON object, parse it and append to the list
            json_object = json.loads(line)
            data.append(json_object)
    return data

def load_yaml(yaml_file):
    with open(yaml_file, "r") as f:
        yaml_dict = yaml.safe_load(f) 
    return yaml_dict


def load_odgt(odgt):
    try:
        samples = [json.loads(x.rstrip()) \
                            for x in open(odgt, 'r')][0]
    except:
        samples = load_json(odgt)

    print(samples[0].keys())
    return samples

def fuse_odgt(odgt_files):
    """
        odgt_files:
    """
    odgt_full = []
    for odgt_file in odgt_files:
        odgt = load_odgt(odgt_file)
        odgt_full = odgt_full + odgt

    return odgt_full


def load_video_first_frame(video_name):
    cap = cv2.VideoCapture(video_name)
    if(cap.isOpened()):
        ret, frame = cap.read()
    else:
        raise ValueError("can not read %s" % video_name)

    return frame


def load_lines(txt_file):
    lines = [line.strip() for line in open(txt_file, 'r')]
    return lines


def load_csv(csv_file):
    import csv
    lines = []
    with open(csv_file) as f:
        reader = csv.reader(f, delimiter=',')
        for row in reader:
            lines.append(row)
    return lines[1:]


# cat multi files in to a single file
def cat_files(files, output):
    all_lines = []
    for filename in files:
        lines = load_lines(filename)
        all_lines = all_lines + lines
    dump_lines(output, all_lines)


class SkipExist:
    def __init__(self,
                 processor,
                 ioType='image',
                 need_res=False,
                 rerun=False):
        self.ioType = ioType
        self.io = IOShop(self.ioType).io
        self.processor = processor
        self.rerun = rerun
        self.need_res = need_res

    def __call__(self, *args, **kwargs):
        assert 'filename' in kwargs
        true_file = '%s.%s' % (kwargs['filename'], self.io.appex)

        if os.path.exists(true_file):
            if self.need_res:
                res = self.io.load(kwargs['filename'])
                return res
        else:
            filename = kwargs['filename']
            del kwargs['filename']
            res = self.processor(*args, **kwargs)
            self.io.dump(filename, res)


def dump_pkl(filename, data):
    import pickle as pkl
    with open(filename, "wb") as fl:
        pkl.dump(data, fl)


def load_pkl(filename):
    import pickle as pkl
    with open(filename, 'rb') as fl:
        res = pkl.load(fl)
    return res


def write_pointcloud(filename, xyz_points, faces=None, rgb_points=None):
    """ 
        creates a .pkl file of the point clouds generated
    """

    assert xyz_points.shape[1] == 3,'Input XYZ points should be Nx3 float array'
    if rgb_points is None:
        rgb_points = np.ones(xyz_points.shape).astype(np.uint8) * 255
    else:
        rgb_points = rgb_points.astype(np.uint8)
        
    assert xyz_points.shape == rgb_points.shape,\
        f'Input RGB colors should be Nx3 {rgb_points.shape} float array \
            and have same size as input XYZ points {xyz_points.shape}'

    # Write header of .ply file
    fid = open(filename,'wb')
    fid.write(bytes('ply\n', 'utf-8'))
    fid.write(bytes('format binary_little_endian 1.0\n', 'utf-8'))
    fid.write(bytes('element vertex %d\n'%xyz_points.shape[0], 'utf-8'))
    fid.write(bytes('property float x\n', 'utf-8'))
    fid.write(bytes('property float y\n', 'utf-8'))
    fid.write(bytes('property float z\n', 'utf-8'))
    fid.write(bytes('property uchar red\n', 'utf-8'))
    fid.write(bytes('property uchar green\n', 'utf-8'))
    fid.write(bytes('property uchar blue\n', 'utf-8'))
    fid.write(bytes('end_header\n', 'utf-8'))

    # Write 3D points to .ply file
    for i in range(xyz_points.shape[0]):
        fid.write(bytearray(struct.pack("fffccc",xyz_points[i,0],xyz_points[i,1],xyz_points[i,2],
                                        rgb_points[i,0].tostring(),rgb_points[i,1].tostring(),
                                        rgb_points[i,2].tostring())))
    if faces is not None:
        for face in faces:
            fid.write(struct.pack("<B", face[0]))
            fid.write(struct.pack("<{}i".format(face[0]), *face[1]))
 
    fid.close()


def read_ply(filename):
    # Load the PLY file
    ply_data = PlyData.read(filename)
    
    # Access the vertex data
    vertex_data = ply_data['vertex']
    
    # Extract x, y, z coordinates as a numpy array
    points = np.vstack((vertex_data['x'], vertex_data['y'], vertex_data['z'])).T
    
    return points


def load_obj(file_path):
    verts = []
    normals = []
    uvs = []
    material_colors = []
    texture_images = []
    texture_atlas = []

    faces_verts = []
    faces_normals = []
    faces_textures = []
    faces_materials = []

    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('v '):
                vertex = [float(v) for v in line.split()[1:]]
                verts.append(vertex)
            elif line.startswith('vn '):
                normal = [float(n) for n in line.split()[1:]]
                normals.append(normal)
            elif line.startswith('vt '):
                uv = [float(u) for u in line.split()[1:]]
                uvs.append(uv)
            elif line.startswith("mtllib "):
                mtl_name = line.split()[1]
            elif line.startswith('vc '):
                color = [float(c) for c in line.split()[1:]]
                material_colors.append(color)
            elif line.startswith('usemtl '):
                material = line.split()[1]
                texture_images.append(material)
            elif line.startswith('f '):
                face_data = line.split()[1:]
                face_verts = []
                face_normals = []
                face_textures = []
                for face in face_data:
                    res = face.split('/')
                    vert = res[0]
                    face_verts.append(int(vert))
                    if len(res) == 2:
                        texture = res[1]
                        face_textures.append(int(texture))
                    if len(res) == 3:
                        normal = res[2]
                        face_normals.append(int(normal))
                faces_verts.append(face_verts)
                faces_normals.append(face_normals)
                faces_textures.append(face_textures)
                faces_materials.append(len(texture_images) - 1)

    mtl_file = f"{os.path.dirname(file_path)}/{mtl_name}"
    with open(mtl_file, 'r') as file:
        for line in file:
            if line.startswith("map_Kd"):
                image_name = line.split()[1]
                break
    
    assert len(texture_images) == 1
    texture_name = texture_images[0]

    image = cv2.imread(f"{os.path.dirname(file_path)}/{image_name}")
    properties = Aux(
        normals=np.array(normals), 
        verts_uvs=np.array(uvs), 
        material_colors=DEFAULT_MATERIAL, 
        texture_images={texture_name: np.float32(image)/ 255.0}, 
        texture_atlas=None)
    
    faces_verts=np.array(faces_verts)
    num_faces = faces_verts.shape[0]
    faces = Faces(
        verts_idx=faces_verts, 
        normals_idx=np.ones(faces_verts.shape) * -1, 
        textures_idx=np.array(faces_textures), 
        materials_idx=np.zeros(num_faces))

    obj = Obj(np.array(verts), faces, properties)
    return obj


def export_obj(filename, obj, 
               include_normals=False, 
               include_textures=True):
    """
    Export the given object to an .obj file with optional normals and textures.
    
    Args:
        filename (str): Path to the output .obj file (without the extension).
        obj (namedtuple): Object containing vertices, faces, and properties.
        include_normals (bool): Flag to include normals in the .obj file.
        include_textures (bool): Flag to include textures in the .obj file.
    """
    material_name = list(obj.properties.texture_images.keys())[0]

    # Write obj file
    name = os.path.basename(filename)
    with open(filename + ".obj", "w") as f:
        f.write("\n")

        if include_textures:
            f.write(f"mtllib {name}.mtl\n")
            f.write("\n")

        for vert in obj.verts:
            x, y, z = vert
            f.write(f"v {x} {y} {z}\n")

        if include_textures:
            for uv in obj.properties.verts_uvs:
                x, y = uv
                f.write(f"vt {x} {y}\n")
            f.write(f"usemtl {material_name}\n")

        num_faces = obj.faces.verts_idx.shape[0]
        for i in range(num_faces):
            f0, f1, f2 = obj.faces.verts_idx[i]
            if include_textures:
                t0, t1, t2 = obj.faces.textures_idx[i]
                if t0 == -1:
                    f.write(f"f {f0} {f1} {f2}\n")
                    continue
                f.write(f"f {f0}/{t0} {f1}/{t1} {f2}/{t2}\n")
            else:
                f.write(f"f {f0} {f1} {f2}\n")

    # Write mtl file
    if include_textures:
        output_dir = os.path.dirname(filename)
        with open(f"{output_dir}/{name}.mtl", "w") as f:
            f.write(f"newmtl {material_name}\n")
            f.write(f"map_Kd {name}.png\n")

            material_colors = obj.properties.material_colors[material_name]
            r, g, b = material_colors["ambient_color"]
            f.write(f"Ka {r} {g} {b}\n")
            r, g, b = material_colors["diffuse_color"]
            f.write(f"Kd {r} {g} {b}\n")
            r, g, b = material_colors["specular_color"]
            f.write(f"Ks {r} {g} {b}\n")
            s = material_colors["shininess"]
            f.write(f"Ns {s}\n")

        # Save texture image
        image = obj.properties.texture_images[material_name] * 255
        texture_img = f"{output_dir}/{name}.png"
        cv2.imwrite(texture_img, image)

    return


def resave_to_video():
    folder = "/Users/peng/Downloads/DenseAR/Mesh/"
    
    vname = "0037438511"
    image_num = 125
    frames = []
    d_frames = []
    crop = [0, 650, 1080, 1270]
    for i in tqdm(range(image_num)):
        name = f"{folder}/{vname}/{i}.jpg"
        d_name = f"{folder}/{vname}/{i}.tiff"
        img = np.array(Image.open(name))
        depth = np.array(Image.open(d_name))
        if img is None:
            continue
        img = img[crop[0]:crop[2], crop[1]:crop[3]]
        depth = depth[crop[0]:crop[2], crop[1]:crop[3]]
        depth = 1.0 / np.maximum(depth, 1e-10)
        depth = p_uts.depth2color(depth, max_d=50)
        frames.append(img)
        d_frames.append(depth)
        
    vio = io_uts.VideoIO()
    video_file = f"{folder}/{vname}.mp4"
    d_video_file = f"{folder}/{vname}_d.mp4"
    vio.dump_skv(video_file, frames, fps=24)
    vio.dump_skv(d_video_file, d_frames, fps=24)


def test_depth_8uc3_encode():
    depth = np.random.rand(480, 640) * 200
    dio = DepthIO()
    depth_encode = dio.to8UC3(depth)
    depth_decode = dio.read8UC3(depth_encode)
    print(depth, depth_decode)
    assert np.sum(np.abs(depth - depth_decode)) / (480 * 640) < 1e-3


########### copy from gta code ################
@functools.lru_cache()
def build_mesh(w, h):
    w = np.linspace(-1.0, 1.0, num=w, dtype=np.float32)
    h = np.linspace(1.0, -1.0, num=h, dtype=np.float32)
    return np.stack(np.meshgrid(w, h), axis=0)


def build_proj_matrix(fov, aspect):
    proj = np.zeros((4, 4))
    proj[0, 0] = 1.0 / np.tan(np.radians(fov / 2)) / aspect
    proj[1, 1] = 1.0 / np.tan(np.radians(fov / 2))
    proj[2, 2] = 0.00001502  # reverse-engineered get from shader
    proj[2, 3] = 0.15000225  # reverse-engineered get from shader
    proj[3, 2] = -1.0
    return proj


def zbuffer_to_depth(zbuffer, fov):
    height, width = zbuffer.shape[:2]
    aspect = width / height

    mesh = build_mesh(width, height)

    if len(zbuffer.shape) != 3:
        zbuffer = np.expand_dims(zbuffer, 0)

    pcloud = np.concatenate((mesh, zbuffer, np.ones_like(zbuffer)), 0)
    pcloud = pcloud.reshape(4, height * width)

    proj_matrix = build_proj_matrix(fov, aspect)

    pcloud = np.linalg.inv(proj_matrix) @ pcloud
    depth = -pcloud[2] / pcloud[3]

    focal_cv = proj_matrix[0, 0] * width / 2.0

    return depth.reshape(height, width), focal_cv

def test_zbuffer_to_depth():
    # root = "E:/Dataset/GTA/Stereo_0/"
    # name = root + "1-130423915874"
    name = "E:/depth_video/0036696165/1"
    config = load_json(name + ".json")
    fov = config["fov"]
    zbuffer = cv2.imread(name + ".tiff", cv2.IMREAD_UNCHANGED)
    depth, focal = zbuffer_to_depth(zbuffer, fov)
    print(depth)

def fuse_frames_of_depth_video():
    """
        frames: list of images or video
    """
    def frame_to_video(video_dir, video_name):
        frames = v_uts.list_all_files(video_dir, exts=['jpg'])
        rgb_video = f"{video_name}.mp4"
        depth_video = f"{video_name}_d.avi"
        cam_file = f"{video_name}.json"

        dio = DepthIO()
        imgs = []
        depths = []
        cams = []
        print("seq len:", len(frames))
        for i, frame in tqdm(enumerate(frames)):
            name = f"{video_dir}/{i}.jpg"
            d_name = f"{video_dir}/{i}.tiff"
            c_name = f"{video_dir}/{i}.json"
            img = np.array(Image.open(name))
            depth = np.array(Image.open(d_name))
            cam = load_json(c_name)
            depth, focal = zbuffer_to_depth(depth, cam['fov'])
            depth = dio.to8UC3(depth)
            imgs.append(img)
            depths.append(depth)
            cam['focal'] = focal
            cams.append(cam)
            # if i > 30:
            #     break
        
        vio = VideoIO()
        vio.dump(rgb_video, imgs)
        vio.dump(depth_video, depths, lossless=True)
        dump_json(cam_file, cams)
        
    folder = "E:/depth_video/"
    output = "E:/depth_video_resave/"

    v_uts.mkdir_if_need(output)
    folder_names = v_uts.list_all_folders(folder)
    for folder_name in tqdm(folder_names[1:]):
        folder_name = folder_name.replace('\\', '/')
        vid_name = folder_name.split('/')[-2]
        print(folder_name, vid_name)
        output_video = f"{output}/{vid_name}"
        frame_to_video(folder_name, video_name=output_video)
        # break


def save_xlsx(filename, dicts, sheets=None):
    with pd.ExcelWriter(filename, mode='w') as writer:
        if sheets is None:
            df1 = pd.DataFrame(dicts)
            df1.to_excel(writer, index=False)
            return 
        for sheet in sheets:
            df1 = pd.DataFrame(dicts[sheet])
            df1.to_excel(writer, sheet_name=sheet, index=False)

def load_xlsx(filename, sheets=None):
    assert os.path.exists(filename) , f"File not found: {filename}"
    if sheets is None:
        df = pd.read_excel(filename)
        dict = {}
        for column in df.columns:
            dict[column] = df[column].tolist()
    else:
        dict = {}
        for sheet in sheets:
            df = pd.read_excel(filename, sheet_name=sheet)
            cur_dict = {}
            for column in df.columns:
                cur_dict[column] = df[column].tolist()
            print(cur_dict.keys())
            dict[sheet] = cur_dict
    print(dict.keys())
    return dict


def get_sheet_list(dict, sheets=None, key="url"):
    images_list = [dict[key]] if sheets is None else [dict[sheet_name][key] for sheet_name in sheets]
    images_full = []

    for images, sheet in zip(images_list, sheets):
        print(f"{sheet}: {len(images)}")
        images_full = images_full + images

    return images_full

def test_load_save_obj():
    image_name = "000000243355_zebra"
    obj = f"./unit_test/{image_name}.obj"
    obj = load_obj(obj)
    export_obj(f"./unit_test/{image_name}_resave", obj)




if __name__ == '__main__':
    # test = [(1,2), (3,4)]
    # dump_pkl('test.pkl', test)
    # print(load_pkl('test.pkl'))
    # xyz = np.random.rand(1000, 3)
    # write_pointcloud("test.ply", xyz)

    # xyz = np.random.rand(1000, 3)
    # write_pointcloud("test.ply", xyz)
    # pass
    # test_depth_8uc3_encode()
    # test_zbuffer_to_depth()
    # fuse_frames_of_depth_video()
    test_load_save_obj()