from flask import Flask, request, jsonify
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
import numpy as np
import pandas as pd
import networkx as nx
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report, roc_curve
from sklearn.model_selection import train_test_split
from pathlib import Path
from datetime import datetime
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from preprocessing_test import Preprocessor
from src.model import *
from main import start_pipelines

app = Flask(__name__)

# Define default values for each column
default_values = {
    'review_id': 'KU_O5udG6zpxOg-VcAEodg',
    'user_id': 'mh_-eMZ6K5RLWhZyISBhwA',
    'business_id': 'XQfwVwDr-v0ZS3_CbbE5Xw',
    'review_stars': 0,
    'review_useful': 0,
    'review_funny': 0,
    'review_cool': 0,
    'review_text': 'It was a moderate experience',
    'review_date': 1531001351000,
    'business_name': 'Coffe at LA',
    'address': '1460 LA',
    'city': 'LA',
    'state': 'CA',
    'postal_code': '00000',
    'latitude': 0.0,
    'longitude': 0.0,
    'business_stars': 0.0,
    'business_review_count': 0,
    'is_open': 0,
    'attributes': '{}',
    'categories': 'Restaurants',
    'hours': '{"Monday": "7:0-20:0", "Tuesday": "7:0-20:0", "Wednesday": "7:0-20:0", "Thursday": "7:0-20:0", "Friday": "7:0-21:0", "Saturday": "7:0-21:0", "Sunday": "7:0-21:0"}',
    'user_name': 'default_user',
    'user_review_count': 0,
    'yelping_since': '2023-01-01 00:00:00',
    'user_useful': 0,
    'user_funny': 0,
    'user_cool': 0,
    'elite': '2024,2025',
    'friends': '',
    'fans': 0,
    'average_stars': 0.0,
    'compliment_hot': 0,
    'compliment_more': 0,
    'compliment_profile': 0,
    'compliment_cute': 0,
    'compliment_list': 0,
    'compliment_note': 0,
    'compliment_plain': 0,
    'compliment_cool': 0,
    'compliment_funny': 0,
    'compliment_writer': 0,
    'compliment_photos': 0,
    'checkin_date': '2023-01-01 00:00:00',
    'tip_compliment_count': 0.0,
    'tip_count': 0.0
}

# Expected types for validation
expected_types = {
    'review_id': str,
    'user_id': str,
    'business_id': str,
    'review_stars': int,
    'review_useful': int,
    'review_funny': int,
    'review_cool': int,
    'review_text': str,
    'review_date': int,
    'business_name': str,
    'address': str,
    'city': str,
    'state': str,
    'postal_code': str,
    'latitude': float,
    'longitude': float,
    'business_stars': float,
    'business_review_count': int,
    'is_open': int,
    'attributes': dict,  # Assuming string representation of dict
    'categories': str,
    'hours': dict,  # Assuming string representation of dict
    'user_name': str,
    'user_review_count': int,
    'yelping_since': str,
    'user_useful': int,
    'user_funny': int,
    'user_cool': int,
    'elite': str,
    'friends': str,
    'fans': int,
    'average_stars': float,
    'compliment_hot': int,
    'compliment_more': int,
    'compliment_profile': int,
    'compliment_cute': int,
    'compliment_list': int,
    'compliment_note': int,
    'compliment_plain': int,
    'compliment_cool': int,
    'compliment_funny': int,
    'compliment_writer': int,
    'compliment_photos': int,
    'checkin_date': str,
    'tip_compliment_count': float,
    'tip_count': float
}

@app.route('/predict', methods=['POST'])
def predict():
    try:
        # Check if request contains JSON data
        if not request.json:
            return jsonify({'error': 'Request must contain JSON data'}), 400

        data = request.json

        # Extract train, test, and train_size with defaults
        train = data.get('train', False)
        test = data.get('test', False)
        train_size = float(data.get('train_size', 0.1))

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Handle training mode
        if train in (True, 'true', 'True'):
            start_pipelines(train_size=train_size)
            logger.info("PIPELINES FINISHED SUCCESSFULLY")
            return jsonify({
                'message': 'Training pipelines executed successfully',
                'train_size': train_size
            }), 200

        # Handle testing/inference mode
        elif test in (True, 'test', 'True'):
            REPO_ID = "Askhedi/graphformermodel"
            MODEL_FILENAME = "model_GraphformerModel_latest.pth"
            model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)

            # Load model
            model = HeteroGraphormer(hidden_dim=64, output_dim=1, edge_dim=4).to(device)
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.eval()

            # Process input data from JSON
            row = {}
            warnings = []
            for col, expected_type in expected_types.items():
                value = data.get(col, default_values[col])
                try:
                    if value == "" or value is None:
                        row[col] = default_values[col]
                    elif col in ['attributes', 'hours']:
                        # Expect a valid JSON string that parses to a dict
                        if isinstance(value, str):
                            parsed = json.loads(value)
                            if not isinstance(parsed, dict):
                                raise ValueError
                            row[col] = value  # Keep as string for Preprocessor
                        else:
                            raise ValueError
                    else:
                        row[col] = expected_type(value)
                except (ValueError, TypeError, json.JSONDecodeError):
                    row[col] = default_values[col]
                    warnings.append(f"Invalid input for '{col}' (expected {expected_type.__name__}), using default value: {default_values[col]}")

            # Convert dictionaries to strings before passing to DataFrame
            for col in ['attributes', 'hours']:
                if isinstance(row[col], dict):
                    row[col] = json.dumps(row[col])

            # Create DataFrame from input
            input_df = pd.DataFrame([row])

            # Preprocess using Preprocessor
            preprocessor = Preprocessor(input_df)
            processed_df = preprocessor.run_pipeline()
            logger.info(f"PREPROCESSING COMPLETED VALUES ARE {processed_df}")

            # Build standalone graph from processed data
            num_users = 1
            num_businesses = 1
            num_rows = 1

            graph = HeteroData()
            features = torch.tensor(processed_df.drop(columns=['user_id', 'review_id', 'business_id']).values, dtype=torch.float, device=device)
            time_since_user = torch.tensor(processed_df['time_since_last_review_user'].values, dtype=torch.float, device=device)
            time_since_business = torch.tensor(processed_df['time_since_last_review_business'].values, dtype=torch.float, device=device)

            user_indices = torch.tensor([0], dtype=torch.long, device=device)
            business_indices = torch.tensor([0], dtype=torch.long, device=device)
            review_indices = torch.tensor([0], dtype=torch.long, device=device)

            user_feats = torch.zeros(num_users, 14, device=device)
            business_feats = torch.zeros(num_businesses, 8, device=device)
            review_feats = torch.zeros(num_rows, 16, device=device)

            user_feats[0] = features[0, :14]
            business_feats[0] = features[0, 14:22]
            review_feats[0] = features[0, 22:38]

            graph['user'].x = user_feats
            graph['business'].x = business_feats
            graph['review'].x = review_feats

            graph['user', 'writes', 'review'].edge_index = torch.stack([user_indices, review_indices], dim=0)
            graph['review', 'about', 'business'].edge_index = torch.stack([review_indices, business_indices], dim=0)

            # Compute encodings
            G = nx.DiGraph()
            node_type_map = {0: 'user', 1: 'business', 2: 'review'}
            G.add_nodes_from([0, 1, 2])
            G.add_edge(0, 2)  # user -> review
            G.add_edge(2, 1)  # review -> business

            num_nodes = 3
            spatial_encoding = torch.full((num_nodes, num_nodes), float('inf'), device=device)
            for i in range(num_nodes):
                for j in range(num_nodes):
                    if i == j:
                        spatial_encoding[i, j] = 0
                    elif nx.has_path(G, i, j):
                        spatial_encoding[i, j] = nx.shortest_path_length(G, i, j)

            centrality_encoding = torch.tensor([G.degree(i) for i in range(num_nodes)], dtype=torch.float, device=device).view(-1, 1)

            edge_features_dict = {}
            user_writes_edge = graph['user', 'writes', 'review'].edge_index
            review_about_edge = graph['review', 'about', 'business'].edge_index

            edge_features_dict[('user', 'writes', 'review')] = create_temporal_edge_features(
                time_since_user[user_writes_edge[0]], time_since_user[user_writes_edge[1]],
                user_indices[user_writes_edge[0]], user_indices[user_writes_edge[0]]
            )
            edge_features_dict[('review', 'about', 'business')] = create_temporal_edge_features(
                time_since_business[review_about_edge[0]], time_since_business[review_about_edge[1]],
                torch.zeros_like(review_about_edge[0]), torch.zeros_like(review_about_edge[0])
            )

            time_since_dict = {
                'user': torch.tensor([time_since_user[0]], dtype=torch.float, device=device)
            }

            # Inference
            with torch.no_grad():
                out = model(graph, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict)
                pred_label = 1 if out.squeeze().item() > 0.5 else 0
                prob = out.squeeze().item()

            # Combine warnings and result
            result = {
                'warnings': warnings,
                'prediction': 'Fake' if pred_label == 1 else 'Not Fake',
                'probability': float(prob)
            }
            return jsonify(result), 200

        else:
            return jsonify({
                'error': 'Either "train" or "test" must be set to true'
            }), 400

    except Exception as e:
        return jsonify({'error': str(e)}), 500