# -*- coding: utf-8 -*-
import os
import torch
import argparse
import numpy as np
import open3d as o3d
from huggingface_hub import hf_hub_download, HfFolder

from segment import seg_point, seg_box, seg_mask
import sam2point.dataset as dataset
import sam2point.configs as configs
from sam2point.voxelizer import Voxelizer
from sam2point.utils import cal
import matplotlib.pyplot as plt
import plotly.graph_objects as go

print("Torch CUDA:", torch.cuda.is_available())
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

def run_demo(dataset_name, prompt_type, sample_idx, prompt_idx, voxel_size, theta, mode, ret_prompt):
   parser = argparse.ArgumentParser()
   parser.add_argument('--dataset', choices=['S3DIS', 'ScanNet', 'Objaverse', 'KITTI', 'Semantic3D'], default='Objaverse', help='dataset selected')
   parser.add_argument('--prompt_type', choices=['point', 'box', 'mask'], default='point', help='prompt type selected')
   parser.add_argument('--sample_idx', type=int, default=2, help='the index of the scene or object')
   parser.add_argument('--prompt_idx', type=int, default=0, help='the index of the prompt')   
   parser.add_argument('--voxel_size', type=float, default=0.02, help='voxel size')   
   parser.add_argument('--theta', type=float, default=0.5)  
   parser.add_argument('--mode', type=str, default='bilinear') 
   parser.add_argument("--ret_prompt", action="store_true")
   args = parser.parse_args()
   args.dataset, args.prompt_type, args.sample_idx, args.prompt_idx = dataset_name, prompt_type, sample_idx, prompt_idx
   args.voxel_size, args.theta, args.mode, args.ret_prompt = voxel_size, theta, mode, ret_prompt
   print(args)

   name_list = [args.dataset, "sample" + str(args.sample_idx), args.prompt_type + "-prompt" + str(args.prompt_idx)]
   name = '_'.join(name_list)

   # use cache result for speeding up
   repo_id = "ZiyuG/Cache"
   result_name = "cache_results/" + name + '.npy'
   prompt_name = "cache_prompt/" + name + '.npy'
   token = os.getenv('HF_TOKEN')

   try:
       result_file = hf_hub_download(repo_id=repo_id, filename=result_name, use_auth_token=token, repo_type='dataset')
       prompt_file = hf_hub_download(repo_id=repo_id, filename=prompt_name, use_auth_token=token, repo_type='dataset')
       new_color = np.load(result_file)
       PROMPT = np.load(prompt_file)
       if not args.ret_prompt: return new_color, PROMPT
       else:   return PROMPT
   except Exception as e:
       if os.path.exists("./cache_results/" + name + '.npy') and os.path.exists("./cache_prompt/" + name + '.npy'):
           new_color = np.load("./cache_results/" + name + '.npy')
           PROMPT = np.load("./cache_prompt/" + name + '.npy')
           if not args.ret_prompt: return new_color, PROMPT
           else:   return PROMPT

   if args.dataset == 'S3DIS':
       info = configs.S3DIS_samples[args.sample_idx]
       # early return
       if args.prompt_type == 'point' and args.ret_prompt:     return list(np.array(info['point_prompts'])[args.prompt_idx])
       elif args.prompt_type == 'box' and args.ret_prompt:     return list(np.array(info['box_prompts'])[args.prompt_idx])
       point, color = dataset.load_S3DIS_sample(info['path'])
   elif args.dataset == 'ScanNet':
       info = configs.ScanNet_samples[args.sample_idx]
       # early return
       if args.prompt_type == 'point' and args.ret_prompt:     return list(np.array(info['point_prompts'])[args.prompt_idx])
       elif args.prompt_type == 'box' and args.ret_prompt:     return list(np.array(info['box_prompts'])[args.prompt_idx])
       point, color = dataset.load_ScanNet_sample(info['path'])
   elif args.dataset == 'Objaverse':
       info = configs.Objaverse_samples[args.sample_idx]
       # early return
       if args.prompt_type == 'point' and args.ret_prompt:     return list(np.array(info['point_prompts'])[args.prompt_idx])
       elif args.prompt_type == 'box' and args.ret_prompt:     return list(np.array(info['box_prompts'])[args.prompt_idx])
       point, color = dataset.load_Objaverse_sample(info['path'])
       args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx]
   elif args.dataset == 'KITTI':
       info = configs.KITTI_samples[args.sample_idx]
       # early return
       if args.prompt_type == 'point' and args.ret_prompt:     return list(np.array(info['point_prompts'])[args.prompt_idx])
       elif args.prompt_type == 'box' and args.ret_prompt:     return list(np.array(info['box_prompts'])[args.prompt_idx])
       point, color = dataset.load_KITTI_sample(info['path'])
       args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx]
   elif args.dataset == 'Semantic3D':
       info = configs.Semantic3D_samples[args.sample_idx]
       # early return
       if args.prompt_type == 'point' and args.ret_prompt:     return list(np.array(info['point_prompts'])[args.prompt_idx])
       elif args.prompt_type == 'box' and args.ret_prompt:     return list(np.array(info['box_prompts'])[args.prompt_idx])
       point, color = dataset.load_Semantic3D_sample(info['path'], args.sample_idx)
       args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx]
   
   point_color = np.concatenate([point, color], axis=1)
   voxelizer = Voxelizer(voxel_size=args.voxel_size, clip_bound=None)
  
   labels_in = point[:, :1].astype(int)
   locs, feats, labels, inds_reconstruct = voxelizer.voxelize(point, color, labels_in)

   if args.prompt_type == 'point':
       if args.ret_prompt:     return list(np.array(info['point_prompts'])[args.prompt_idx])
       mask = seg_point(locs, feats, info['point_prompts'], args)
       point_prompts = np.array(info['point_prompts'])
       prompt_point = list(point_prompts[args.prompt_idx])
       prompt_box = None
       PROMPT = prompt_point
   elif args.prompt_type == 'box':
       if args.ret_prompt:     return list(np.array(info['box_prompts'])[args.prompt_idx])
       mask = seg_box(locs, feats, info['box_prompts'], args)
       point_prompts = np.array(info['box_prompts'])
       prompt_point = None
       prompt_box = list(point_prompts[args.prompt_idx])
       PROMPT = prompt_box
   elif args.prompt_type == 'mask':
       if 'mask_prompts' not in info:  info['mask_prompts'] = info['point_prompts']
       mask, prompt_mask = seg_mask(locs, feats, info['mask_prompts'], args)
       prompt_point, prompt_box = None, None
       point_locs = locs[inds_reconstruct]
       point_prompt_mask = prompt_mask[point_locs[:, 0], point_locs[:, 1], point_locs[:, 2]]
       point_prompt_mask = point_prompt_mask.unsqueeze(-1)
       point_prompt_mask_not = ~point_prompt_mask
       color_prompt_mask = color * point_prompt_mask_not.numpy() + (color * 0 + np.array([[1., 0., 0.]])) * point_prompt_mask.numpy()
       PROMPT = color_prompt_mask
       if args.ret_prompt:    
           return color_prompt_mask
  
   point_locs = locs[inds_reconstruct]
   point_mask = mask[point_locs[:, 0], point_locs[:, 1], point_locs[:, 2]]
  
   point_mask = point_mask.unsqueeze(-1)
   point_mask_not = ~point_mask
  
   point, color = point_color[:, :3], point_color[:, 3:]
   new_color = color * point_mask_not.numpy() + (color * 0 + np.array([[0., 1., 0.]])) * point_mask.numpy()

   name_list = [args.dataset, "sample" + str(args.sample_idx), args.prompt_type + "-prompt" + str(args.prompt_idx)]
   name = '_'.join(name_list) + 'frames'

   #cache for speeding up
   name_list = [args.dataset, "sample" + str(args.sample_idx), args.prompt_type + "-prompt" + str(args.prompt_idx)]
   name = '_'.join(name_list)
   os.makedirs("cache_results", exist_ok=True)
   os.makedirs("cache_prompt", exist_ok=True)
   np.save("./cache_results/" + name + '.npy', new_color)
   np.save("./cache_prompt/" + name + '.npy', PROMPT)
   return new_color, PROMPT

def create_box(prompt):
   x_min, y_min, z_min, x_max, y_max, z_max = tuple(prompt)
   bbox_points = np.array([
       [x_min, y_min, z_min],
       [x_max, y_min, z_min],
       [x_max, y_max, z_min],
       [x_min, y_max, z_min],
       [x_min, y_min, z_max],
       [x_max, y_min, z_max],
       [x_max, y_max, z_max],
       [x_min, y_max, z_max]
   ])
   edges = [
       (0, 1), (1, 2), (2, 3), (3, 0), # Bottom face
       (4, 5), (5, 6), (6, 7), (7, 4), # Top face
       (0, 4), (1, 5), (2, 6), (3, 7)  # Vertical edges
   ]
   bbox_lines = []
   f = 1
   for start, end in edges:
       bbox_lines.append(go.Scatter3d(
           x=[bbox_points[start, 0], bbox_points[end, 0]],
           y=[bbox_points[start, 1], bbox_points[end, 1]],
           z=[bbox_points[start, 2], bbox_points[end, 2]],
           mode='lines',
           line=dict(color='rgb(220, 20, 60)', width=6),  
           name="Box Prompt" if f == 1 else "",
           showlegend=True if f == 1 else False
       ))
       f = 0
   return bbox_lines