SoLProject / app.py
kambris's picture
Update app.py
6bd6b44 verified
raw
history blame
5.29 kB
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
from bertopic import BERTopic
import torch
import numpy as np
from collections import Counter
# Load AraBERT tokenizer and model for embeddings
bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
# Load AraBERT model for emotion classification
emotion_model = AutoModelForSequenceClassification.from_pretrained("aubmindlab/bert-base-arabertv2")
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
# Function to generate embeddings using AraBERT
def generate_embeddings(texts):
all_embeddings = []
for text in texts:
# Tokenize with truncation to handle long sequences
inputs = bert_tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Generate embeddings
with torch.no_grad():
outputs = bert_model(**inputs)
# Get the mean of the last hidden state as the embedding
embedding = outputs.last_hidden_state.mean(dim=1).numpy()
all_embeddings.append(embedding[0]) # Remove batch dimension
return np.array(all_embeddings)
# Function to perform emotion classification with proper truncation
def classify_emotions(texts):
emotions = []
for text in texts:
# Process text in chunks if it's too long
if len(bert_tokenizer.encode(text)) > 512:
chunks = [text[i:i + 512] for i in range(0, len(text), 512)]
# Take the emotion of the first chunk (usually contains the most relevant information)
emotion = emotion_classifier(chunks[0])[0]['label']
else:
emotion = emotion_classifier(text)[0]['label']
emotions.append(emotion)
return emotions
# Function to process the uploaded file and summarize by country
def process_and_summarize(uploaded_file, top_n=50):
# Determine the file type
if uploaded_file.name.endswith(".csv"):
df = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".xlsx"):
df = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file format.")
return None, None
# Validate required columns
required_columns = ['country', 'poem']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"Missing columns: {', '.join(missing_columns)}")
return None, None
# Parse and preprocess the file
df['country'] = df['country'].str.strip()
df = df.dropna(subset=['country', 'poem'])
# Initialize BERTopic
topic_model = BERTopic(language="arabic")
# Group by country
summaries = []
for country, group in df.groupby('country'):
st.info(f"Processing poems for {country}...")
# Get texts for this country
texts = group['poem'].dropna().tolist()
# Classify emotions
st.info(f"Classifying emotions for {country}...")
emotions = classify_emotions(texts)
# Generate embeddings and fit topic model
st.info(f"Generating embeddings and topics for {country}...")
embeddings = generate_embeddings(texts)
try:
topics, _ = topic_model.fit_transform(texts, embeddings)
# Aggregate topics and emotions
top_topics = Counter(topics).most_common(top_n)
top_emotions = Counter(emotions).most_common(top_n)
summaries.append({
'country': country,
'total_poems': len(texts),
'top_topics': top_topics,
'top_emotions': top_emotions
})
except Exception as e:
st.warning(f"Could not generate topics for {country}: {str(e)}")
continue
return summaries, topic_model
# Streamlit App Interface
st.title("Arabic Poem Topic Modeling & Emotion Classification")
st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.")
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
if uploaded_file is not None:
try:
top_n = st.number_input("Select the number of top topics/emotions to display:", min_value=1, max_value=100, value=50)
summaries, topic_model = process_and_summarize(uploaded_file, top_n=top_n)
if summaries is not None:
st.success("Data successfully processed!")
# Display summary for each country
for summary in summaries:
st.write(f"### {summary['country']}")
st.write(f"Total Poems: {summary['total_poems']}")
st.write(f"Top {top_n} Topics:")
st.write(summary['top_topics'])
st.write(f"Top {top_n} Emotions:")
st.write(summary['top_emotions'])
# Display overall topics
st.write("### Global Topic Information:")
st.write(topic_model.get_topic_info())
except Exception as e:
st.error(f"Error: {e}")