import pickle
from minisom import MiniSom
import numpy as np
import cv2

import urllib.request
import uuid

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List

class InputData(BaseModel):
  data: str  # image url

app = FastAPI()

# Función para construir el modelo manualmente
def build_model():
  with open('somlucuma.pkl', 'rb') as fid:
    somecoli = pickle.load(fid)
  MM = np.loadtxt('matrizMM.txt', delimiter=" ")
  return somecoli,MM

som,MM = build_model()  # Construir el modelo al iniciar la aplicación


from scipy.ndimage import median_filter
from scipy.signal import convolve2d

def sobel(patron):
  gx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32)
  gy = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=np.float32)

  Gx = convolve2d(patron, gx, mode='valid')
  Gy = convolve2d(patron, gy, mode='valid')

  return Gx, Gy

def medfilt2(G, d=3):
  return median_filter(G, size=d)

def orientacion(patron, w):
  Gx, Gy = sobel(patron)
  Gx = medfilt2(Gx)
  Gy = medfilt2(Gy)

  m, n = Gx.shape
  mOrientaciones = np.zeros((m // w, n // w), dtype=np.float32)

  for i in range(m // w):
    for j in range(n // w):
      Gx_patch = Gx[i*w:(i+1)*w, j*w:(j+1)*w]
      Gy_patch = Gy[i*w:(i+1)*w, j*w:(j+1)*w]

      YY = np.sum(2 * Gx_patch * Gy_patch)
      XX = np.sum(Gx_patch**2 - Gy_patch**2)

      mOrientaciones[i, j] = (0.5 * np.arctan2(YY, XX) + np.pi / 2.0) * (18.0 / np.pi)

  return mOrientaciones

def redimensionar(img, h, v):
  return cv2.resize(img, (h, v), interpolation=cv2.INTER_AREA)


def prediction(som, imgurl):
  archivo = f"/tmp/test-{uuid.uuid4()}.jpg"
  urllib.request.urlretrieve(imgurl, archivo)
  Xtest = redimensionar(cv2.imread(archivo),256,256)
  Xtest = np.array(Xtest)
  Xtest = cv2.cvtColor(Xtest, cv2.COLOR_BGR2GRAY)

  orientaciones = orientacion(Xtest, w=14)
  Xtest = Xtest.astype('float32') / 255.0
  orientaciones = orientaciones.reshape(-1)
  return som.winner(orientaciones)

# Ruta de predicción
@app.post("/predict/")
async def predict(data: InputData):
  print(f"Data: {data}")
  global som
  global MM
  try:
    # Convertir la lista de entrada a un array de NumPy para la predicción
    imgurl = data.data
    print(type(data.data))
    w = prediction(som, imgurl)
    return {"prediction": MM[w]}
  except Exception as e:
    raise HTTPException(status_code=500, detail=str(e))