#from openvino.runtime import Core
import gradio as gr 
import numpy as np 
from PIL import Image
import cv2
from torchvision import models,transforms
from typing import Iterable
import gradio as gr
from torch import nn 
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
import time
import torch
import intel_extension_for_pytorch as ipex

#core = Core()
def conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )
class resconv(nn.Module):
  def __init__(self,in_features,out_features):
    super(resconv,self).__init__()
    self.block=nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(in_features,out_features,3),
        nn.InstanceNorm2d(out_features),
        nn.ReLU(inplace=True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(out_features,out_features,3),
        nn.InstanceNorm2d(out_features),
        nn.ReLU(inplace=True),

    )
  def forward(self,x):
      return x+self.block(x)

def up_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class ResnUnet(nn.Module):
  def __init__(self, out_channels=32,number_of_block=9):
    super().__init__()
    out_features=64
    channels=3
    model=[nn.ReflectionPad2d(3),nn.Conv2d(3,out_features,7),nn.InstanceNorm2d(out_features),
    nn.ReLU(inplace=True),nn.MaxPool2d(3,stride=2)]
    model+=[resconv(out_features,out_features)]
    model+=[nn.Conv2d(out_features,out_features*2,3,stride=2,padding=1),nn.InstanceNorm2d(out_features),
    nn.ReLU(inplace=True)]
    model+=[resconv(out_features*2,out_features*2)]
    model+=[nn.Conv2d(out_features*2,out_features*4,3,stride=2,padding=1),nn.InstanceNorm2d(out_features),
    nn.ReLU(inplace=True)]
    model+=[resconv(out_features*4,out_features*4)]
    model+=[nn.Conv2d(out_features*4,out_features*8,3,stride=2,padding=1),nn.InstanceNorm2d(out_features),
    nn.ReLU(inplace=True)]
    model+=[resconv(out_features*8,out_features*8)]
    out_features*=8
    input_features=out_features
    for _ in range(4):
      out_features//=2
      model+=[
              nn.Upsample(scale_factor=2),
              nn.Conv2d(input_features,out_features,3,stride=1,padding=1),
              nn.InstanceNorm2d(out_features),
              nn.ReLU(inplace=True)
      ]
      input_features=out_features
    model+=[nn.ReflectionPad2d(3),nn.Conv2d(32,32,7),
    ]
    self.model=nn.Sequential(*model)
  def forward(self,x):
    return self.model(x)


model=ResnUnet().to('cpu')

# Load the state_dict
state_dict = torch.load('real_model2_onnx_compat.pt',map_location='cpu')

# Create a new state_dict without the 'module.' prefix
new_state_dict = {}
for key, value in state_dict.items():
    new_key = key.replace("module.", "")  # Remove the 'module.' prefix
    new_state_dict[new_key] = value

# Load the new state_dict into your model
model.load_state_dict(new_state_dict)
model.eval()

model = ipex.optimize(model, weights_prepack=False)

model = torch.compile(model, backend="ipex")    
    
# Read model to OpenVINO Runtime
#model_ir = core.read_model(model="Davinci_eye.onnx")
#compiled_model_ir = core.compile_model(model=model_ir, device_name='CPU')

tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
])
color_map={
    (251,244,5): 1,
    (37,250,5):2,
    (0,21,209):3,
    (172,21,2): 4,
    (172,21,229): 5,
    (6,254,249): 6,
    (141,216,23):7,
    (96,13,13):8,
    (65,214,24):9,
    (124,3,252):10,
    (214,55,153):11,
    (48,61,173):12,
    (110,31,254):13,
    (249,37,14):14,
    (249,137,254):15,
    (34,255,113):16,
    (169,52,14):17,
    (124,49,176):18,
    (4,88,238):19,
    (115,214,178):20,
    (115,63,178):21,
    (115,214,235):22,
    (63,63,178): 23,
    (130,34,26):24,
    (220,158,161):25,
    (201,117,56):26,
    (121,16,40):27,
    (15,126,0):28,
    (0,50,70):29,
    (20,20,0):30,
    (20,20,0):31,
     }
colormap={v:[i for i in k] for k,v in color_map.items()}
items = {
    1: "HarmonicAce_Head",
    2: "HarmonicAce_Body",
    3: "MarylandBipolarForceps_Head",
    4: "MarylandBipolarForceps_Wrist",
    5: "MarylandBipolarForceps_Body",
    6: "CadiereForceps_Head",
    7: "CadiereForceps_Wrist",
    8: "CadiereForceps_Body",
    9: "CurvedAtraumaticGrasper_Head",
    10: "CurvedAtraumaticGrasper_Body",
    11: "Stapler_Head",
    12: "Stapler_Body",
    13: "MediumLargeClipApplier_Head",
    14: "MediumLargeClipApplier_Wrist",
    15: "MediumLargeClipApplier_Body",
    16: "SmallClipApplier_Head",
    17: "SmallClipApplier_Wrist",
    18: "SmallClipApplier_Body",
    19: "SuctionIrrigation",
    20: "Needle",
    21: "Endotip",
    22: "Specimenbag",
    23: "DrainTube",
    24: "Liver",
    25: "Stomach",
    26: "Pancreas",
    27: "Spleen",
    28: "Gallbladder",
    29:"Gauze",
    30:"TheOther_Instruments",
    31:"TheOther_Tissues",


}
class Davinci_Eye(Base):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.stone,
        secondary_hue: colors.Color | str = colors.blue,
        neutral_hue: colors.Color | str = colors.gray,
        spacing_size: sizes.Size | str = sizes.spacing_md,
        radius_size: sizes.Size | str = sizes.radius_md,
        text_size: sizes.Size | str = sizes.text_lg,
        font: fonts.Font
        | str
        | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("IBM Plex Mono"),
            "ui-sans-serif",
            "sans-serif",
        ),
        font_mono: fonts.Font
        | str
        | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("IBM Plex Mono"),
            "ui-monospace",
            "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            spacing_size=spacing_size,
            radius_size=radius_size,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )


davincieye = Davinci_Eye()


def convert_mask_to_rgb(pred_mask):

    rgb_mask=np.zeros((pred_mask.shape[0],pred_mask.shape[1],3),dtype=np.uint8)
    for k,v in colormap.items():
        rgb_mask[pred_mask==k]=v
    return rgb_mask


def segment_image(filepath,tag):
  image=cv2.imread(filepath)
  image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
  image = cv2.resize(image, (224,224))
  x=tfms(image.copy()/255.)
  with torch.no_grad():
    mask=model(x.unsqueeze(0).float())
#ort_input={ort_session.get_inputs()[0].name:x.cpu().unsqueeze(0).float().numpy()}
#out=ort_session.run(None,ort_input)
  _,pred_mask=torch.max(mask,dim=1)
  pred_mask=pred_mask[0].numpy()
  pred_mask=pred_mask.astype(np.uint8)
  color_mask=convert_mask_to_rgb(pred_mask)
  masked_image=cv2.addWeighted(image,0.3,color_mask,0.8,0.2)
  pred_keys=pred_mask[np.nonzero(pred_mask)]
  objects=[items[k] for k in pred_keys]
  surgery_items=np.unique(np.array(objects),axis=0)
  surg=""
  for item in surgery_items:
        surg+=item+","+" "
  return Image.fromarray(masked_image),surg


demo=gr.Interface(fn=segment_image,inputs=[gr.Image(type='filepath')],
                  outputs=[gr.Image(type="pil"),gr.Text()],
                  examples=["R001_ch1_video_03_00-29-13-03.jpg",
                            "R002_ch1_video_01_01-07-25-19.jpg",
                            "R003_ch1_video_05_00-22-42-23.jpg",
                            "R004_ch1_video_01_01-12-22-00.jpg",
                            "R005_ch1_video_03_00-19-10-11.jpg",
                           "R006_ch1_video_01_00-45-02-10.jpg",
                           "R013_ch1_video_03_00-40-17-11.jpg"],
                  #examples='R001_ch1_video_03_00-29-13-03.jpg',
                  theme=davincieye.set(loader_color='#65aab1'),
                  title="Davinci Eye(Quantized for CPU)")
demo.launch()