import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from loguru import logger

class FeatureAnalyzer:
    def __init__(self,df,output_path):
        self.df=df
        self.output_path=output_path
    

    def plot_correlation_heatmap(self):
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        numeric_cols = self.df.select_dtypes(include=[np.number]).columns.drop('fake')
        correlation_matrix = self.df[numeric_cols].corr()
        plt.figure(figsize=(14, 12))
        sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', vmin=-1, vmax=1, center=0)
        plt.title('Correlation Heatmap of Numeric Features', fontsize=16)
        plt.tight_layout()
        output_file = Path(self.output_path) / 'correlation_heatmap.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved correlation heatmap to {output_file}")

    def plot_mean_by_fake_bar(self):
        key_features = [
            'review_stars', 'business_stars', 'business_review_count', 'user_review_count',
            'friends', 'fans', 'average_stars', 'tip_compliment_count', 'tip_count',
            'time_since_last_review_user', 'user_account_age', 'pronoun_density',
            'grammar_error_score', 'repetitive_words_count', 'similarity_to_other_reviews',
            'review_useful_funny_cool', 'user_useful_funny_cool', 'sentiment_polarity'
        ]
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        mean_by_fake = self.df.groupby('fake')[key_features].mean().T
        mean_by_fake.columns = ['Genuine (0)', 'Fake (1)']
        plt.figure(figsize=(12, 8))
        mean_by_fake.plot(kind='bar', color=['skyblue', 'salmon'], width=0.8)
        plt.title('Mean Feature Values by Fake Label', fontsize=16)
        plt.xlabel('Features', fontsize=12)
        plt.ylabel('Mean Value', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.legend(title='Fake Label')
        plt.tight_layout()
        output_file = Path(self.output_path) / 'mean_by_fake_bar.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved mean by fake bar plot to {output_file}")

    def plot_violin_plots(self):
        key_features = [
            'review_stars', 'business_stars', 'business_review_count', 'user_review_count',
            'friends', 'fans', 'average_stars', 'tip_compliment_count', 'tip_count',
            'time_since_last_review_user', 'user_account_age', 'pronoun_density',
            'grammar_error_score', 'repetitive_words_count', 'similarity_to_other_reviews',
            'review_useful_funny_cool', 'user_useful_funny_cool', 'sentiment_polarity'
        ]
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        plt.figure(figsize=(14, 10))
        for i, feature in enumerate(key_features[:6], 1):
            plt.subplot(2, 3, i)
            sns.violinplot(x='fake', y=feature, data=self.df, palette=['skyblue', 'salmon'])
            plt.title(f'{feature} Distribution', fontsize=12)
            plt.xlabel('Fake (0/1)', fontsize=10)
        plt.tight_layout()
        output_file = Path(self.output_path) / 'violin_plots.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved violin plots to {output_file}")

    def plot_box_plots(self):
        key_features = [
            'review_stars', 'business_stars', 'business_review_count', 'user_review_count',
            'friends', 'fans', 'average_stars', 'tip_compliment_count', 'tip_count',
            'time_since_last_review_user', 'user_account_age', 'pronoun_density',
            'grammar_error_score', 'repetitive_words_count', 'similarity_to_other_reviews',
            'review_useful_funny_cool', 'user_useful_funny_cool', 'sentiment_polarity'
        ]
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        plt.figure(figsize=(14, 10))
        for i, feature in enumerate(key_features[6:11], 1):
            plt.subplot(2, 3, i)
            sns.boxplot(x='fake', y=feature, data=self.df, palette=['skyblue', 'salmon'])
            plt.title(f'{feature} Distribution', fontsize=12)
            plt.xlabel('Fake (0/1)', fontsize=10)
        plt.tight_layout()
        output_file = Path(self.output_path) / 'box_plots.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved box plots to {output_file}")

    def plot_scatter_review_grammar(self):
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        plt.figure(figsize=(10, 6))
        sns.scatterplot(x='review_stars', y='grammar_error_score', hue='fake', data=self.df, palette=['blue', 'red'], alpha=0.5)
        plt.title('Review Stars vs Grammar Error Score by Fake Label', fontsize=16)
        plt.xlabel('Review Stars', fontsize=12)
        plt.ylabel('Grammar Error Score', fontsize=12)
        plt.legend(title='Fake')
        plt.tight_layout()
        output_file = Path(self.output_path) / 'scatter_review_grammar.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved scatter plot to {output_file}")

    def plot_density_plots(self):
        key_features = [
            'review_stars', 'business_stars', 'business_review_count', 'user_review_count',
            'friends', 'fans', 'average_stars', 'tip_compliment_count', 'tip_count',
            'time_since_last_review_user', 'user_account_age', 'pronoun_density',
            'grammar_error_score', 'repetitive_words_count', 'similarity_to_other_reviews',
            'review_useful_funny_cool', 'user_useful_funny_cool', 'sentiment_polarity'
        ]
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        plt.figure(figsize=(14, 10))
        for i, feature in enumerate(key_features[:4], 1):
            plt.subplot(2, 2, i)
            for label in [0, 1]:
                subset = self.df[self.df['fake'] == label]
                sns.kdeplot(subset[feature], label=f'Fake={label}', fill=True, alpha=0.5)
            plt.title(f'{feature} Density', fontsize=12)
            plt.xlabel(feature, fontsize=10)
            plt.legend()
        plt.tight_layout()
        output_file = Path(self.output_path) / 'density_plots.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved density plots to {output_file}")

    def plot_stacked_bar_similarity(self):
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        bins = pd.cut(self.df['similarity_to_other_reviews'], bins=10)
        stacked_data = self.df.groupby([bins, 'fake']).size().unstack(fill_value=0)
        stacked_data = stacked_data.div(stacked_data.sum(axis=1), axis=0)
        plt.figure(figsize=(12, 8))
        stacked_data.plot(kind='bar', stacked=True, color=['skyblue', 'salmon'], width=0.8)
        plt.title('Proportion of Fake by Similarity to Other Reviews Bins', fontsize=16)
        plt.xlabel('Similarity Bins', fontsize=12)
        plt.ylabel('Proportion', fontsize=12)
        plt.legend(['Genuine (0)', 'Fake (1)'], title='Fake Label')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        output_file = Path(self.output_path) / 'stacked_bar_similarity.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved stacked bar plot to {output_file}")

    def plot_pie_fake_distribution(self):
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        fake_counts = self.df['fake'].value_counts()
        plt.figure(figsize=(8, 8))
        plt.pie(fake_counts, labels=['Genuine (0)', 'Fake (1)'], colors=['skyblue', 'salmon'], autopct='%1.1f%%', startangle=90)
        plt.title('Distribution of Fake Labels', fontsize=16)
        plt.axis('equal')
        output_file = Path(self.output_path) / 'pie_fake_distribution.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved pie chart to {output_file}")

    def plot_count_code_switching(self):
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        plt.figure(figsize=(8, 6))
        sns.countplot(x='code_switching_flag', hue='fake', data=self.df, palette=['skyblue', 'salmon'])
        plt.title('Count of Fake by Code Switching Flag', fontsize=16)
        plt.xlabel('Code Switching Flag (0/1)', fontsize=12)
        plt.ylabel('Count', fontsize=12)
        plt.legend(title='Fake Label')
        plt.tight_layout()
        output_file = Path(self.output_path) / 'count_code_switching.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved count plot to {output_file}")

    def plot_variance_by_fake_bar(self):
        key_features = [
            'review_stars', 'business_stars', 'business_review_count', 'user_review_count',
            'friends', 'fans', 'average_stars', 'tip_compliment_count', 'tip_count',
            'time_since_last_review_user', 'user_account_age', 'pronoun_density',
            'grammar_error_score', 'repetitive_words_count', 'similarity_to_other_reviews',
            'review_useful_funny_cool', 'user_useful_funny_cool', 'sentiment_polarity'
        ]
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        variance_by_fake = self.df.groupby('fake')[key_features].var().T
        variance_by_fake.columns = ['Genuine (0)', 'Fake (1)']
        plt.figure(figsize=(12, 8))
        variance_by_fake.plot(kind='bar', color=['skyblue', 'salmon'], width=0.8)
        plt.title('Feature Variance by Fake Label', fontsize=16)
        plt.xlabel('Features', fontsize=12)
        plt.ylabel('Variance', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.legend(title='Fake Label')
        plt.tight_layout()
        output_file = Path(self.output_path) / 'variance_by_fake_bar.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        logger.info(f"Saved variance bar plot to {output_file}")

    def run_pipeline(self):

        sns.set(style="whitegrid")
        plt.rcParams['figure.figsize'] = (12, 8)
        self.plot_correlation_heatmap()
        self.plot_mean_by_fake_bar()
        self.plot_violin_plots()
        self.plot_box_plots()
        self.plot_scatter_review_grammar()
        self.plot_density_plots()
        self.plot_stacked_bar_similarity()
        self.plot_pie_fake_distribution()
        self.plot_count_code_switching()
        self.plot_variance_by_fake_bar()