# abstractive.py

"""
Módulo de resúmenes 'abstractive.py'

Contiene implementaciones de diferentes técnicas de resumen de texto:
- TF-IDF Summarizer
- TextRank Summarizer
- Combined Summarizer (que combina TF-IDF y TextRank para extraer palabras clave)
- BERT Summarizer (extractivo basado en un modelo BERT preentrenado)
"""

import numpy as np
import networkx as nx
from typing import List
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from summarizer import Summarizer


class TFIDFSummarizer:
    """Genera resúmenes usando el modelo TF-IDF."""

    @staticmethod
    def summarize(sentences: List[str], preprocessed_sentences: List[str], num_sentences: int = 1) -> str:
        """
        Genera un resumen basado en TF-IDF seleccionando las oraciones mejor puntuadas.

        :param sentences: Lista de oraciones originales (sin procesar).
        :param preprocessed_sentences: Lista de oraciones preprocesadas (por ejemplo, tokenizadas o normalizadas).
        :param num_sentences: Número de oraciones a devolver en el resumen.
        :return: Un string que contiene el resumen formado por las oraciones más relevantes.
        """
        vectorizer = TfidfVectorizer()
        tfidf_matrix = vectorizer.fit_transform(preprocessed_sentences)
        sentence_scores = np.sum(tfidf_matrix.toarray(), axis=1)
        ranked_indices = np.argsort(sentence_scores)[::-1]
        selected = [sentences[i] for i in ranked_indices[:num_sentences]]
        return ' '.join(selected)


class TextRankSummarizer:
    """Genera resúmenes usando el algoritmo TextRank."""

    @staticmethod
    def summarize(sentences: List[str], preprocessed_sentences: List[str], num_sentences: int = 1) -> str:
        """
        Genera un resumen usando el algoritmo de grafos TextRank.

        :param sentences: Lista de oraciones originales.
        :param preprocessed_sentences: Lista de oraciones preprocesadas.
        :param num_sentences: Número de oraciones a devolver en el resumen.
        :return: Un string que contiene el resumen.
        """
        vectorizer = TfidfVectorizer()
        tfidf_matrix = vectorizer.fit_transform(preprocessed_sentences)
        similarity_matrix = cosine_similarity(tfidf_matrix)
        nx_graph = nx.from_numpy_array(similarity_matrix)
        scores = nx.pagerank(nx_graph)
        # Ordena los nodos (oraciones) por puntaje descendente
        ranked_indices = sorted(((scores[node], node) for node in nx_graph.nodes), reverse=True)
        selected = [sentences[i] for _, i in ranked_indices[:num_sentences]]
        return ' '.join(selected)


class CombinedSummarizer:
    """Genera resúmenes combinando palabras clave TF-IDF y TextRank."""

    def __init__(self, top_n_keywords: int = 10):
        """
        :param top_n_keywords: Número de palabras clave a extraer de cada método (TF-IDF y TextRank).
        """
        self.top_n_keywords = top_n_keywords

    def extract_keywords_tfidf(self, preprocessed_sentences: List[str]) -> List[str]:
        """
        Extrae palabras clave basadas en TF-IDF.

        :param preprocessed_sentences: Lista de oraciones preprocesadas.
        :return: Lista con las palabras clave más relevantes según TF-IDF.
        """
        vectorizer = TfidfVectorizer()
        tfidf_matrix = vectorizer.fit_transform(preprocessed_sentences)
        tfidf_scores = zip(vectorizer.get_feature_names_out(), tfidf_matrix.toarray().sum(axis=0))
        sorted_scores = sorted(tfidf_scores, key=lambda x: x[1], reverse=True)
        return [word for word, _ in sorted_scores[:self.top_n_keywords]]

    def extract_keywords_textrank(self, preprocessed_sentences: List[str]) -> List[str]:
        """
        Extrae palabras clave basadas en TextRank (a través de la co-ocurrencia de palabras).

        :param preprocessed_sentences: Lista de oraciones preprocesadas.
        :return: Lista con las palabras clave más relevantes según TextRank.
        """
        words = ' '.join(preprocessed_sentences).split()
        co_occurrence_graph = nx.Graph()
        for i in range(len(words) - 1):
            word_pair = (words[i], words[i + 1])
            if co_occurrence_graph.has_edge(*word_pair):
                co_occurrence_graph[word_pair[0]][word_pair[1]]['weight'] += 1
            else:
                co_occurrence_graph.add_edge(word_pair[0], word_pair[1], weight=1)

        ranks = nx.pagerank(co_occurrence_graph, weight='weight')
        sorted_ranks = sorted(ranks.items(), key=lambda x: x[1], reverse=True)
        return [word for word, _ in sorted_ranks[:self.top_n_keywords]]

    def combined_keywords(self, preprocessed_sentences: List[str]) -> List[str]:
        """
        Combina las palabras clave obtenidas tanto por TF-IDF como por TextRank
        y devuelve la intersección de ambas listas.

        :param preprocessed_sentences: Lista de oraciones preprocesadas.
        :return: Lista con las palabras clave en común (intersección).
        """
        tfidf_keywords = self.extract_keywords_tfidf(preprocessed_sentences)
        textrank_keywords = self.extract_keywords_textrank(preprocessed_sentences)
        return list(set(tfidf_keywords) & set(textrank_keywords))

    def summarize(self, sentences: List[str], preprocessed_sentences: List[str], num_sentences: int = 1) -> str:
        """
        Genera un resumen basado en la frecuencia de palabras clave combinadas (TF-IDF & TextRank).

        :param sentences: Lista de oraciones originales.
        :param preprocessed_sentences: Lista de oraciones preprocesadas.
        :param num_sentences: Número de oraciones a devolver en el resumen.
        :return: Un string con las oraciones más relevantes.
        """
        keywords = self.combined_keywords(preprocessed_sentences)
        sentence_scores = []
        for i, sentence in enumerate(preprocessed_sentences):
            score = sum(1 for word in sentence.split() if word in keywords)
            sentence_scores.append((score, i))
        # Ordena las oraciones por la cantidad de palabras clave presentes
        ranked_sentences = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
        selected = [sentences[i] for _, i in ranked_sentences[:num_sentences]]
        return ' '.join(selected)


class BERTSummarizer:
    def __init__(self):
        self.model = Summarizer()

    def summarize(self, text, num_sentences):
        return self.model(text, num_sentences=num_sentences)