import pandas as pd
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from umap import UMAP
from Functionalities import NLP_Helper
# Visualization
import plotly.graph_objects as go


class TopicClustering:
    def __init__(self, keyword_df, text_col, representation_model, sentence_model):
        self.topic_names = None
        self.topic_model = None
        self.embeddings = None
        self.topic_name_mapping = {}
        self.keyword_df, self.text_col = keyword_df, text_col
        self.sentence_model = SentenceTransformer(sentence_model)
        self.representation_model = NLP_Helper.get_bertopic_representation(representation_model)

    def topic_cluster_bert(self) -> None:
        self.embeddings = self.sentence_model.encode(self.keyword_df[self.text_col], show_progress_bar=True)
        self.topic_model = BERTopic(representation_model=self.representation_model,
                                    embedding_model=self.sentence_model,
                                    n_gram_range=(1, 3), top_n_words=2)
        topics, _ = self.topic_model.fit_transform(self.keyword_df[self.text_col])
        topic_labels = self.topic_model.generate_topic_labels(nr_words=1, topic_prefix=False)

        if self.topic_model.get_topic_info()['Topic'].values[0] == -1:
            topic_labels[0] = 'Unknown'
        self.topic_model.set_topic_labels(topic_labels)

        self.keyword_df['Topic'] = topics
        topic_info = self.topic_model.get_topic_info()

        topic_info['Name'] = topic_labels
        self.keyword_df = pd.merge(topic_info, self.keyword_df, on=['Topic'])
        self.keyword_df.rename(columns={'Name': 'Topic Name'}, inplace=True)
        self.keyword_df.drop(columns=['CustomName'], inplace=True)
        self.topic_names = topic_labels

    def visualize_documents(self, n_neighbors) -> go.Figure:
        reduced_embeddings = UMAP(n_neighbors=n_neighbors, n_components=2, min_dist=0.0, metric='cosine').fit_transform(
            self.embeddings)

        fig = self.topic_model.visualize_documents(self.keyword_df[self.text_col],
                                                   reduced_embeddings=reduced_embeddings,
                                                   custom_labels=True)
        return fig

    def visualize_topic_distribution(self) -> go.Figure:
        fig = self.topic_model.visualize_barchart(custom_labels=True, top_n_topics=5, n_words=20,
                                                  title='Topic Distribution')
        return fig

    def update_topic_names(self):
        for k in self.topic_name_mapping:
            self.keyword_df['Topic Name'][self.keyword_df['Topic Name'] == k] = self.topic_name_mapping[k]

        self.topic_names = self.topic_name_mapping.values()
        self.topic_name_mapping = {}

    def get_df_in_google_ads_format(self, campaign_name):
        keyword_df_google_ads = pd.DataFrame(
            columns=['Action', 'Keyword status', 'Campaign', 'Ad group', 'Keyword', 'Match Type'])
        keyword_df_google_ads['Ad group'] = self.keyword_df['Topic Name']
        keyword_df_google_ads['Keyword'] = self.keyword_df[self.text_col]
        keyword_df_google_ads['Match Type'] = 'Phrase'
        keyword_df_google_ads['Action'] = 'Add'
        keyword_df_google_ads['Keyword status'] = 'Enabled'
        keyword_df_google_ads['Campaign'] = campaign_name
        return keyword_df_google_ads