# -*- coding: utf-8 -*-


"""## hugging face funcs"""

import io
import matplotlib.pyplot as plt
import requests
import inflect
from PIL import Image

def load_image_from_url(url):
    return Image.open(requests.get(url, stream=True).raw)

def render_results_in_image(in_pil_img, in_results):
    plt.figure(figsize=(16, 10))
    plt.imshow(in_pil_img)

    ax = plt.gca()

    for prediction in in_results:

        x, y = prediction['box']['xmin'], prediction['box']['ymin']
        w = prediction['box']['xmax'] - prediction['box']['xmin']
        h = prediction['box']['ymax'] - prediction['box']['ymin']

        ax.add_patch(plt.Rectangle((x, y),
                                   w,
                                   h,
                                   fill=False,
                                   color="green",
                                   linewidth=2))
        ax.text(
           x,
           y,
           f"{prediction['label']}: {round(prediction['score']*100, 1)}%",
           color='red'
        )

    plt.axis("off")

    # Save the modified image to a BytesIO object
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png',
                bbox_inches='tight',
                pad_inches=0)
    img_buf.seek(0)
    modified_image = Image.open(img_buf)

    # Close the plot to prevent it from being displayed
    plt.close()

    return modified_image

def summarize_predictions_natural_language(predictions):
    summary = {}
    p = inflect.engine()

    for prediction in predictions:
        label = prediction['label']
        if label in summary:
            summary[label] += 1
        else:
            summary[label] = 1

    result_string = "In this image, there are "
    for i, (label, count) in enumerate(summary.items()):
        count_string = p.number_to_words(count)
        result_string += f"{count_string} {label}"
        if count > 1:
          result_string += "s"

        result_string += " "

        if i == len(summary) - 2:
          result_string += "and "

    # Remove the trailing comma and space
    result_string = result_string.rstrip(', ') + "."

    return result_string


##### To ignore warnings #####
import warnings
import logging
from transformers import logging as hf_logging

def ignore_warnings():
    # Ignore specific Python warnings
    warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
    warnings.filterwarnings("ignore", message="Could not find image processor class")
    warnings.filterwarnings("ignore", message="The `max_size` parameter is deprecated")

    # Adjust logging for libraries using the logging module
    logging.basicConfig(level=logging.ERROR)
    hf_logging.set_verbosity_error()

########

import numpy as np
import torch
import matplotlib.pyplot as plt


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3),
                                np.array([0.6])],
                               axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0),
                               w,
                               h, edgecolor='green',
                               facecolor=(0,0,0,0),
                               lw=2))

def show_boxes_on_image(raw_image, boxes):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_on_image(raw_image, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_and_boxes_on_image(raw_image,
                                   boxes,
                                   input_points,
                                   input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points_and_boxes_on_image(raw_image,
                                   boxes,
                                   input_points,
                                   input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0],
               pos_points[:, 1],
               color='green',
               marker='*',
               s=marker_size,
               edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0],
               neg_points[:, 1],
               color='red',
               marker='*',
               s=marker_size,
               edgecolor='white',
               linewidth=1.25)


def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


def show_mask_on_image(raw_image, mask, return_image=False):
    if not isinstance(mask, torch.Tensor):
      mask = torch.Tensor(mask)

    if len(mask.shape) == 4:
      mask = mask.squeeze()

    fig, axes = plt.subplots(1, 1, figsize=(15, 15))

    mask = mask.cpu().detach()
    axes.imshow(np.array(raw_image))
    show_mask(mask, axes)
    axes.axis("off")
    plt.show()

    if return_image:
      fig = plt.gcf()
      return fig2img(fig)




def show_pipe_masks_on_image(raw_image, outputs, return_image=False):
  plt.imshow(np.array(raw_image))
  ax = plt.gca()
  for mask in outputs["masks"]:
      show_mask(mask, ax=ax, random_color=True)
  plt.axis("off")
  plt.show()
  if return_image:
      fig = plt.gcf()
      return fig2img(fig)

"""## imports"""

from transformers import pipeline
from transformers import SamModel, SamProcessor
from transformers import BlipForImageTextRetrieval
from transformers import AutoProcessor

from transformers.utils import logging
logging.set_verbosity_error()
#ignore_warnings()

import io
import matplotlib.pyplot as plt
import requests
import inflect
from PIL import Image

import os
import gradio as gr

import time

"""# Object detection

## hugging face model ("facebook/detr-resnet-50"). 167MB
"""

od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")

chosen_model = pipeline("object-detection", "hustvl/yolos-small")

"""## gradio funcs"""

def get_object_detection_prediction(model_name, raw_image):
  model = od_pipe
  if "chosen-model" in model_name:
    model = chosen_model
  start = time.time()
  pipeline_output = model(raw_image)
  end = time.time()
  elapsed_result = f'{model_name} object detection elapsed {end-start} seconds'
  print(elapsed_result)
  processed_image = render_results_in_image(raw_image, pipeline_output)
  return [processed_image, elapsed_result]

"""# Image segmentation

## hugging face models: Zigeng/SlimSAM-uniform-77(segmentation) 39MB, Intel/dpt-hybrid-midas(depth) 490MB
"""

hugging_face_segmentation_pipe = pipeline("mask-generation", "Zigeng/SlimSAM-uniform-77")
hugging_face_segmentation_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
hugging_face_segmentation_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
hugging_face_depth_estimator = pipeline(task="depth-estimation", model="Intel/dpt-hybrid-midas")

"""## chosen models: facebook/sam-vit-base(segmentation) 375MB, LiheYoung/depth-anything-small-hf(depth) 100MB"""

chosen_name = "facebook/sam-vit-base"
chosen_segmentation_pipe = pipeline("mask-generation", chosen_name)
chosen_segmentation_model = SamModel.from_pretrained(chosen_name)
chosen_segmentation_processor = SamProcessor.from_pretrained(chosen_name)
chosen_depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")

"""## gradio funcs"""

input_points = [[[1600, 700]]]

def segment_image_pretrained(model_name, raw_image):
  processor = hugging_face_segmentation_processor
  model = hugging_face_segmentation_model
  if("chosen" in model_name):
    processor = chosen_segmentation_processor
    model = chosen_segmentation_model
  start = time.time()
  inputs = processor(raw_image,
                  input_points=input_points,
                  return_tensors="pt")
  with torch.no_grad():
    outputs = model(**inputs)
  predicted_masks = processor.image_processor.post_process_masks(
    outputs.pred_masks,
    inputs["original_sizes"],
    inputs["reshaped_input_sizes"])
  results = []
  predicted_mask = predicted_masks[0]
  end = time.time()
  elapsed_result = f'{model_name} pretrained image segmentation elapsed {end-start} seconds'
  print(elapsed_result)
  for i in range(3):
    results.append(show_mask_on_image(raw_image, predicted_mask[:, i], return_image=True))
  results.append(elapsed_result);
  return results

def segment_image(model_name, raw_image):
  model = hugging_face_segmentation_pipe
  if("chosen" in model_name):
    print("chosen model used")
    model = chosen_segmentation_pipe
  start = time.time()
  output = model(raw_image, points_per_batch=32)
  end = time.time()
  elapsed_result = f'{model_name} raw image segmentation elapsed {end-start} seconds'
  print(elapsed_result)
  return [show_pipe_masks_on_image(raw_image, output, return_image = True), elapsed_result]

def depth_image(model_name, input_image):
  depth_estimator = hugging_face_depth_estimator
  print(model_name)
  if("chosen" in model_name):
    print("chosen model used")
    depth_estimator = chosen_depth_estimator
  start = time.time()
  out = depth_estimator(input_image)
  prediction = torch.nn.functional.interpolate(
      out["predicted_depth"].unsqueeze(0).unsqueeze(0),
      size=input_image.size[::-1],
      mode="bicubic",
      align_corners=False,
  )
  end = time.time()
  elapsed_result = f'{model_name} Depth Estimation elapsed {end-start} seconds'
  print(elapsed_result)
  output = prediction.squeeze().numpy()
  formatted = (output * 255 / np.max(output)).astype("uint8")
  depth = Image.fromarray(formatted)
  return [depth, elapsed_result]

"""# Image retrieval

## hugging face model: Salesforce/blip-itm-base-coco 900MB
"""

hugging_face_retrieval_model = BlipForImageTextRetrieval.from_pretrained(
    "Salesforce/blip-itm-base-coco")
hugging_face_retrieval_processor = AutoProcessor.from_pretrained(
    "Salesforce/blip-itm-base-coco")

"""## chosen model: Salesforce/blip-itm-base-flickr 900MB"""

chosen_retrieval_model = BlipForImageTextRetrieval.from_pretrained(
    "Salesforce/blip-itm-base-flickr")
chosen_retrieval_processor = AutoProcessor.from_pretrained(
    "Salesforce/blip-itm-base-flickr")

"""## gradion func"""

def retrieve_image(model_name, raw_image, predict_text):
  processor = hugging_face_retrieval_processor
  model = hugging_face_retrieval_model
  if("chosen" in model_name):
    processor = chosen_retrieval_processor
    model = chosen_retrieval_model
  start = time.time()
  inputs = processor(images=raw_image,
                   text=predict_text,
                   return_tensors="pt")
  end = time.time()
  elapsed_result = f"{model_name} image retrieval elapsed {end-start} seconds"
  print(elapsed_result)
  itm_scores = model(**inputs)[0]
  itm_score = torch.nn.functional.softmax(itm_scores,dim=1)
  return [f"""\
    The image and text are matched \
    with a probability of {itm_score[0][1]:.4f}""",
          elapsed_result]

"""# gradio"""

with gr.Blocks() as object_detection_tab:
  gr.Markdown("# Detect objects on image")
  gr.Markdown("Upload an image, choose model, press button.")

  with gr.Row():
      with gr.Column():
          # Input components
          input_image = gr.Image(label="Upload Image", type="pil")
          model_selector = gr.Dropdown(["hugging-face(facebook/detr-resnet-50)", "chosen-model(hustvl/yolos-small)"],
                                        label = "Select Model")

      with gr.Column():
          # Output image
          elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
          output_image = gr.Image(label="Output Image", type="pil")

  # Process button
  process_btn = gr.Button("Detect objects")

  # Connect the input components to the processing function
  process_btn.click(
      fn=get_object_detection_prediction,
      inputs=[
          model_selector,
          input_image
      ],
      outputs=[output_image, elapsed_result]
  )

with gr.Blocks() as image_segmentation_detection_tab:
  gr.Markdown("# Image segmentation")
  gr.Markdown("Upload an image, choose model, press button.")

  with gr.Row():
      with gr.Column():
          # Input components
          input_image = gr.Image(label="Upload Image", type="pil")
          model_selector = gr.Dropdown(["hugging-face(Zigeng/SlimSAM-uniform-77)", "chosen-model(facebook/sam-vit-base)"],
                                        label = "Select Model")

      with gr.Column():
          elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
          # Output image
          output_image = gr.Image(label="Segmented image", type="pil")
  with gr.Row():
    with gr.Column():
      segment_btn = gr.Button("Segment image(not pretrained)")

  with gr.Row():
    elapsed_result_pretrained_segment = gr.Textbox(label="Seconds elapsed", lines=1)
    with gr.Column():
        segment_pretrained_output_image_1 = gr.Image(label="Segmented image by pretrained model", type="pil")
    with gr.Column():
        segment_pretrained_output_image_2 = gr.Image(label="Segmented image by pretrained model", type="pil")
    with gr.Column():
      segment_pretrained_output_image_3 = gr.Image(label="Segmented image by pretrained model", type="pil")
  with gr.Row():
    with gr.Column():
      segment_pretrained_model_selector = gr.Dropdown(["hugging-face(Zigeng/SlimSAM-uniform-77)", "chosen-model(facebook/sam-vit-base)"],
                                        label = "Select Model")
      segment_pretrained_btn = gr.Button("Segment image(pretrained)")

  with gr.Row():
    with gr.Column():
        depth_output_image = gr.Image(label="Depth image", type="pil")
        elapsed_result_depth = gr.Textbox(label="Seconds elapsed", lines=1)
  with gr.Row():
    with gr.Column():
      depth_model_selector = gr.Dropdown(["hugging-face(Intel/dpt-hybrid-midas)", "chosen-model(LiheYoung/depth-anything-small-hf)"],
                                        label = "Select Model")
      depth_btn = gr.Button("Get image depth")

  segment_btn.click(
      fn=segment_image,
      inputs=[
          model_selector,
          input_image
      ],
      outputs=[output_image, elapsed_result]
  )
  segment_pretrained_btn.click(
      fn=segment_image_pretrained,
      inputs=[
          segment_pretrained_model_selector,
          input_image
      ],
      outputs=[segment_pretrained_output_image_1, segment_pretrained_output_image_2, segment_pretrained_output_image_3, elapsed_result_pretrained_segment]
  )

  depth_btn.click(
      fn=depth_image,
      inputs=[
          depth_model_selector,
          input_image,
      ],
      outputs=[depth_output_image, elapsed_result_depth]
  )

with gr.Blocks() as image_retrieval_tab:
  gr.Markdown("# Check is text describes image")
  gr.Markdown("Upload an image, choose model, press button.")

  with gr.Row():
      with gr.Column():
          # Input components
          input_image = gr.Image(label="Upload Image", type="pil")
          text_prediction = gr.TextArea(label="Describe image")
          model_selector = gr.Dropdown(["hugging-face(Salesforce/blip-itm-base-coco)", "chosen-model(Salesforce/blip-itm-base-flickr)"],
                                        label = "Select Model")

      with gr.Column():
          # Output image
          output_result = gr.Textbox(label="Probability result", lines=3)
          elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)

  # Process button
  process_btn = gr.Button("Detect objects")

  # Connect the input components to the processing function
  process_btn.click(
      fn=retrieve_image,
      inputs=[
          model_selector,
          input_image,
          text_prediction
      ],
      outputs=[output_result, elapsed_result]
  )

with gr.Blocks() as app:
  gr.TabbedInterface(
          [object_detection_tab,
           image_segmentation_detection_tab,
           image_retrieval_tab],
          ["Object detection",
           "Image segmentation",
           "Retrieve image"
          ],
      )

app.launch(share=True, debug=True)

app.close()