import os
from pathlib import Path
from flask import Flask, render_template, request, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import warnings

# Suppress FutureWarnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Set the TRANSFORMERS_CACHE to a writable directory
os.environ["TRANSFORMERS_CACHE"] = "./cache"  # Modify this path if needed

app = Flask(__name__)

# Configuration  # Directory containing model files
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512
BERT_TOKENIZER = 'bert-base-uncased'
ROBERTA_TOKENIZER = 'jcblaise/roberta-tagalog-base'
ELECTRA_TOKENIZER = 'google/electra-base-discriminator'

LABELS = ["fake", "real"]

class Classifier:
    def __init__(self, model_path, device, tokenizer_name):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_path,
            local_files_only=True,
            device_map=device
        )
        self.model.eval()

    def predict(self, text):
        """Make prediction for a single text"""
        # Tokenize
        inputs = self.tokenizer(
            text,
            truncation=True,
            max_length=MAX_LENGTH,
            padding=True,
            return_tensors="pt"
        ).to(self.device)

        # Get prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
            confidence_scores = probabilities[0].tolist()

        # Format results
        result = {
            'predicted_class': LABELS[predicted_class],
            'confidence_scores': {
                label: score 
                for label, score in zip(LABELS, confidence_scores)
            }
        }
        return result

@app.route('/')
def home():
    return render_template('index.html')

@app.route('/detect', methods=['POST'])
def detect():
    try:
        data = request.get_json()
        news_text = data.get('text')
        model_chosen = data.get('model')

        print(model_chosen)

        if not news_text:
            return jsonify({
                'status': 'error',
                'message': 'No text provided'
            }), 400
        
        switch = {
            'nonaug-bert': 'bert-nonaug',
            'aug-bert': 'bert-aug',
            'nonaug-tagbert': 'tagbert-nonaug',
            'aug-tagbert': 'tagbert-aug',
            'nonaug-electra': 'electra-nonaug',
            'aug-electra': 'electra-aug'
        }

        model_p = switch.get(model_chosen)

        print("model", model_p)

        # Adjusting the model path to point to the correct folder inside 'webapp'
        MODEL_PATH = Path("huggingface", "webapp", model_p)  # Corrected model path to webapp folder

        print(MODEL_PATH)

        tokenizer = model_chosen.split("-")[1]
        tokenizer_chosen = {
            'bert': BERT_TOKENIZER,
            'tagbert': ROBERTA_TOKENIZER,
            'electra': ELECTRA_TOKENIZER
        }

        print(tokenizer)

        classifier = Classifier(MODEL_PATH, DEVICE, tokenizer_chosen.get(tokenizer))

        result = classifier.predict(news_text)
        print(result['confidence_scores'])

        if result['predicted_class'] == "fake":
            out = "News Needs Further Validation"
        else:
            out = "News is Real"

        return jsonify({
            'status': 'success',
            'prediction': out,
            'confidence': result['confidence_scores']
        })

    except Exception as e:
        return jsonify({
            'status': 'error',
            'message': str(e)
        }), 400

if __name__ == '__main__':
    app.run(debug=True)