from fastapi import FastAPI, Request
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
from pydantic import BaseModel
from typing import Optional
from sklearn.feature_extraction.text import CountVectorizer
import yake


app = FastAPI()


class InputText(BaseModel):
    text : str
    threshold: float = 0.


model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
sentiment_model = AutoModelForSequenceClassification.from_pretrained(model_name)
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_name)
sentiment_model.config.id2label[3] = "mixed"

model_name = 'qanastek/51-languages-classifier'
language_model = AutoModelForSequenceClassification.from_pretrained(model_name)
language_tokenizer = AutoTokenizer.from_pretrained(model_name)


language = "id"
max_ngram_size = 3
deduplication_threshold = 0.6
deduplication_algo = 'seqm'
windowSize = 3
numOfKeywords = 20

kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=numOfKeywords, features=None)


ner_model = "syafiqfaray/indobert-model-ner"
ner = pipeline(
    "ner", 
    ner_model, 
    aggregation_strategy="simple",
)



@app.get("/")
def greet_json():
    return {"Hello": "World!"}


@app.post("/key_phrase_extraction")
async def key_phrase_extraction(inp: InputText):
    return [{"label": x["word"], "score": float(x["score"])} for x in ner(inp.text) if x["score"] > inp.threshold and x["entity_group"] != "CRD"]

# @app.post("/key_phrase_extraction")
# async def key_phrase_extraction(inp:InputText):
#     def merge_keyphrases(keyphrases):
#         new_merged = keyphrases
#         while True:
#           merged = [new_merged[0]]
#           for i in range(1, len(keyphrases)):
#               keys = keyphrases[i]
#               keys_prev = keyphrases[i-1]
#               label = keys["label"]
#               score = keys["score"]
#               vectorizer = CountVectorizer(ngram_range=( 1,len(label.split(" ")) ), lowercase=False)
#               analyzer = vectorizer.build_analyzer()
#               for key in analyzer(label)[::-1]:
#                   key_prev = keys_prev["label"][::-1]
#                   if key == key_prev[:len(key)][::-1].strip():
#                       label = key_prev[len(key):][::-1].strip() + " " + label#.replace(key, "")
#                       score = max(keys_prev["score"],keys["score"])
#                       merged.pop()
#                       break
#               merged.append({"label":label.strip(), "score":score})
#           if new_merged == merged:
#             break
#           else:
#             new_merged = merged
#         return merged
    
#     keywords = kw_extractor.extract_keywords(inp.text)
    
#     return merge_keyphrases([{"label":key[0], "score":1-key[1]} for key in keywords if 1-key[1]>inp.threshold])


@app.post("/language_detection")
async def language_detection(inp: InputText):
    inputs = language_tokenizer(inp.text, return_tensors='pt')
    with torch.no_grad():
        logits = language_model(**inputs).logits
    
    softmax = torch.nn.functional.sigmoid(logits)
    
    # Apply the threshold by creating a mask
    mask = softmax >= inp.threshold
    
    # Filter the tensor based on the threshold
    filtered_x = softmax[mask]
    
    # Get the sorted indices of the filtered tensor
    sorted_indices = torch.argsort(filtered_x, descending=True)
    
    # Map the sorted indices back to the original tensor indices
    original_indices = torch.nonzero(mask, as_tuple=True)[1][sorted_indices]
    
    return [{"label":language_model.config.id2label[predicted_class_id.tolist()], "score":softmax[0, predicted_class_id].tolist()} for predicted_class_id in original_indices]


@app.post("/sentiment_score")
async def sentiment_score(inp: InputText):
    text = inp.text
    inputs = sentiment_tokenizer(text[:2500], return_tensors='pt')
    
    with torch.no_grad():
        logits = sentiment_model(**inputs).logits #+ 1
    
    
    print(logits)
    
    logits = logits + logits[0,1].abs()
    
    # print(torch.nn.functional.sigmoid(logits))
    
    # logits = logits / 10
    
    # print(logits)
    
    # print(torch.abs(logits[0,0] - logits[0,-1]))
    # print(logits[0,1]//torch.max(torch.abs(logits[0,::2])))
    
    logits = torch.cat(
        (
            logits, (
                # ( logits[0,1] + torch.sign(logits[0,0] - logits[0,-1]) * (logits[0,0] - logits[0,-1])/2 )/2 + 
                # (logits[0,0] + logits[0,-1])/20
                (1 - torch.abs(logits[0,0] - logits[0,-1])*(2+(logits[0,1]//torch.max(torch.abs(logits[0,::2])))))
                ).unsqueeze(0).unsqueeze(0)
        ), dim=-1
    )
    
    softmax = torch.nn.functional.softmax(
        logits, 
        dim=-1
    )
    
    return [{"label":sentiment_model.config.id2label[predicted_class_id.tolist()], "score":softmax[0, predicted_class_id].tolist()} for predicted_class_id in softmax.argsort(dim=-1, descending=True)[0]]