import streamlit as st
import pandas as pd
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
from span_marker import SpanMarkerModel
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, TextGeneration, PartOfSpeech
from torch import cuda
from spacy.cli import download
import transformers
from torch import bfloat16
import os
import scipy.cluster.hierarchy as sch  # HIERARCHY

# ------------------------------------------------------------------------------
# Funzione per ottenere la configurazione della lingua
# ------------------------------------------------------------------------------
def get_language_config(selected_language):
    """
    Restituisce un dizionario di configurazione in base alla lingua selezionata.
    Include il modello spaCy, il modello linguistico per il rilevamento (SpanMarker)
    e i parametri per DataForSEO.
    """
    language_options = {
        "English (US)": {
            "spacy_model": "en_core_web_sm",
            "linguistic_model": "nbroad/span-marker-xdistil-l12-h384-orgs-v3",
            "dataforseo_params": {"language": "en-us"}
        },
        "English (UK)": {
            "spacy_model": "en_core_web_sm",  # spaCy non ha un modello UK specifico, si usa quello standard
            "linguistic_model": "nbroad/span-marker-xdistil-l12-h384-orgs-v3",
            "dataforseo_params": {"language": "en-gb"}
        },
        "Italiano": {
            "spacy_model": "it_core_news_sm",
            "linguistic_model": "nbroad/span-marker-xdistil-l12-h384-orgs-v3",  # Sostituire con il modello appropriato se disponibile
            "dataforseo_params": {"language": "it-it"}
        },
        "Español": {
            "spacy_model": "es_core_news_sm",
            "linguistic_model": "nbroad/span-marker-xdistil-l12-h384-orgs-v3",
            "dataforseo_params": {"language": "es-es"}
        },
        "Deutsch": {
            "spacy_model": "de_core_news_sm",
            "linguistic_model": "nbroad/span-marker-xdistil-l12-h384-orgs-v3",
            "dataforseo_params": {"language": "de-de"}
        },
        "Français": {
            "spacy_model": "fr_core_news_sm",
            "linguistic_model": "nbroad/span-marker-xdistil-l12-h384-orgs-v3",
            "dataforseo_params": {"language": "fr-fr"}
        }
    }
    return language_options.get(selected_language, language_options["English (US)"])

# ------------------------------------------------------------------------------
# Configurazione della pagina
# ------------------------------------------------------------------------------
st.set_page_config(
    page_title="Keywords Cluster for SEO",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        'Get Help': 'https://www.linkedin.com/in/francisco-nardi-212b338b/',
        'Report a bug': "https://www.linkedin.com/in/francisco-nardi-212b338b/",
        'About': "# A simple keywords clustering tool for SEO purpose."
    }
)

# Inizializzazione della sessione (opzionale)
if 'model_loaded' not in st.session_state:
    st.session_state.model_loaded = False
if 'analysis_complete' not in st.session_state:
    st.session_state.analysis_complete = False
if 'current_step' not in st.session_state:
    st.session_state.current_step = 0

# Stili CSS personalizzati
st.markdown("""
    <style>
        .stProgress > div > div > div > div {
            background-color: #1f77b4;
        }
        .success-message {
            padding: 1rem;
            border-radius: 0.5rem;
            background-color: #d4edda;
            color: #155724;
            border: 1px solid #c3e6cb;
            margin-bottom: 1rem;
        }
        .info-box {
            padding: 1rem;
            border-radius: 0.5rem;
            background-color: #e2f0fd;
            border: 1px solid #b8daff;
            margin-bottom: 1rem;
        }
        .sidebar .sidebar-content {
            width: 400px !important;
        }
    </style>
    """, unsafe_allow_html=True)

# ------------------------------------------------------------------------------
# 1) Caricamento modelli con cache_resource
# ------------------------------------------------------------------------------
@st.cache_resource
def load_models(language_config):
    """Carica i modelli necessari con caching (una sola volta)."""
    with st.spinner("Loading models... This may take a few minutes."):
        try:
            # Scarica il modello spaCy in base alla lingua selezionata
            spacy_model_name = language_config["spacy_model"]
            download(spacy_model_name)
            
            # Modello SpanMarker: rilevazione entità (Brand/Unbranded)
            linguistic_model_name = language_config["linguistic_model"]
            if cuda.is_available():
                model_filter = SpanMarkerModel.from_pretrained(linguistic_model_name).cuda()
            else:
                model_filter = SpanMarkerModel.from_pretrained(linguistic_model_name)
            
            # Modello di embedding SentenceTransformer (resta invariato)
            embedding_model = SentenceTransformer("all-mpnet-base-v2")
            
            return model_filter, embedding_model
        except Exception as e:
            st.error(f"Error loading models: {str(e)}")
            raise

# ------------------------------------------------------------------------------
# 2) Lettura CSV con cache_data
# ------------------------------------------------------------------------------
@st.cache_data
def load_csv(file, skiprows, nrows):
    """Carica il CSV con caching."""
    df = pd.read_csv(file, skiprows=skiprows, nrows=nrows)
    return df

# ------------------------------------------------------------------------------
# 3) Funzione di etichettatura Brand/Unbranded con cache_data
# ------------------------------------------------------------------------------
@st.cache_data
def process_keywords(df, model_filter):
    """
    Rileva eventuali keyword di tipo 'Brand' utilizzando il modello SpanMarker.
    Ritorna la lista di etichette 'Brand' o 'Unbranded' per ciascuna keyword.
    """
    results = []
    total = len(df)
    progress_text = "Processing keywords..."
    progress_bar = st.progress(0, text=progress_text)
    
    for i, keyword in enumerate(df['Keyword']):
        try:
            entities = model_filter.predict([keyword])
            label = (
                "Brand"
                if entities and isinstance(entities[0], list) and any(entity.get("label") == "ORG" for entity in entities[0])
                else "Unbranded"
            )
            results.append(label)
        except Exception as e:
            st.error(f"Error processing keyword '{keyword}': {str(e)}")
            results.append("Unbranded")
        
        progress_bar.progress((i + 1) / total, text=f"{progress_text} ({i+1}/{total})")
    
    progress_bar.empty()
    return results

# ------------------------------------------------------------------------------
# 4) Creazione del modello di topic
# ------------------------------------------------------------------------------
def create_topic_model(embedding_model, model_params, language_config):
    """Crea e configura il modello di topic modeling."""
    try:
        # Configurazione quantizzazione per Hugging Face
        bnb_config = transformers.BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type='nf4',
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=bfloat16
        )

        # Configurazione UMAP
        umap_model = UMAP(
            n_neighbors=model_params['umap_n_neighbors'],
            n_components=model_params['umap_n_components'],
            min_dist=model_params['umap_min_dist'],
            metric='cosine',
            random_state=42
        )
        
        # Configurazione HDBSCAN
        hdbscan_model = HDBSCAN(
            min_cluster_size=model_params['min_cluster_size'],
            min_samples=model_params['min_samples'],
            metric='euclidean',
            cluster_selection_method='eom',
            prediction_data=True
        )
        
        # Configurazione CountVectorizer
        vectorizer_model = CountVectorizer(
            stop_words="english",
            min_df=model_params['min_df'],
            max_df=model_params['max_df'],
            ngram_range=(model_params['ngram_min'], model_params['ngram_max'])
        )

        # Configurazione Llama 2
        model_id = 'meta-llama/Llama-2-7b-chat-hf'
        
        tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
        
        model = transformers.AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            quantization_config=bnb_config,
            device_map='auto',
        )
        model.eval()

        generator = transformers.pipeline(
            model=model, 
            tokenizer=tokenizer,
            task='text-generation',
            temperature=model_params['llama_temperature'],
            max_new_tokens=model_params['llama_max_tokens'],
            repetition_penalty=model_params['llama_repetition_penalty']
        )

        # Prompt configuration
        system_prompt = """
        <s>[INST] <<SYS>>
        You are a helpful, respectful and honest assistant for labeling topics.
        <</SYS>>
        """

        example_prompt = """
        I have a topic that contains the following documents:
        - Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
        - Meat, but especially beef, is the word food in terms of emissions.
        - Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
        
        The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
        
        Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
        
        [/INST] Environmental impacts of eating meat
        """

        main_prompt = """
        [INST]
        I have a topic that contains the following documents:
        [DOCUMENTS]
        
        The topic is described by the following keywords: '[KEYWORDS]'.
        
        Based on the information about the topic above, please create a **short label** of this topic. 
        **Return only the label** and avoid adding any explanations or extra text such as 'topic'.
        [/INST]
        """

        prompt = system_prompt + example_prompt + main_prompt
        
        # Create representation models
        keybert_model = KeyBERTInspired()
        # Utilizza il modello spaCy in base alla lingua selezionata
        pos_model = PartOfSpeech(language_config["spacy_model"])
        mmr_model = MaximalMarginalRelevance(diversity=model_params['diversity_factor'])
        llama2 = TextGeneration(generator, prompt=prompt)
        
        representation_model = {
            "KeyBERT": keybert_model,
            "Llama2": llama2,
            "MMR": mmr_model,
            "POS": pos_model
        }
        
        return BERTopic(
            embedding_model=embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            vectorizer_model=vectorizer_model,
            representation_model=representation_model,
            top_n_words=model_params['top_n_words'],
            verbose=True
        )
    except Exception as e:
        st.error(f"Error creating topic model: {str(e)}")
        raise

# ------------------------------------------------------------------------------
# 5) Analisi principale (cachiamo i risultati finali dell'analisi)
# ------------------------------------------------------------------------------
@st.cache_data(hash_funcs={
    SpanMarkerModel: lambda _: None,       # ignora hashing per SpanMarker
    SentenceTransformer: lambda _: None      # ignora hashing per SentenceTransformer
})
def run_analysis(df, model_filter, embedding_model, model_params, exclude_brand_keywords, language_config):
    """
    - Etichetta (facoltativo) come 'Brand' o 'Unbranded'
    - Filtra i brand se richiesto
    - Crea embeddings
    - Esegue il topic modeling
    - Restituisce il modello e il DataFrame dei risultati
    """
    # Se l'utente sceglie di escludere i brand, etichettiamo e filtriamo
    if exclude_brand_keywords:
        df['Label'] = process_keywords(df, model_filter)
        filtered_df = df[df['Label'] == 'Unbranded']
    else:
        df['Label'] = "Unbranded"
        filtered_df = df

    filtered_keywords = filtered_df['Keyword'].tolist()
    
    if not filtered_keywords:
        st.warning("No keywords found for analysis (perhaps all were branded).")
        return None, None
    
    # Genera embeddings
    embeddings = embedding_model.encode(filtered_keywords, show_progress_bar=True)
    
    # Crea e applica topic model (passando anche la configurazione della lingua)
    topic_model = create_topic_model(embedding_model, model_params, language_config)
    topics, probs = topic_model.fit_transform(filtered_keywords, embeddings)
    
    # Ottieni gli embeddings ridotti per la visualizzazione
    reduced_embeddings = topic_model.umap_model.embedding_

    # Usa i label generati da Llama 2 come label finali
    llama_topic_labels = {
        topic: "".join(list(zip(*values))[0]) 
        for topic, values in topic_model.topic_aspects_["Llama2"].items()
    }
    llama_topic_labels[-1] = "Outlier Topic"
    topic_model.set_topic_labels(llama_topic_labels)

    # Ottieni le informazioni sui topic
    topic_info = topic_model.get_topic_info()
    topic_labels = dict(zip(topic_info["Topic"], topic_info["CustomName"]))      
    
    # Ottieni le informazioni di default BERT
    bert_labels = dict(zip(topic_info["Topic"], topic_info["Name"]))
    
    # Creiamo il DataFrame dei risultati
    results_df = pd.DataFrame({
        "Keyword": filtered_keywords,
        "Topic ID": topics,
        "Confidence": probs
    })
    
    # Aggiungiamo le label Llama e BERT
    results_df["Llama label"] = [
        topic_labels[topic] if topic in topic_labels else "Outlier Topic" 
        for topic in topics
    ]
    results_df["BERT label"] = [
        bert_labels[topic] if topic in bert_labels else "Outlier Topic" 
        for topic in topics
    ]
    
    # Se nel CSV c'è una colonna 'Volume', la aggiungiamo
    if "Volume" in filtered_df.columns:
        results_df["Volume"] = filtered_df["Volume"].values

    return topic_model, results_df

# ------------------------------------------------------------------------------
# 6) Main Streamlit App
# ------------------------------------------------------------------------------
def main():
    st.title("🔍 Keywords Cluster for SEO")
    
    # ------------------------------------------------------------------------------
    # Sidebar: Selezione della lingua e configurazioni
    # ------------------------------------------------------------------------------
    with st.sidebar:
        st.header("Configuration")
        
        # Selezione della lingua
        selected_language = st.selectbox(
            "Select Language",
            ["English (US)", "English (UK)", "Italiano", "Español", "Deutsch", "Français"],
            index=0,
            help="Seleziona la lingua per l'analisi. Questo imposterà il modello spaCy, il modello linguistico per il rilevamento e i parametri per DataForSEO."
        )
        language_config = get_language_config(selected_language)
        
        # File upload e configurazione righe
        uploaded_file = st.file_uploader(
            "Upload CSV file",
            type="csv",
            help="File must contain a 'Keyword' column"
        )
        
        with st.expander("CSV Reading Options"):
            min_rows = st.number_input(
                "Start reading from row",
                min_value=1,
                value=1,
                help="Define the first row of the CSV file from which data should be read."
            )
            max_rows = st.number_input(
                "Maximum rows to read",
                min_value=1,
                value=5000,
                help="Define how many rows in total to read from the CSV file, starting from the row defined above."
            )
        
        # Opzione per escludere keyword brand
        exclude_brands = st.checkbox(
            "Exclude Organization keywords",
            value=False,
            help="If enabled, organization-labeled keywords are excluded from the analysis. (ex. company ltd)"
        )
            
        # Parametri UMAP
        with st.expander("UMAP Parameters"):
            umap_n_neighbors = st.slider("N Neighbors", 2, 100, 10)
            umap_n_components = st.slider("N Components", 2, 50, 2)
            umap_min_dist = st.slider("Min Distance", 0.0, 1.0, 0.0, 0.01)
            
        # Parametri HDBSCAN
        with st.expander("HDBSCAN Parameters"):
            min_cluster_size = st.slider("Min Cluster Size", 2, 50, 5)
            min_samples = st.slider("Min Samples", 1, 20, 5)
            
        # Parametri Vectorizer
        with st.expander("Vectorizer Parameters"):
            min_df_type = st.radio(
                "Min Document Frequency Type",
                ["Absolute", "Relative"],
                help="Absolute: minimum count of documents, Relative: minimum fraction of documents"
            )
            
            if min_df_type == "Absolute":
                min_df = st.number_input("Min Document Count", 1, 100, 2)
            else:
                min_df = st.slider("Min Document Fraction", 0.0, 0.5, 0.1, 0.01)
            
            max_df = st.slider(
                "Max Document Fraction", 
                min_value=float(min_df) if isinstance(min_df, float) else 0.5,
                max_value=1.0,
                value=0.95,
                step=0.05
            )
            
            st.info(
                f"Documents must appear in at least {min_df} "
                f"{'documents' if isinstance(min_df, int) else '% of documents'} "
                f"and at most {int(max_df * 100)}% of documents"
            )
            
            ngram_min = st.number_input("N-gram Min", 1, 3, 1)
            ngram_max = st.number_input("N-gram Max", 1, 3, 2)
            
        # Parametri Topic Model
        with st.expander("Topic Model Parameters"):
            top_n_words = st.slider("Top N Words", 5, 30, 10)
            diversity_factor = st.slider("Topic Diversity", 0.0, 1.0, 0.3)
            
        # Parametri Llama 2
        with st.expander("Llama 2 Parameters"):
            llama_temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1)
            llama_max_tokens = st.slider("Max Tokens", 50, 200, 100)
            llama_repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1, 0.1)
            
        # Help section
        with st.expander("ℹ️ Help"):
            st.markdown("""
            **How to use this app:**
            1. Upload a CSV file with keywords
            2. Configure CSV reading options
            3. (Optionally) check "Exclude brand-labeled keywords"
            4. Adjust model parameters if needed
            5. Click 'Start Analysis'
            6. Wait for results to appear
            
            **Advanced Parameters:**
            - UMAP: Controls dimensionality reduction
            - HDBSCAN: Controls clustering behavior
            - Vectorizer: Controls text preprocessing
            - Topic Model: Controls topic generation
            - Llama 2: Controls topic labeling
            
            **Language Selection:**
            Selezionando la lingua verranno impostati:
            - Il modello spaCy da utilizzare (per es. 'en_core_web_sm' per English o 'it_core_news_sm' per Italiano)
            - Il modello linguistico per il rilevamento (SpanMarker) (sostituisci i placeholder con i modelli corretti se disponibili)
            - I parametri per DataForSEO (ad es. il codice lingua come 'en-us', 'it-it', ecc.)
            """)
    
    # ------------------------------------------------------------------------------
    # 7) Prepariamo dizionario parametri per il topic model
    # ------------------------------------------------------------------------------
    model_params = {
        'umap_n_neighbors': umap_n_neighbors,
        'umap_n_components': umap_n_components,
        'umap_min_dist': umap_min_dist,
        'min_cluster_size': min_cluster_size,
        'min_samples': min_samples,
        'min_df': min_df,
        'max_df': max_df,
        'ngram_min': ngram_min,
        'ngram_max': ngram_max,
        'top_n_words': top_n_words,
        'diversity_factor': diversity_factor,
        'llama_temperature': llama_temperature,
        'llama_max_tokens': llama_max_tokens,
        'llama_repetition_penalty': llama_repetition_penalty
    }
    
    # ------------------------------------------------------------------------------
    # 8) Se abbiamo caricato un file, procediamo
    # ------------------------------------------------------------------------------
    if uploaded_file is not None:
        try:
            # Carica dati con caching
            df = load_csv(
                file=uploaded_file,
                skiprows=min_rows - 1,
                nrows=max_rows - min_rows + 1
            )
            
            if 'Keyword' not in df.columns:
                st.error("CSV must contain a 'Keyword' column")
                return
            
            # Preview dati
            with st.expander("Preview Data", expanded=True):
                st.write(f"Reading rows {min_rows} to {max_rows}")
                st.dataframe(
                    df.head(),
                    use_container_width=True
                )
                st.write(f"Total rows loaded: {len(df)}")
            
            # Pulsante per avviare l'analisi
            if st.button("Start Analysis", type="primary"):
                try:
                    # Carichiamo i modelli (cache_resource) con la configurazione della lingua
                    with st.spinner("Loading models..."):
                        model_filter, embedding_model = load_models(language_config)
                    
                    # Eseguiamo l'analisi (cache_data)
                    with st.spinner("Processing data..."):
                        topic_model, results_df = run_analysis(
                            df,
                            model_filter,
                            embedding_model,
                            model_params,
                            exclude_brand_keywords=exclude_brands,
                            language_config=language_config
                        )
                        
                        if topic_model is None or results_df is None:
                            st.error("Analysis failed!")
                            return
                    
                    # Visualizza riepilogo configurazione
                    with st.expander("Configuration Summary", expanded=False):
                        st.subheader("Model Parameters")
                        st.json(model_params)
                        st.subheader("Language Configuration")
                        st.json(language_config)
                    
                    # ------------------------------------------------------------------------------
                    # 9) Mostra risultati
                    # ------------------------------------------------------------------------------
                    st.write("### Results Table")
                    st.dataframe(results_df, use_container_width=True, hide_index=True)
                    
                    # Visualizza la dashboard interattiva
                    st.write("### Interactive Topic Visualization")
                    try:
                        # Embedding ridotto
                        fig = topic_model.visualize_documents(
                            results_df['Keyword'].tolist(),
                            reduced_embeddings=topic_model.umap_model.embedding_,
                            hide_annotations=True,
                            hide_document_hover=False,
                            custom_labels=True
                        )
                        st.plotly_chart(fig, theme="streamlit", use_container_width=True)

                        # Visualizzazione dei topic
                        st.write("### Topic Overview")
                        try:
                            topic_fig = topic_model.visualize_topics(custom_labels=True)
                            st.plotly_chart(topic_fig, theme="streamlit", use_container_width=True)
                        except Exception as e:
                            st.error(f"Error creating topic visualization: {str(e)}")

                        # Visualizzazione barchart dei topic
                        st.write("### Topic Distribution")
                        try:
                            n_topics = len(topic_model.get_topic_info())
                            n_topics = min(50, max(1, n_topics - 1))  # -1 per outlier
                            
                            barchart_fig = topic_model.visualize_barchart(
                                top_n_topics=n_topics,
                                custom_labels=True
                            )
                            st.plotly_chart(barchart_fig, theme="streamlit", use_container_width=True)
                        except Exception as e:
                            st.error(f"Error creating barchart visualization: {str(e)}")
                        
                        # (A) AGGIUNTA: Visualizzazione gerarchica dei topic
                        st.write("### Hierarchical Topics")
                        try:
                            docs = results_df["Keyword"].tolist()
                            linkage_function = lambda x: sch.linkage(x, 'single', optimal_ordering=True)
                            hierarchical_topics = topic_model.hierarchical_topics(
                                docs,
                                linkage_function=linkage_function
                            )
                            # Grafico gerarchico
                            fig_hierarchy = topic_model.visualize_hierarchy(
                                hierarchical_topics=hierarchical_topics,
                                custom_labels=True
                            )
                            st.plotly_chart(fig_hierarchy, theme="streamlit", use_container_width=True)
                        
                            # (B) AGGIUNTA: Visualizzazione testuale dell'albero
                            st.write("### Hierarchical Topic Tree")
                            tree = topic_model.get_topic_tree(hierarchical_topics)
                            st.text(tree)  # Oppure st.code(tree) per un blocco formattato
                        
                        except Exception as e:
                            st.error(f"Error creating hierarchical visualization: {str(e)}")
                    
                        # Download risultati in CSV
                        st.download_button(
                            label="Download Results",
                            data=results_df.to_csv(index=False),
                            file_name="keyword_analysis_results.csv",
                            mime="text/csv",
                            key="download_results"
                        )
                    except Exception as e:
                        st.error(f"An error occurred: {str(e)}")
                        
                except Exception as e:
                    st.error(f"An error occurred: {str(e)}")
        except Exception as e:
            st.error(f"Error reading file: {str(e)}")

    else:
        # Messaggio iniziale
        st.info("""
        👋 Welcome to the Keywords Cluster for SEO!
        
        1. Upload a CSV file with a column named **'Keyword'**.
        2. Adjust parameters in the sidebar if needed.
        3. Click **"Start Analysis"**.
        4. Explore the data.
        5. Download the results (this will refresh page).
        """)


if __name__ == "__main__":
    main()