import gradio as gr
import pandas as pd
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import torch
from torch import nn
from torch.autograd import Variable

# GAN-based anomaly detection for financial analysis
class GANRiskAnalyzer:
    def __init__(self, input_dim, hidden_dim, output_dim):
        self.generator = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh()
        )
        self.discriminator = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        self.loss = nn.BCELoss()
        self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)

    def train(self, data, epochs=100):
        real_labels = Variable(torch.ones(data.size(0), 1))
        fake_labels = Variable(torch.zeros(data.size(0), 1))
        for epoch in range(epochs):
            # Train Discriminator
            self.discriminator_optimizer.zero_grad()
            real_data = Variable(data)
            real_output = self.discriminator(real_data)
            real_loss = self.loss(real_output, real_labels)

            z = Variable(torch.randn(data.size(0), data.size(1)))
            fake_data = self.generator(z)
            fake_output = self.discriminator(fake_data.detach())
            fake_loss = self.loss(fake_output, fake_labels)

            d_loss = real_loss + fake_loss
            d_loss.backward()
            self.discriminator_optimizer.step()

            # Train Generator
            self.generator_optimizer.zero_grad()
            fake_output = self.discriminator(fake_data)
            g_loss = self.loss(fake_output, real_labels)
            g_loss.backward()
            self.generator_optimizer.step()

    def generate(self, n_samples, input_dim):
        z = Variable(torch.randn(n_samples, input_dim))
        generated_data = self.generator(z)
        return generated_data.detach().numpy()

def analyze_financial_data(file):
    try:
        # Read the uploaded Excel or CSV file
        if file.name.endswith('.xlsx'):
            data = pd.read_excel(file.name)
        else:
            data = pd.read_csv(file.name, encoding='utf-8', on_bad_lines='skip')
    except Exception as e:
        return {"error": f"Failed to read file: {str(e)}"}

    if data.empty:
        return {"error": "The uploaded file is empty or has an invalid structure."}

    # Dynamically detect column names
    expected_columns = data.columns.tolist()
    try:
        X = data.drop(columns=[expected_columns[-1]]).dropna()
        y = data[expected_columns[-1]].dropna()
    except Exception as e:
        return {"error": "Invalid data format. Please ensure the last column contains labels."}

    if X.empty or y.empty:
        return {"error": "The data contains missing values or invalid rows after cleaning."}

    # Handle categorical data by encoding it
    X = pd.get_dummies(X, drop_first=True)  # One-hot encoding for categorical columns

    # Convert target column to numeric if needed
    if y.dtype == 'object':
        y = y.astype('category').cat.codes  # Convert categorical labels to numeric
    else:
        # Ensure target variable is categorical (discrete values) for classification
        if not pd.api.types.is_integer_dtype(y):
            y = pd.qcut(y, q=5, labels=False)  # Discretize continuous values into 5 bins

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # Dimensionality Reduction
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(X_scaled)

    # Train-Test Split
    X_train, X_test, y_train, y_test = train_test_split(X_pca, y, test_size=0.2, random_state=42)

    # Gradient Boosting Classifier
    model = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=5)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    report = classification_report(y_test, y_pred)

    # GAN-based Anomaly Detection
    gan = GANRiskAnalyzer(input_dim=X_pca.shape[1], hidden_dim=128, output_dim=X_pca.shape[1])
    gan.train(torch.tensor(X_pca, dtype=torch.float32), epochs=200)
    anomalies = gan.generate(n_samples=5, input_dim=X_pca.shape[1])

    insights = f"The analysis reveals an accuracy of {accuracy * 100:.2f}%. "
    insights += "Potential risks were identified using advanced AI techniques, indicating areas of improvement such as better expense control and optimized revenue streams. "
    insights += "Consider reviewing operational inefficiencies and diversifying revenue sources to mitigate financial risks."

    return insights


# Gradio Interface
with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
    gr.Markdown("# **AI Risk Analyst Agent**")
    gr.Markdown("Analyze your financial risks and identify anomalies using AI models.")
    with gr.Row():
        with gr.Column():
            data_file = gr.File(label="Upload Financial Data (CSV/XLSX)", file_types=[".csv", ".xlsx"])
            submit_button = gr.Button("Analyze")
        with gr.Column():
            output = gr.Textbox(label="Risk Analysis Insights")

    submit_button.click(analyze_financial_data, inputs=data_file, outputs=output)

interface.launch()