import logging
import uvicorn
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
import os

os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'

logging.basicConfig(
    format='%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s',
    level=logging.DEBUG,
    datefmt='%Y-%m-%d %H:%M:%S'
)
classifier = pipeline("zero-shot-classification", model="models/classificator", use_fast=False)


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


tokenizer = AutoTokenizer.from_pretrained('models/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('models/all-MiniLM-L6-v2')

app = FastAPI()


class RequestData(BaseModel):
    multiLabel: bool
    sequence: str
    labels: list[str]


class ResponseData(BaseModel):
    sequence: str
    labels: list[str]
    scores: list[float]


def classify(data: RequestData):
    return classifier(data.sequence, data.labels, multi_label=data.multiLabel)


def similarity(data: RequestData):
    sentences = [data.sequence]
    sentences.extend(data.labels)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    with torch.no_grad():
        model_output = model(**encoded_input)

    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    text_probs = sentence_embeddings[:1] @ sentence_embeddings[1:].T
    return text_probs.tolist()[0]


@app.post("/classify", response_model=ResponseData, tags=["Classificator"])
async def classify_text(data: RequestData):
    result = classify(data)
    logging.info(result)
    return result


@app.post("/similarity", response_model=ResponseData, tags=["Similarity"])
async def classify_text(data: RequestData):
    result = similarity(data)
    logging.info(result)
    return ResponseData.model_validate({
        "sequence": data.sequence,
        "labels": data.labels,
        "scores": result
    })


@app.get("/ping", tags=["TEST"])
async def ping():
    return "pong"


if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8000)