import streamlit as st
import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from modelling_cnn import CNNForNER, SentimentCNNModel
import pandas as pd
import altair as alt

# Load the Yoruba NER model
# ner_model_name = "./my_model/pytorch_model.bin"
# model_ner = "Testys/cnn_yor_ner"
# ner_tokenizer = AutoTokenizer.from_pretrained(model_ner)
# with open("./my_model/config.json", "r") as f:
#     ner_config = json.load(f)

# ner_model = CNNForNER(
#                       pretrained_model_name=ner_config["pretrained_model_name"],
#                       num_classes=ner_config["num_classes"]
#                       )
# ner_model.load_state_dict(torch.load(ner_model_name, map_location=torch.device('cpu')))
# ner_model.eval()

ner_model = AutoModelForTokenClassification.from_pretrained("masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0")
ner_tokenizer = AutoTokenizer.from_pretrained("masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0")
ner_config = ner_model.config

ner_model.eval()


# Load the Yoruba sentiment analysis model
sentiment_model_name = "./sent_model/sent_pytorch_model.bin"
model_sent = "Testys/cnn_sent_yor"
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent)

with open("./sent_model/config.json", "r") as f:
    sentiment_config = json.load(f)

sentiment_model = SentimentCNNModel(
                                    transformer_model_name=sentiment_config["pretrained_model_name"],
                                    num_classes=sentiment_config["num_classes"]
                                    )

sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
sentiment_model.eval()


def analyze_text(text):
    # Tokenize input text for NER
    ner_inputs = ner_tokenizer(text, return_tensors="pt")
    
    # Perform Named Entity Recognition
    tokens = ner_tokenizer.convert_ids_to_tokens(ner_inputs.input_ids[0])
    with torch.no_grad():
        ner_outputs = ner_model(**ner_inputs)
    
    print(ner_outputs)
    
    ner_predictions = torch.argmax(ner_outputs.logits, dim=-1)[0]
    ner_labels = ner_predictions.tolist()
    print(ner_labels)
    ner_labels = [ner_config.id2label[label] for label in ner_labels]

    #matching the tokens with the labels
    ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)]

    # Tokenize input text for sentiment analysis
    sentiment_inputs = sentiment_tokenizer(text, max_length= 514, truncation= True, padding= "max_length", return_tensors="pt")

    # Perform sentiment analysis
    with torch.no_grad():
        sentiment_outputs = sentiment_model(**sentiment_inputs)
    sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1)
    sentiment_scores = sentiment_probabilities.tolist()
    sentiment_id = sentiment_scores[0]
    sentiment = sentiment_config["id2label"][str(sentiment_id)]

    return ner_labels, sentiment

def main():
    st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide")
    
    st.title("YorubaCNN Models for NER and Sentiment Analysis")
    
    # Input text
    text = st.text_area("Enter Yoruba text", "")
    
    if st.button("Analyze"):
        if text:
            ner_labels, sentiment = analyze_text(text)
            
            # Display Named Entities
            st.header("Named Entities")
            
            # Convert NER results to DataFrame
            ner_df = pd.DataFrame([label.split(': ') for label in ner_labels], columns=['Token', 'Entity'])
            
            # Display NER results in a styled table
            st.dataframe(ner_df.style.highlight_max(axis=0, color='lightblue'))
            
            # Display Sentiment Analysis
            st.header("Sentiment Analysis")
            
            # Create a sentiment score (you may need to adjust this based on your model's output)
            sentiment_score = 0.8 if sentiment == "positive" else -0.8 if sentiment == "negative" else 0
            
            # Create a chart for sentiment visualization
            sentiment_df = pd.DataFrame({'sentiment': [sentiment_score]})
            chart = alt.Chart(sentiment_df).mark_bar().encode(
                x=alt.X('sentiment', scale=alt.Scale(domain=(-1, 1))),
                color=alt.condition(
                    alt.datum.sentiment > 0,
                    alt.value("green"),
                    alt.value("red")
                )
            ).properties(width=600, height=100)
            
            st.altair_chart(chart)
            st.write(f"Sentiment: {sentiment.capitalize()}")
    
    # Explanatory section
    with st.expander("About this analysis"):
        st.write("""
        This tool uses YorubaCNN models to perform two types of analysis on Yoruba text:
        
        1. **Named Entity Recognition (NER)**: Identifies and classifies named entities (e.g., person names, organizations) in the text.
        2. **Sentiment Analysis**: Determines the overall emotional tone of the text (positive, negative, or neutral).
        
        The models used are based on Convolutional Neural Networks (CNN) and are specifically trained for the Yoruba language.
        """)

    # Styling
    st.markdown("""
        <style>
        .stAlert > div {
            padding-top: 20px;
            padding-bottom: 20px;
        }
        .stDataFrame {
            padding: 10px;
            border-radius: 5px;
            background-color: #f0f2f6;
        }
        </style>
        """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()