from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool
from typing import List, Optional

class Guardrail:
    def __init__(self):
        tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        self.classifier = pipeline(
            "text-classification",
            model=model,
            tokenizer=tokenizer,
            truncation=True,
            max_length=512,
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

    async def guard(self, prompt):
        return await run_in_threadpool(self.classifier, prompt)

    def determine_level(self, label, score):
        if label == "SAFE":
            return 0, "safe"
        else:
            if score > 0.9:
                return 4, "high"
            elif score > 0.75:
                return 3, "medium"
            elif score > 0.5:
                return 2, "low"
            else:
                return 1, "very low"

class TextPrompt(BaseModel):
    prompt: str

class ClassificationResult(BaseModel):
    label: str
    score: float
    level: int
    severity_label: str

class ToxicityResult(BaseModel):
    toxicity: float
    severe_toxicity: float
    obscene: float
    threat: float
    insult: float
    identity_attack: float

    @classmethod
    def from_dict(cls, data: dict):
        return cls(**{k: float(v) for k, v in data.items()})

class TopicBannerClassifier:
    def __init__(self):
        self.classifier = pipeline(
            "zero-shot-classification",
            model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.hypothesis_template = "This text is about {}"

    async def classify(self, text, labels):
        return await run_in_threadpool(
            self.classifier,
            text,
            labels,
            hypothesis_template=self.hypothesis_template,
            multi_label=False
        )

class TopicBannerRequest(BaseModel):
    prompt: str
    labels: List[str]

class TopicBannerResult(BaseModel):
    sequence: str
    labels: list
    scores: list

class GuardrailsRequest(BaseModel):
    prompt: str
    guardrails: List[str]
    labels: Optional[List[str]] = None

class GuardrailsResponse(BaseModel):
    prompt_injection: Optional[ClassificationResult] = None
    toxicity: Optional[ToxicityResult] = None
    topic_banner: Optional[TopicBannerResult] = None

app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()

@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
    try:
        result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
        return ToxicityResult.from_dict(result)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
    try:
        result = await guardrail.guard(text_prompt.prompt)
        label = result[0]['label']
        score = result[0]['score']
        level, severity_label = guardrail.determine_level(label, score)
        return {"label": label, "score": score, "level": level, "severity_label": severity_label}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(request: TopicBannerRequest):
    try:
        result = await topic_banner_classifier.classify(request.prompt, request.labels)
        return {
            "sequence": result["sequence"],
            "labels": result["labels"],
            "scores": result["scores"]
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/guardrails", response_model=GuardrailsResponse)
async def evaluate_guardrails(request: GuardrailsRequest):
    tasks = []
    response = GuardrailsResponse()

    if "pi" in request.guardrails:
        tasks.append(classify_text(TextPrompt(prompt=request.prompt)))
    if "tox" in request.guardrails:
        tasks.append(classify_toxicity(TextPrompt(prompt=request.prompt)))
    if "top" in request.guardrails:
        if not request.labels:
            raise HTTPException(status_code=400, detail="Labels are required for topic banner classification")
        tasks.append(classify_topic_banner(TopicBannerRequest(prompt=request.prompt, labels=request.labels)))

    results = await asyncio.gather(*tasks, return_exceptions=True)

    for result, guardrail in zip(results, request.guardrails):
        if isinstance(result, Exception):
            # Handle the exception as needed
            continue
        if guardrail == "pi":
            response.prompt_injection = result
        elif guardrail == "tox":
            response.toxicity = result
        elif guardrail == "top":
            response.topic_banner = result

    return response

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)