import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report, roc_curve
from sklearn.model_selection import train_test_split
from pathlib import Path
from datetime import datetime
from loguru import logger

# Temporal Edge Features Function
def create_temporal_edge_features(time_since_src, time_since_tgt, user_i, user_j):
    delta_t = torch.abs(time_since_src - time_since_tgt).float()
    hour_scale = torch.sin(delta_t / 3600)
    day_scale = torch.sin(delta_t / (24 * 3600))
    week_scale = torch.sin(delta_t / (7 * 24 * 3600))
    same_user = (user_i == user_j).float()
    burst_feature = same_user * torch.exp(-delta_t / (24 * 3600))
    return torch.stack([hour_scale, day_scale, week_scale, burst_feature], dim=-1)

# Custom Multihead Attention (unchanged)
class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.scale = self.head_dim ** -0.5

    def forward(self, query, key, value, attn_bias=None):
        batch_size, seq_len, embed_dim = query.size()
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if attn_bias is not None:
            scores = scores + attn_bias.unsqueeze(1)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        out = self.out_proj(out)
        return out, attn

# HeteroGraphormer (unchanged)
class HeteroGraphormer(nn.Module):
    def __init__(self, hidden_dim, output_dim, num_heads=4, edge_dim=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.embed_dict = nn.ModuleDict({
            'user': nn.Linear(14, hidden_dim),
            'business': nn.Linear(8, hidden_dim),
            'review': nn.Linear(16, hidden_dim)
        })
        
        self.edge_proj = nn.Linear(edge_dim, hidden_dim)
        
        self.gru_user = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.gru_business = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.gru_review = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        
        self.attention1 = CustomMultiheadAttention(hidden_dim, num_heads)
        self.attention2 = CustomMultiheadAttention(hidden_dim, num_heads)
        
        self.ffn1 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        self.ffn2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        self.norm4 = nn.LayerNorm(hidden_dim)
        
        self.centrality_proj = nn.Linear(1, hidden_dim)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, 1)
        )
        
        self.dropout = nn.Dropout(0.1)

    def time_aware_aggregation(self, x, time_since, decay_rate=0.1):
        weights = torch.exp(-decay_rate * time_since.unsqueeze(-1))
        return x * weights

    def forward(self, data, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict):
        x_dict = {}
        for node_type in data.x_dict:
            x = self.embed_dict[node_type](data[node_type].x)
            if node_type in time_since_dict:
                x = self.time_aware_aggregation(x, time_since_dict[node_type])
            x_dict[node_type] = x
        
        x = torch.cat([x_dict['user'], x_dict['business'], x_dict['review']], dim=0)
        
        centrality = self.centrality_proj(centrality_encoding)
        x = x + centrality

        x = x.unsqueeze(0)

        x_user = x[:, :data['user'].x.size(0), :]
        x_business = x[:, data['user'].x.size(0):data['user'].x.size(0) + data['business'].x.size(0), :]
        x_review = x[:, data['user'].x.size(0) + data['business'].x.size(0):, :]
        
        x_user, _ = self.gru_user(x_user)
        x_business, _ = self.gru_business(x_business)
        x_review, _ = self.gru_review(x_review)
        
        x = torch.cat([x_user, x_business, x_review], dim=1)

        total_nodes = x.size(1)
        attn_bias = torch.zeros(1, total_nodes, total_nodes, device=x.device)
        attn_bias[0] = -spatial_encoding
        
        for edge_type in edge_features_dict:
            edge_index = data[edge_type].edge_index
            edge_feats = self.edge_proj(edge_features_dict[edge_type])
            for i, (src, tgt) in enumerate(edge_index.t()):
                attn_bias[0, src, tgt] += edge_feats[i].sum()

        residual = x
        x, _ = self.attention1(x, x, x, attn_bias=attn_bias)
        x = self.norm1(x + residual)
        x = self.dropout(x)

        residual = x
        x = self.ffn1(x)
        x = self.norm2(x + residual)
        x = self.dropout(x)

        residual = x
        x, _ = self.attention2(x, x, x, attn_bias=attn_bias)
        x = self.norm3(x + residual)
        x = self.dropout(x)

        residual = x
        x = self.ffn2(x)
        x = self.norm4(x + residual)
        x = self.dropout(x)

        x = x.squeeze(0)

        user_start = 0
        business_start = data['user'].x.size(0)
        review_start = business_start + data['business'].x.size(0)
        
        h_user = x[user_start:business_start]
        h_business = x[business_start:review_start]
        h_review = x[review_start:]
        
        user_indices = data['user', 'writes', 'review'].edge_index[0]
        business_indices = data['review', 'about', 'business'].edge_index[1]
        review_indices = data['user', 'writes', 'review'].edge_index[1]
        
        h_user_mapped = h_user[user_indices]
        h_business_mapped = h_business[business_indices]
        h_review_mapped = h_review[review_indices]
        
        combined = torch.cat([h_review_mapped, h_user_mapped, h_business_mapped], dim=-1)
        
        logits = self.classifier(combined)
        return torch.sigmoid(logits)

# Updated GraphformerModel with Plotting
class GraphformerModel:
    def __init__(self, df, output_path, epochs, test_size=0.3):
        self.df_whole = df
        self.output_path = output_path
        self.output_path = Path(self.output_path) / "GraphformerModel"
        self.epochs = epochs
        self.df, self.test_df = train_test_split(self.df_whole, test_size=test_size, random_state=42)
        
        torch.manual_seed(42)
        np.random.seed(42)
        
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = HeteroGraphormer(hidden_dim=64, output_dim=1, edge_dim=4).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.005)
        self.criterion = nn.BCELoss()

    def compute_graph_encodings(self, data):
        G = nx.DiGraph()
        node_offset = 0
        node_type_map = {}
        
        for node_type in ['user', 'business', 'review']:
            num_nodes = data[node_type].x.size(0)
            for i in range(num_nodes):
                G.add_node(node_offset + i)
                node_type_map[node_offset + i] = node_type
            node_offset += num_nodes

        edge_types = [('user', 'writes', 'review'), ('review', 'about', 'business')]
        for src_type, rel, tgt_type in edge_types:
            edge_index = data[src_type, rel, tgt_type].edge_index
            src_nodes = edge_index[0].tolist()
            tgt_nodes = edge_index[1].tolist()
            src_offset = 0 if src_type == 'user' else (self.num_users if src_type == 'business' else self.num_users + self.num_businesses)
            tgt_offset = 0 if tgt_type == 'user' else (self.num_users if tgt_type == 'business' else self.num_users + self.num_businesses)
            for src, tgt in zip(src_nodes, tgt_nodes):
                G.add_edge(src + src_offset, tgt + tgt_offset)

        num_nodes = G.number_of_nodes()
        spatial_encoding = torch.full((num_nodes, num_nodes), float('inf'), device=self.device)
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i == j:
                    spatial_encoding[i, j] = 0
                elif nx.has_path(G, i, j):
                    spatial_encoding[i, j] = nx.shortest_path_length(G, i, j)
        
        centrality_encoding = torch.tensor([G.degree(i) for i in range(num_nodes)], dtype=torch.float, device=self.device).view(-1, 1)
        
        return spatial_encoding, centrality_encoding, node_type_map

    def compute_metrics(self, y_true, y_pred, y_prob, prefix=""):
        metrics = {}
        metrics[f"{prefix}accuracy"] = accuracy_score(y_true, y_pred)
        metrics[f"{prefix}precision"] = precision_score(y_true, y_pred, zero_division=0)
        metrics[f"{prefix}recall"] = recall_score(y_true, y_pred, zero_division=0)
        metrics[f"{prefix}f1"] = f1_score(y_true, y_pred, zero_division=0)
        metrics[f"{prefix}auc_roc"] = roc_auc_score(y_true, y_prob)
        metrics[f"{prefix}conf_matrix"] = confusion_matrix(y_true, y_pred)
        metrics[f"{prefix}class_report"] = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
        return metrics

    def run_model(self):
        features = torch.tensor(self.df.drop(columns=['user_id', 'review_id', 'business_id', 'fake']).values, dtype=torch.float, device=self.device)
        y = torch.tensor(self.df['fake'].values, dtype=torch.float, device=self.device)
        time_since_user = torch.tensor(self.df['time_since_last_review_user'].values, dtype=torch.float, device=self.device)
        time_since_business = torch.tensor(self.df['time_since_last_review_business'].values, dtype=torch.float, device=self.device)
        num_rows = len(self.df)
    
        graph = HeteroData()
    
        self.num_users = len(self.df['user_id'].unique())
        self.num_businesses = len(self.df['business_id'].unique())
    
        user_indices = torch.tensor(self.df['user_id'].map({uid: i for i, uid in enumerate(self.df['user_id'].unique())}).values, dtype=torch.long, device=self.device)
        business_indices = torch.tensor(self.df['business_id'].map({bid: i for i, bid in enumerate(self.df['business_id'].unique())}).values, dtype=torch.long, device=self.device)
        review_indices = torch.arange(num_rows, dtype=torch.long, device=self.device)
    
        user_feats = torch.zeros(self.num_users, 14, device=self.device)
        business_feats = torch.zeros(self.num_businesses, 8, device=self.device)
        review_feats = torch.zeros(num_rows, 16, device=self.device)
    
        user_cols = ['hours', 'user_review_count', 'elite', 'friends', 'fans', 'average_stars',
                     'time_since_last_review_user', 'user_account_age', 'user_degree',
                     'user_review_burst_count', 'review_like_ratio', 'latest_checkin_hours',
                     'user_useful_funny_cool', 'rating_variance_user']
        business_cols = ['latitude', 'longitude', 'business_stars', 'business_review_count',
                         'time_since_last_review_business', 'business_degree',
                         'business_review_burst_count', 'rating_deviation_from_business_average']
        review_cols = ['review_stars', 'tip_compliment_count', 'tip_count', 'average_time_between_reviews',
                       'temporal_similarity', 'pronoun_density', 'avg_sentence_length',
                       'excessive_punctuation_count', 'sentiment_polarity', 'good_severity',
                       'bad_severity', 'code_switching_flag', 'grammar_error_score',
                       'repetitive_words_count', 'similarity_to_other_reviews', 'review_useful_funny_cool']
    
        for i in range(len(self.df)):
            user_idx = user_indices[i]
            business_idx = business_indices[i]
            user_feats[user_idx] += features[i, :14]
            business_feats[business_idx] += features[i, 14:22]
        review_feats = features[:, 22:38]
    
        graph['user'].x = user_feats
        graph['business'].x = business_feats
        graph['review'].x = review_feats
        graph['review'].y = y
    
        graph['user', 'writes', 'review'].edge_index = torch.stack([user_indices, review_indices], dim=0)
        graph['review', 'about', 'business'].edge_index = torch.stack([review_indices, business_indices], dim=0)
    
        edge_features_dict = {}
        user_writes_edge = graph['user', 'writes', 'review'].edge_index
        review_about_edge = graph['review', 'about', 'business'].edge_index
        
        src_users = user_indices[user_writes_edge[0]]
        tgt_reviews = review_indices[user_writes_edge[1]]
        edge_features_dict[('user', 'writes', 'review')] = create_temporal_edge_features(
            time_since_user[src_users], time_since_user[tgt_reviews], src_users, src_users
        )
        
        src_reviews = review_indices[review_about_edge[0]]
        tgt_businesses = business_indices[review_about_edge[1]]
        edge_features_dict[('review', 'about', 'business')] = create_temporal_edge_features(
            time_since_business[src_reviews], time_since_business[tgt_businesses], 
            torch.zeros_like(src_reviews), torch.zeros_like(src_reviews)
        )
    
        user_time_since = self.df.groupby('user_id')['time_since_last_review_user'].min().reindex(
            self.df['user_id'].unique(), fill_value=0).values
        time_since_dict = {
            'user': torch.tensor(user_time_since, dtype=torch.float, device=self.device)
        }
    
        spatial_encoding, centrality_encoding, node_type_map = self.compute_graph_encodings(graph)
    
        # Training with metrics history
        self.model.train()
        train_metrics_history = []
        for epoch in range(self.epochs):
            self.optimizer.zero_grad()
            out = self.model(graph, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict)
            loss = self.criterion(out.squeeze(), y)
            loss.backward()
            self.optimizer.step()
            
            pred_labels = (out.squeeze() > 0.5).float()
            logger.info(f"PREDICTED LABELS : {pred_labels}")
            # print(pred_labels)
            probs = out.squeeze().detach().cpu().numpy()
            train_metrics = self.compute_metrics(y.cpu().numpy(), pred_labels.cpu().numpy(), probs, prefix="train_")
            train_metrics['loss'] = loss.item()
            train_metrics_history.append(train_metrics)
            
            if epoch % 10 == 0:
                logger.info(f"Epoch {epoch}, Loss: {loss.item():.4f}, Accuracy: {train_metrics['train_accuracy']:.4f}, F1: {train_metrics['train_f1']:.4f}")
    
        # Save model
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_save_path = Path(self.output_path) / f"model_GraphformerModel_latest.pth"
        torch.save(self.model.state_dict(), model_save_path)
    
        # Testing
        if self.test_df is not None:
            test_features = torch.tensor(self.test_df.drop(columns=['user_id', 'review_id', 'business_id', 'fake']).values, dtype=torch.float, device=self.device)
            test_y = torch.tensor(self.test_df['fake'].values, dtype=torch.float, device=self.device)
            test_time_since_user = torch.tensor(self.test_df['time_since_last_review_user'].values, dtype=torch.float, device=self.device)
            test_time_since_business = torch.tensor(self.test_df['time_since_last_review_business'].values, dtype=torch.float, device=self.device)
            num_test_rows = len(self.test_df)
    
            new_user_unique = self.test_df['user_id'].unique()
            new_business_unique = self.test_df['business_id'].unique()
    
            existing_user_ids = list(self.df['user_id'].unique())
            user_mapping = {uid: i for i, uid in enumerate(existing_user_ids)}
            total_users = self.num_users
            for uid in new_user_unique:
                if uid not in user_mapping:
                    user_mapping[uid] = total_users
                    total_users += 1
    
            existing_business_ids = list(self.df['business_id'].unique())
            business_mapping = {bid: i for i, bid in enumerate(existing_business_ids)}
            total_businesses = self.num_businesses
            for bid in new_business_unique:
                if bid not in business_mapping:
                    business_mapping[bid] = total_businesses
                    total_businesses += 1
    
            new_user_indices = torch.tensor([user_mapping[uid] for uid in self.test_df['user_id']], dtype=torch.long, device=self.device)
            new_business_indices = torch.tensor([business_mapping[bid] for bid in self.test_df['business_id']], dtype=torch.long, device=self.device)
            new_review_indices = torch.arange(num_rows, num_rows + num_test_rows, device=self.device)
    
            if total_users > self.num_users:
                additional_user_feats = torch.zeros(total_users - self.num_users, 14, device=self.device)
                graph['user'].x = torch.cat([graph['user'].x, additional_user_feats], dim=0)
            if total_businesses > self.num_businesses:
                additional_business_feats = torch.zeros(total_businesses - self.num_businesses, 8, device=self.device)
                graph['business'].x = torch.cat([graph['business'].x, additional_business_feats], dim=0)
    
            for i in range(num_test_rows):
                user_idx = new_user_indices[i]
                business_idx = new_business_indices[i]
                if user_idx < graph['user'].x.size(0):
                    graph['user'].x[user_idx] += test_features[i, :14]
                if business_idx < graph['business'].x.size(0):
                    graph['business'].x[business_idx] += test_features[i, 14:22]
            graph['review'].x = torch.cat([graph['review'].x, test_features[:, 22:38]], dim=0)
            graph['review'].y = torch.cat([graph['review'].y, test_y], dim=0)
    
            graph['user', 'writes', 'review'].edge_index = torch.cat([
                graph['user', 'writes', 'review'].edge_index,
                torch.stack([new_user_indices, new_review_indices], dim=0)], dim=1)
            graph['review', 'about', 'business'].edge_index = torch.cat([
                graph['review', 'about', 'business'].edge_index,
                torch.stack([new_review_indices, new_business_indices], dim=0)], dim=1)
    
            all_time_since_user = torch.cat([time_since_user, test_time_since_user])
            all_time_since_business = torch.cat([time_since_business, test_time_since_business])
            all_user_indices = torch.cat([user_indices, new_user_indices])
            all_business_indices = torch.cat([business_indices, new_business_indices])
            all_review_indices = torch.cat([review_indices, new_review_indices])
            
            user_writes_edge = graph['user', 'writes', 'review'].edge_index
            review_about_edge = graph['review', 'about', 'business'].edge_index
            
            edge_features_dict[('user', 'writes', 'review')] = create_temporal_edge_features(
                all_time_since_user[user_writes_edge[0]], all_time_since_user[user_writes_edge[1]], 
                all_user_indices[user_writes_edge[0]], all_user_indices[user_writes_edge[0]]
            )
            edge_features_dict[('review', 'about', 'business')] = create_temporal_edge_features(
                all_time_since_business[review_about_edge[0]], all_time_since_business[review_about_edge[1]], 
                torch.zeros_like(review_about_edge[0]), torch.zeros_like(review_about_edge[0])
            )
    
            self.num_users = total_users
            self.num_businesses = total_businesses
    
            test_user_time_since = self.test_df.groupby('user_id')['time_since_last_review_user'].min().reindex(
                pd.Index(list(self.df['user_id'].unique()) + list(self.test_df['user_id'].unique())), fill_value=0).values
            time_since_dict['user'] = torch.tensor(test_user_time_since[:total_users], dtype=torch.float, device=self.device)
    
            spatial_encoding, centrality_encoding, node_type_map = self.compute_graph_encodings(graph)
    
            self.model.eval()
            with torch.no_grad():
                out = self.model(graph, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict)
                pred_labels = (out.squeeze() > 0.5).float()
                probs = out.squeeze().detach().cpu().numpy()
                test_metrics = self.compute_metrics(graph['review'].y[-num_test_rows:].cpu().numpy(), pred_labels[-num_test_rows:].cpu().numpy(), probs[-num_test_rows:], prefix="test_")
                train_metrics = self.compute_metrics(y.cpu().numpy(), pred_labels[:num_rows].cpu().numpy(), probs[:num_rows], prefix="train_")
                logger.info(f"Test Accuracy: {test_metrics['test_accuracy']:.4f}, F1: {test_metrics['test_f1']:.4f}, AUC-ROC: {test_metrics['test_auc_roc']:.4f}")
    
            # Save metrics to file
            metrics_file = Path(self.output_path) / f"metrics_{timestamp}.txt"
            with open(metrics_file, 'w') as f:
                f.write("Training Metrics (Final Epoch):\n")
                for k, v in train_metrics.items():
                    f.write(f"{k}: {v}\n")
                f.write("\nTest Metrics:\n")
                for k, v in test_metrics.items():
                    f.write(f"{k}: {v}\n")
    
            # Plotting and saving to output_path
            plt.figure(figsize=(12, 8))
            plt.plot([m['loss'] for m in train_metrics_history], label='Training Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training Loss Curve')
            plt.legend()
            plt.grid(True)
            plt.savefig(Path(self.output_path) / f"loss_curve_{timestamp}.png")
            plt.close()
    
            plt.figure(figsize=(12, 8))
            plt.plot([m['train_accuracy'] for m in train_metrics_history], label='Training Accuracy')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.title('Training Accuracy Curve')
            plt.legend()
            plt.grid(True)
            plt.savefig(Path(self.output_path) / f"accuracy_curve_{timestamp}.png")
            plt.close()
    
            plt.figure(figsize=(12, 8))
            plt.plot([m['train_precision'] for m in train_metrics_history], label='Training Precision')
            plt.plot([m['train_recall'] for m in train_metrics_history], label='Training Recall')
            plt.plot([m['train_f1'] for m in train_metrics_history], label='Training F1-Score')
            plt.xlabel('Epoch')
            plt.ylabel('Score')
            plt.title('Training Precision, Recall, and F1-Score Curves')
            plt.legend()
            plt.grid(True)
            plt.savefig(Path(self.output_path) / f"prf1_curves_{timestamp}.png")
            plt.close()
    
            plt.figure(figsize=(12, 8))
            plt.plot([m['train_auc_roc'] for m in train_metrics_history], label='Training AUC-ROC')
            plt.xlabel('Epoch')
            plt.ylabel('AUC-ROC')
            plt.title('Training AUC-ROC Curve')
            plt.legend()
            plt.grid(True)
            plt.savefig(Path(self.output_path) / f"auc_roc_curve_train_{timestamp}.png")
            plt.close()
    
            plt.figure(figsize=(8, 6))
            sns.heatmap(test_metrics['test_conf_matrix'], annot=True, fmt='d', cmap='Blues', cbar=False)
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title('Test Confusion Matrix')
            plt.savefig(Path(self.output_path) / f"confusion_matrix_test_{timestamp}.png")
            plt.close()
    
            fpr, tpr, _ = roc_curve(graph['review'].y[-num_test_rows:].cpu().numpy(), probs[-num_test_rows:])
            plt.figure(figsize=(10, 6))
            plt.plot(fpr, tpr, label=f'Test ROC Curve (AUC = {test_metrics["test_auc_roc"]:.4f})')
            plt.plot([0, 1], [0, 1], 'k--', label='Random Guess')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Test ROC Curve')
            plt.legend()
            plt.grid(True)
            plt.savefig(Path(self.output_path) / f"roc_curve_test_{timestamp}.png")
            plt.close()
    
            plt.figure(figsize=(8, 6))
            sns.heatmap(train_metrics['train_conf_matrix'], annot=True, fmt='d', cmap='Blues', cbar=False)
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title('Training Confusion Matrix (Final Epoch)')
            plt.savefig(Path(self.output_path) / f"confusion_matrix_train_{timestamp}.png")
            plt.close()
    
            fpr_train, tpr_train, _ = roc_curve(graph['review'].y[:num_rows].cpu().numpy(), probs[:num_rows])
            plt.figure(figsize=(10, 6))
            plt.plot(fpr_train, tpr_train, label=f'Training ROC Curve (AUC = {train_metrics["train_auc_roc"]:.4f})')
            plt.plot([0, 1], [0, 1], 'k--', label='Random Guess')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Training ROC Curve (Final Epoch)')
            plt.legend()
            plt.grid(True)
            plt.savefig(Path(self.output_path) / f"roc_curve_train_{timestamp}.png")
            plt.close()
    
            logger.info(f"All metrics, plots, and model saved to {self.output_path}")