import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
from sklearn.feature_extraction.text import CountVectorizer
from bertopic import BERTopic
import torch
import numpy as np
from collections import Counter
import os
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import pkg_resources
import folium
import country_converter as coco
import time
import gc
def clear_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
current_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(current_dir, "ArabicR2013-J25x.ttf")
ARABIC_STOP_WORDS = {
    'في', 'من', 'إلى', 'على', 'علي', 'عن', 'مع', 'خلال', 'حتي', 'حتى', 'إذا',
    
    'ثم', 'أو', 'و', 'ل', 'ب', 'ك', 'لل', 'ال', 'هذا', 
    'هذه', 'ذلك', 'تلك', 'هؤلاء', 'هم', 'هن', 'هو', 'هي','هنا', 'نحن',
    'انت', 'انتم', 'كان', 'كانت', 'يكون', 'تكون', 'اي', 'كل',
    'بعض', 'غير', 'حول', 'عند', 'قد', 'لقد', 'لم', 'لن', 'لو',
    'ما', 'ماذا', 'متى', 'كيف', 'اين', 'لماذا', 'الذي', 'التي',
    'الذين', 'اللاتي', 'اللواتي', 'الان', 'بين', 'فوق', 'تحت',
    'امام', 'خلف', 'حين', 'قبل', 'بعد', 'أن', 'له', 'كما', 'لها',
    'منذ', 'نفس', 'حيث', 'هناك', 'جدا', 'ذات', 'ضمن', 'انه', 'لدى',
    'عليه', 'مثل', 'أما', 'لدي', 'فيه', 'كلم', 'لكن', 'ايضا', 'لازم',
     'يجب', 'صار', 'صارت', 'ضد', 'يا', 'لا', 'اما',
    'بها', 'ان', 'به', 'الي', 'لما', 'انا', 'اليك', 'لي', 'لك','اذا','بلا','او','لديك','لديه','اني','كنت','ليس','ايها', 'قلت',
    'وثم', 'وأو', 'ول', 'وب', 'وك', 'ولل', 'وال', 
    'وهذا', 'وهذه', 'وذلك', 'وتلك', 'وهؤلاء', 'وهم', 'وهن', 'وهو', 'وهي', 'ونحن',
    'وانت', 'وانتم', 'وكان', 'وكانت', 'ويكون', 'وتكون', 'واي', 'وكل',
    'وبعض', 'وغير', 'وحول', 'وعند', 'وقد', 'ولقد', 'ولم', 'ولن', 'ولو',
    'وما', 'وماذا', 'ومتى', 'وكيف', 'واين', 'ولماذا', 'والذي', 'والتي',
    'والذين', 'واللاتي', 'واللواتي', 'والان', 'وبين', 'وفوق','وهنا', 'وتحت',
    'وامام', 'وخلف', 'وحين', 'وقبل', 'وبعد', 'وأن', 'وله', 'وكما', 'ولها',
    'ومنذ', 'ونفس', 'وحيث', 'وهناك', 'وجدا', 'وذات', 'وضمن', 'وانه', 'ولدى',
    'وعليه', 'ومثل', 'وأما', 'وفيه', 'وكلم', 'ولكن', 'وايضا', 'ولازم',
     'ويجب', 'وصار', 'وصارت', 'وضد', 'ويا', 'ولا', 'واما',
    'وبها', 'وان', 'وبه', 'والي', 'ولما', 'وانا', 'واليك', 'ولي', 'ولك', 'وقلت',
    
    'وفي', 'ومن', 'وعلى', 'وعلي', 'وعن', 'ومع', 'وحتى', 'وإذا',
    'وهذا', 'وهذه', 'وذلك', 'وتلك', 'وهو', 'وهي', 'ونحن',
    'وكان', 'وكانت', 'وكل', 'وبعض', 'وحول', 'وعند', 'وقد',
    'ولقد', 'ولم', 'ولن', 'وما', 'وكيف', 'واين', 'والذي',
    'وبين', 'وقبل', 'وبعد', 'وله', 'ولها', 'وهناك', 'وانه',
    'منه','الا','فيها','فلا','وكم','يكن','عليك','منها','فما','لهم','يكن','واني','هل','فهل','بي','نحو','كي','سوف','كنا','لنا','معا','كلما','وإذا','منه','عنه','إذ','كم','بل','فيها','هكذا','لهم','ولدى', 'وعليه', 'ومثل',
    'واحد', 'اثنان', 'ثلاثة', 'أربعة', 'خمسة', 'ستة', 'سبعة', 
    'ثمانية', 'تسعة', 'عشرة',
    'الأول', 'الثاني', 'الثالث', 'الرابع', 'الخامس', 'السادس', 
    'السابع', 'الثامن', 'التاسع', 'العاشر'
}
COUNTRY_MAPPING = {
    'مصر': 'Egypt',
    'السعودية': 'Saudi Arabia',
    'الإمارات': 'UAE',
    'الكويت': 'Kuwait',
    'العراق': 'Iraq',
    'سوريا': 'Syria',
    'لبنان': 'Lebanon',
    'الأردن': 'Jordan',
    'فلسطين': 'Palestine',
    'اليمن': 'Yemen',
    'عمان': 'Oman',
    'قطر': 'Qatar',
    'البحرين': 'Bahrain',
    'السودان': 'Sudan',
    'ليبيا': 'Libya',
    'تونس': 'Tunisia',
    'الجزائر': 'Algeria',
    'المغرب': 'Morocco',
    'موريتانيا': 'Mauritania'
}
st.set_page_config(
    page_title="Contemporary Arabic Poetry Analysis",
    page_icon="📚",
    layout="wide"
)
@st.cache_resource
def load_models():
    """Load and cache the models to prevent reloading"""
    tokenizer = AutoTokenizer.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment")
    bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
    emotion_model = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment")
    emotion_tokenizer = AutoTokenizer.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment")
    emotion_classifier = pipeline(
        "sentiment-analysis",
        model=emotion_model,
        tokenizer=emotion_tokenizer,
        return_all_scores=True
    )
    return tokenizer, bert_model, emotion_classifier
def split_text(text, max_length=512):
    """Split text into chunks of maximum token length while preserving word boundaries."""
    words = text.split()
    chunks = []
    current_chunk = []
    current_length = 0
    
    for word in words:
        word_length = len(word.split())
        if current_length + word_length > max_length:
            if current_chunk:
                chunks.append(' '.join(current_chunk))
            current_chunk = [word]
            current_length = word_length
        else:
            current_chunk.append(word)
            current_length += word_length
    
    if current_chunk:
        chunks.append(' '.join(current_chunk))
    
    return chunks
    
def get_country_coordinates():
    """Returns dictionary of Arab country coordinates"""
    return {
        'Egypt': [26.8206, 30.8025],
        'Saudi Arabia': [23.8859, 45.0792],
        'UAE': [23.4241, 53.8478],
        'Kuwait': [29.3117, 47.4818],
        'Iraq': [33.2232, 43.6793],
        'Syria': [34.8021, 38.9968],
        'Lebanon': [33.8547, 35.8623],
        'Jordan': [30.5852, 36.2384],
        'Palestine': [31.9522, 35.2332],
        'Yemen': [15.5527, 48.5164],
        'Oman': [21.4735, 55.9754],
        'Qatar': [25.3548, 51.1839],
        'Bahrain': [26.0667, 50.5577],
        'Sudan': [12.8628, 30.2176],
        'Libya': [26.3351, 17.2283],
        'Tunisia': [33.8869, 9.5375],
        'Algeria': [28.0339, 1.6596],
        'Morocco': [31.7917, -7.0926],
        'Mauritania': [21.0079, -10.9408]
    }
def create_topic_map(summaries):
    # Debug print to check incoming data
    print("DEBUG - First summary emotions:", summaries[0]['top_emotions'])
    
    coordinates = get_country_coordinates()
    m = folium.Map(location=[27.0, 42.0], zoom_start=5)
    
    sentiment_colors = {
        'LABEL_1': 'green',  # Positive
        'LABEL_0': 'red',    # Negative 
        'LABEL_2': 'blue'    # Neutral
    }
    
    for summary in summaries:
        country_en = COUNTRY_MAPPING.get(summary['country'])
        if country_en and country_en in coordinates:
            REVERSE_EMOTION_LABELS = {
                'positive': 'LABEL_1',
                'negative': 'LABEL_0', 
                'neutral': 'LABEL_2'
            }
            
            dominant_emotion = summary['top_emotions'][0]['emotion'] if summary['top_emotions'] else "neutral"
            dominant_label = REVERSE_EMOTION_LABELS.get(dominant_emotion, 'LABEL_2')
            circle_color = sentiment_colors.get(dominant_label, 'gray')
            
            # Debug print
            print(f"DEBUG - Country: {country_en}, Emotion: {dominant_emotion}, Label: {dominant_label}, Color: {circle_color}")
            
            popup_content = f"""
                {country_en}
                Sentiment Distribution:
                {'
'.join(f"• {e['emotion']}: {e['count']}" for e in summary['top_emotions'][:3])}
                Top Topic:
                {summary['top_topics'][0]['topic'] if summary['top_topics'] else 'No topics'}
                Total Poems: {summary['total_poems']}
            """
            
            folium.CircleMarker(
                location=coordinates[country_en],
                radius=10,
                popup=folium.Popup(popup_content, max_width=300),
                color=circle_color,
                fill=True
            ).add_to(m)
    
    legend_html = """
    
    Sentiment:
    ● Positive
    ● Negative
    ● Neutral
     
    """
    m.get_root().html.add_child(folium.Element(legend_html))
    
    return m
def create_arabic_wordcloud(text, title):
    wordcloud = WordCloud(
        width=1200, 
        height=600,
        background_color='white',
        font_path=font_path,
        max_words=200,
        stopwords=ARABIC_STOP_WORDS
    ).generate(text)
    
    fig, ax = plt.subplots(figsize=(15, 8))
    ax.imshow(wordcloud, interpolation='bilinear')
    ax.axis('off')
    ax.set_title(title, fontsize=16, pad=20)
    return fig
def clean_arabic_text(text):
    """Clean Arabic text by removing stop words and normalizing."""
    words = text.split()
    cleaned_words = [word for word in words if word not in ARABIC_STOP_WORDS and len(word) > 1]
    return ' '.join(cleaned_words)
def classify_emotion(text, classifier):
    """Classify emotion for complete text with proper token handling."""
    try:
        words = text.split()
        chunks = []
        current_chunk = []
        current_length = 0
        
        for word in words:
            word_tokens = len(classifier.tokenizer.encode(word))
            if current_length + word_tokens > 512:
                if current_chunk:
                    chunks.append(' '.join(current_chunk))
                current_chunk = [word]
                current_length = word_tokens
            else:
                current_chunk.append(word)
                current_length += word_tokens
        
        if current_chunk:
            chunks.append(' '.join(current_chunk))
        
        if not chunks:
            chunks = [text]
        
        all_scores = []
        for chunk in chunks:
            try:
                inputs = classifier.tokenizer(
                    chunk,
                    truncation=True,
                    max_length=512,
                    return_tensors="pt"
                )
                result = classifier(chunk, truncation=True, max_length=512)
                scores = result[0]
                all_scores.append(scores)
            except Exception as chunk_error:
                st.warning(f"Skipping chunk due to error: {str(chunk_error)}")
                continue
        
        if all_scores:
            label_scores = {}
            count = len(all_scores)
            
            for scores in all_scores:
                for score in scores:
                    label = score['label']
                    if label not in label_scores:
                        label_scores[label] = 0
                    label_scores[label] += score['score']
            
            avg_scores = {label: score/count for label, score in label_scores.items()}
            final_emotion = max(avg_scores.items(), key=lambda x: x[1])[0]
            return final_emotion
        
        return "LABEL_2"
        
    except Exception as e:
        st.warning(f"Error in emotion classification: {str(e)}")
        return "LABEL_2"
def get_embedding_for_text(text, tokenizer, model):
    """Get embedding for complete text."""
    chunks = split_text(text)
    chunk_embeddings = []
    
    for chunk in chunks:
        try:
            inputs = tokenizer(
                chunk,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            )
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model(**inputs)
            
            embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            chunk_embeddings.append(embedding[0])
        except Exception as e:
            st.warning(f"Error processing chunk: {str(e)}")
            continue
    
    if chunk_embeddings:
        weights = np.array([len(chunk.split()) for chunk in chunks])
        weights = weights / weights.sum()
        weighted_embedding = np.average(chunk_embeddings, axis=0, weights=weights)
        return weighted_embedding
    return np.zeros(model.config.hidden_size)
def format_topics(topic_model, topic_counts):
    """Format topics for display."""
    formatted_topics = []
    for topic_num, count in topic_counts:
        if topic_num == -1:
            topic_label = "Miscellaneous"
        else:
            words = topic_model.get_topic(topic_num)
            topic_label = " | ".join([word for word, _ in words[:5]])
        
        formatted_topics.append({
            'topic': topic_label,
            'count': count
        })
    return formatted_topics
def format_emotions(emotion_counts):
    """Format emotions for display."""
    EMOTION_LABELS = {
        'LABEL_0': 'Negative',
        'LABEL_1': 'Positive',
        'LABEL_2': 'Neutral'
    }
    
    formatted_emotions = []
    for label, count in emotion_counts:
        emotion = EMOTION_LABELS.get(label, label)
        formatted_emotions.append({
            'emotion': emotion,
            'count': count
        })
    return formatted_emotions
    
def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, top_n=50, topic_strategy="Auto", n_topics=None, min_topic_size=3):
    """Process the data and generate summaries with flexible topic configuration."""
    summaries = []
    
    topic_model_params = {
        "language": "arabic",
        "calculate_probabilities": True,
        "min_topic_size": 3,
        "n_gram_range": (1, 1),
        "top_n_words": 15,
        "verbose": True,
    }
    st.write(f"Total documents: {len(df)}")
    st.write(f"Topic strategy: {topic_strategy}")
    st.write(f"Min topic size: {min_topic_size}")
    
    if topic_strategy == "Manual":
        topic_model_params["nr_topics"] = n_topics
    else:
        topic_model_params["nr_topics"] = "auto"
    
    topic_model = BERTopic(
        embedding_model=bert_model,
        **topic_model_params)
    
    vectorizer = CountVectorizer(stop_words=list(ARABIC_STOP_WORDS),
                                min_df=1,
                                max_df=1.0)
    topic_model.vectorizer_model = vectorizer
    
    for country, group in df.groupby('country'):
        progress_text = f"Processing poems for {country}..."
        progress_bar = st.progress(0, text=progress_text)
        
        texts = [clean_arabic_text(poem) for poem in group['poem'].dropna()]
        all_emotions = []
        
        embeddings = []
        clear_memory()
        
        for i, text in enumerate(texts):
            try:
                embedding = get_embedding_for_text(text, bert_tokenizer, bert_model)
                if embedding is not None and not np.isnan(embedding).any():
                    embeddings.append(embedding)
                else:
                    st.warning(f"Invalid embedding generated for text {i+1} in {country}")
                    continue
            except Exception as e:
                st.warning(f"Error generating embedding for text {i+1} in {country}: {str(e)}")
                continue
            if i % 10 == 0:
                clear_memory()
            
            progress = (i + 1) / len(texts) * 0.4
            progress_bar.progress(progress, text=f"Generated embeddings for {i+1}/{len(texts)} poems...")
        
        if len(embeddings) != len(texts):
            texts = texts[:len(embeddings)]
        embeddings = np.array(embeddings)
        
        clear_memory()
        for i, text in enumerate(texts):
            emotion = classify_emotion(text, emotion_classifier)
            all_emotions.append(emotion)
            if i % 10 == 0:
                clear_memory()
            progress = 0.4 + ((i + 1) / len(texts) * 0.3)
            progress_bar.progress(progress, text=f"Classified emotions for {i+1}/{len(texts)} poems...")
        try:
            
            if len(texts) < min_topic_size:
                st.warning(f"Not enough documents for {country} to generate meaningful topics (minimum {min_topic_size} required)")
                continue
                
            
            topics, probs = topic_model.fit_transform(texts, embeddings)
            
            
            topic_counts = Counter(topics)
            
            top_topics = format_topics(topic_model, topic_counts.most_common(top_n))
            top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
            
            summaries.append({
                'country': country,
                'total_poems': len(texts),
                'top_topics': top_topics,
                'top_emotions': top_emotions
            })
            progress_bar.progress(1.0, text="Processing complete!")
            
        except Exception as e:
            st.warning(f"Could not generate topics for {country}: {str(e)}")
            continue
    return summaries, topic_model
try:
    bert_tokenizer, bert_model, emotion_classifier = load_models()
    st.success("Models loaded successfully!")
except Exception as e:
    st.error(f"Error loading models: {str(e)}")
    st.stop()
# Main app interface
st.title("📚 Contemporary Arabic Poetry Analysis")
st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.")
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
if uploaded_file is not None:
    try:
        if uploaded_file.name.endswith('.csv'):
            df = pd.read_csv(uploaded_file)
        else:
            df = pd.read_excel(uploaded_file)
        
        required_columns = ['country', 'poem']
        if not all(col in df.columns for col in required_columns):
            st.error("File must contain 'country' and 'poem' columns.")
            st.stop()
        
        df['country'] = df['country'].str.strip()
        df = df.dropna(subset=['country', 'poem'])
        sampled_df = df.groupby('country').apply(lambda x: x.head(20)).reset_index(drop=True)
        
        st.subheader("Topic Modeling Settings")
        col1, col2 = st.columns(2)
        
        with col1:
            topic_strategy = st.radio(
                "Topic Number Strategy",
                ["Auto", "Manual"],
                help="Choose whether to let the model determine the optimal number of topics or set it manually"
            )
            
            if topic_strategy == "Manual":
                n_documents = len(df)
                max_topics = 500
                min_topics = 5
                default_topics = 20
                
                n_topics = st.slider(
                    "Number of Topics",
                    min_value=min_topics,
                    max_value=max_topics,
                    value=default_topics,
                    help=f"Select the desired number of topics (max {max_topics} based on dataset size)"
                )
                
                st.info(f"""
                    💡 For your dataset of {n_documents:,} documents:
                    - Available topic range: {min_topics}-{max_topics}
                    - Recommended range: {max_topics//10}-{max_topics//3} for optimal coherence
                    """)
        
        with col2:
            top_n = st.number_input(
                "Number of top topics/emotions to display:", 
                min_value=1, 
                max_value=100, 
                value=10
            )
        if st.button("Process Data"):
            with st.spinner("Processing your data..."):
                summaries, topic_model = process_and_summarize(
                    sampled_df,
                    bert_tokenizer,
                    bert_model,
                    emotion_classifier,
                    top_n=top_n,
                    topic_strategy=topic_strategy,
                    n_topics=n_topics if topic_strategy == "Manual" else None,
                    min_topic_size=3
                )
                                
                if summaries:
                    st.success("Analysis complete!")
                    
                    tab1, tab2, tab3 = st.tabs(["Country Summaries", "Global Topics", "Topic Map"])
                    
                    with tab1:
                        for summary in summaries:
                            with st.expander(f"📍 {summary['country']} ({summary['total_poems']} poems)"):
                                col1, col2 = st.columns(2)
                                
                                with col1:
                                    st.subheader("Top Topics")
                                    for topic in summary['top_topics']:
                                        st.write(f"• {topic['topic']}: {topic['count']} poems")
                                
                                with col2:
                                    st.subheader("Emotions")
                                    for emotion in summary['top_emotions']:
                                        st.write(f"• {emotion['emotion']}: {emotion['count']} poems")
                                st.subheader("Word Cloud Visualization")
                                country_poems = df[df['country'] == summary['country']]['poem']
                                combined_text = ' '.join(country_poems)
                                wordcloud_fig = create_arabic_wordcloud(combined_text, f"Most Common Words in {summary['country']} Poems")
                                st.pyplot(wordcloud_fig)                                
                    
                    with tab2:
                        st.subheader("Global Topic Distribution")
                        topic_info = topic_model.get_topic_info()
                        for _, row in topic_info.iterrows():
                            if row['Topic'] == -1:
                                topic_name = "Miscellaneous"
                            else:
                                words = topic_model.get_topic(row['Topic'])
                                topic_name = " | ".join([word for word, _ in words[:5]])
                            st.write(f"• Topic {row['Topic']}: {topic_name} ({row['Count']} poems)")
                    with tab3:
                        st.subheader("Topic and Sentiment Distribution Map")
                        topic_map = create_topic_map(summaries)
                        st.components.v1.html(topic_map._repr_html_(), height=600)
    
    except Exception as e:
        st.error(f"Error processing file: {str(e)}")
else:
    st.info("👆 Upload a file to get started!")
    
    st.write("### Expected File Format:")
    example_df = pd.DataFrame({
        'country': ['Egypt', 'Palestine'],
        'poem': ['قصيدة مصرية', 'قصيدة فلسطينية']
    })
    st.dataframe(example_df)