from typing import Dict, Any, Iterable
from sklearn.feature_extraction.text import TfidfVectorizer
import wordcloud
from pydantic import BaseModel, Field
import numpy as np
import PIL
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go


class WordCloudExtractor(BaseModel):
    max_words: int = 50
    wordcloud_params: Dict[str, Any] = Field(default_factory=dict)
    tfidf_params: Dict[str, Any] = Field(default_factory=lambda: {"stop_words": "english"})

    def extract_wordcloud_image(self, texts) -> PIL.Image.Image:
        frequencies = self._extract_frequencies(texts, self.max_words, tfidf_params=self.tfidf_params)
        wc = wordcloud.WordCloud(**self.wordcloud_params).generate_from_frequencies(frequencies)
        return wc.to_image()

    @classmethod
    def _extract_frequencies(cls, texts, max_words=100, tfidf_params: dict={}) -> Dict[str, float]:
        """
        Extract word frequencies from a corpus using TF-IDF vectorization
        and generate word cloud frequencies.

        Args:
            texts: List of text documents
            max_features: Maximum number of words to include

        Returns:
            Dictionary of word frequencies suitable for WordCloud
        """
        # Initialize TF-IDF vectorizer
        tfidf = TfidfVectorizer(
            max_features=max_words,
            **tfidf_params
        )

        # Fit and transform the texts
        tfidf_matrix = tfidf.fit_transform(texts)

        # Get feature names (words)
        feature_names = tfidf.get_feature_names_out()

        # Calculate mean TF-IDF scores across documents
        mean_tfidf = np.array(tfidf_matrix.mean(axis=0)).flatten()

        # Create frequency dictionary
        frequencies = dict(zip(feature_names, mean_tfidf))

        return frequencies


class EmbeddingVisualizer(BaseModel):
    display_df: pd.DataFrame
    plot_kwargs: Dict[str, Any] = Field(default_factory=lambda: dict(
        range_x=(3, 16.5),
        range_y=(-3, 11),
        width=1200,
        height=800,
        x="x",
        y="y",
        template="plotly_white",
    ))

    def make_embedding_plots(self, color_col=None, hover_data=["name"], filter_df_fn=None):
        """
        plots Plotly scatterplot of UMAP embeddings
        """
        display_df = self.display_df
        if filter_df_fn is not None:
            display_df = filter_df_fn(display_df)

        display_df = display_df.sort_values("representation", ascending=False)
        readme_df = display_df[display_df["representation"].isin(["readme", "generated_readme", "task"])]
        raw_df = display_df[display_df["representation"].isin(["dependency_signature", "selected_code", "task"])]
        dependency_df = display_df[display_df["representation"].isin(["repository_signature", "dependency_signature", "generated_tasks", "task"])]

        plots = [
            self._make_task_and_repos_scatterplot(df, hover_data, color_col)
            for df in [readme_df, raw_df, dependency_df]
        ]
        return dict(zip(["READMEs", "Basic representations", "Dependency graph based representations"], plots))

    def _make_task_and_repos_scatterplot(self, df, hover_data, color_col):
        # Set opacity and symbol based on is_task
        df['size'] = df['is_task'].apply(lambda x: 0.25 if x else 0.1)
        df['symbol'] = df['is_task'].apply(int)

        combined_fig = px.scatter(
            df,
            hover_name="name",
            hover_data=hover_data,
            color=color_col,
            color_discrete_sequence=px.colors.qualitative.Set1,
            opacity=0.5,
            **self.plot_kwargs
        )
        combined_fig.data = combined_fig.data[::-1]

        return combined_fig

    def make_task_area_scatterplot(self, n_areas=6):
        display_df = self.display_df
        displayed_tasks_df = display_df[display_df["representation"] == "task"].sort_values("representation")
        displayed_tasks_df = displayed_tasks_df.merge(pd.read_csv("data/paperswithcode_tasks.csv"), left_on="name", right_on="task")
        displayed_tasks_df= displayed_tasks_df[displayed_tasks_df["area"].isin(displayed_tasks_df["area"].value_counts().head(n_areas).index)]
        tasks_fig = px.scatter(displayed_tasks_df, color="area", hover_data=["name"], opacity=0.7, **self.plot_kwargs)
        print("N DISPLAYED TASKS", len(displayed_tasks_df))
        return tasks_fig

    class Config:
        arbitrary_types_allowed = True