Spaces:
Running
Running
import streamlit as st | |
import json | |
import random | |
import numpy as np | |
from gensim import corpora, models | |
import pyLDAvis.gensim_models as gensimvis | |
import pyLDAvis | |
import pandas as pd | |
import streamlit.components.v1 as components | |
from MIND_utils import df_to_self_states_json, element_short_desc_map | |
# --------------------------- | |
# Streamlit App Layout | |
# --------------------------- | |
st.set_page_config(layout="wide") | |
st.title("Prototypical Self-States via Topic Modeling") | |
uploaded_file = st.file_uploader("Upload your own data file (CSV)", type="csv") | |
st.header("Model Parameters") | |
lda_document_is = st.radio("A 'Document' in the topic model will correspond to a:", ("self-state", "segment")) | |
num_topics = st.slider("Number of Topics", min_value=2, max_value=20, value=5) | |
num_passes = st.slider("Number of Passes", min_value=5, max_value=50, value=10) | |
seed_value = st.number_input("Random Seed", value=42) | |
st.subheader("Beta -- dispersion of words in a topic - lower means less words in each topic") | |
is_set_beta = st.checkbox("Set custom Beta (default: 1 / num_topics)? ") | |
if is_set_beta: | |
beta = st.number_input("Beta", min_value=0.0, max_value=1.0, value=1/num_topics, step=0.05, format="%.3f") | |
else: | |
beta = 1 / num_topics | |
st.subheader("Alpha -- dispersion of topics in a document - lower means less topics in each document") | |
is_set_alpha = st.checkbox("Set custom Alpha (default: dynamic per document)? ") | |
if is_set_alpha: | |
alpha = st.number_input("Alpha", min_value=0.0, max_value=1.0, value=1/num_topics, step=0.05, format="%.3f") | |
else: | |
alpha = "auto" | |
st.header("Display") | |
num_top_elements_to_show = st.slider("# top element to show in a topic", min_value=2, max_value=15, value=5) | |
show_long_elements = st.checkbox("Show full element name") | |
# --------------------------- | |
# Load Data | |
# --------------------------- | |
def load_data(csv): | |
return pd.read_csv(csv) | |
df = load_data(uploaded_file or "clean_annotations_safe.csv") | |
# --------------------------- | |
# Preprocess Data: Build Documents | |
# --------------------------- | |
# Set random seeds for reproducibility | |
random.seed(seed_value) | |
np.random.seed(seed_value) | |
# Functions to extract "words" (elements -- <dim>:<category>) from a segment / self-state | |
def extract_elements_from_selfstate(selfstate): | |
words = [] | |
for dim, dim_obj in selfstate.items(): | |
if dim == "is_adaptive": | |
continue | |
if "Category" in dim_obj and not pd.isna(dim_obj["Category"]): | |
word = f"{dim}:{dim_obj['Category']}" | |
words.append(word) | |
return words | |
def extract_elements_from_segment(segment): | |
words = [] | |
for selfstate in segment["self-states"]: | |
words += extract_elements_from_selfstate(selfstate) | |
return words | |
# Build a list of "documents" (one per segment) | |
lda_documents = [] | |
lda_document_ids = [] | |
for (doc_id, annotator), df_ in df.groupby(["document", "annotator"]): | |
doc_json = df_to_self_states_json(df_, doc_id, annotator) | |
### * for Segment-level LDA-documents: | |
if lda_document_is == "segment": | |
for segment in doc_json["segments"]: | |
lda_doc = extract_elements_from_segment(segment) | |
if lda_doc: # only add if non-empty | |
lda_documents.append(lda_doc) | |
lda_document_ids.append(f"{doc_id}_seg{segment['segment']}") | |
### * for SelfState-level LDA-documents: | |
elif lda_document_is == "self-state": | |
for segment in doc_json["segments"]: | |
for i, selfstate in enumerate(segment["self-states"]): | |
lda_doc = extract_elements_from_selfstate(selfstate) | |
if lda_doc: | |
lda_documents.append(lda_doc) | |
lda_document_ids.append(f"{doc_id}_seg{segment['segment']}_state{i+1}") | |
# Create a dictionary and corpus for LDA | |
dictionary = corpora.Dictionary(lda_documents) | |
corpus = [dictionary.doc2bow(doc) for doc in lda_documents] | |
# --------------------------- | |
# Run LDA Model | |
# --------------------------- | |
lda_model = models.LdaModel(corpus, | |
num_topics=num_topics, | |
id2word=dictionary, | |
passes=num_passes, | |
eta=beta, | |
alpha=alpha, | |
random_state=seed_value) | |
# --------------------------- | |
# Display Pretty Printed Topics | |
# --------------------------- | |
st.header("Pretty Printed Topics") | |
# Build a mapping for each topic to the list of (document index, topic probability) | |
topic_docs = {topic_id: [] for topic_id in range(lda_model.num_topics)} | |
# Iterate over the corpus to get topic distributions for each document | |
for i, doc_bow in enumerate(corpus): | |
# Get the full topic distribution (with minimum_probability=0 so every topic is included) | |
doc_topics = lda_model.get_document_topics(doc_bow, minimum_probability=0) | |
for topic_id, prob in doc_topics: | |
topic_docs[topic_id].append((i, prob)) | |
# For each topic, sort the documents by probability in descending order and keep the top 3 | |
top_docs = {} | |
for topic_id, doc_list in topic_docs.items(): | |
sorted_docs = sorted(doc_list, key=lambda x: x[1], reverse=True) | |
top_docs[topic_id] = sorted_docs[:3] | |
# Aggregate output into a single string | |
output_str = "Identified Prototypical Self-States (Topics):\n\n" | |
for topic_id, topic_str in lda_model.print_topics(num_words=num_top_elements_to_show): | |
output_str += f"Topic {topic_id}:\n" | |
terms = topic_str.split(" + ") | |
for term in terms: | |
weight, token = term.split("*") | |
token = token.strip().replace('"', '') | |
output_str += f" {float(weight):.3f} -> {token}\n" | |
output_str += " Top 3 Documents (Segment Indices) for this topic:\n" | |
for doc_index, prob in top_docs[topic_id]: | |
# Assuming lda_document_ids is a list or dict mapping document indices to identifiers | |
output_str += f" Doc {doc_index} ({lda_document_ids[doc_index]}) with probability {prob:.3f}\n" | |
output_str += "-" * 60 + "\n" | |
# Now you can display the aggregated string in Streamlit: | |
import streamlit as st | |
st.text(output_str) | |
# --------------------------- | |
# Prepare and Display pyLDAvis Visualization | |
# --------------------------- | |
st.header("Interactive Topic Visualization") | |
if not show_long_elements: | |
vis_dict = {i: element_short_desc_map[v] for i, v in dictionary.items()} | |
vis_dictionary = corpora.dictionary.Dictionary([[new_token] for new_token in vis_dict.values()]) | |
vis_data = gensimvis.prepare(lda_model, corpus, vis_dictionary) | |
else: | |
vis_data = gensimvis.prepare(lda_model, corpus, dictionary) | |
html_string = pyLDAvis.prepared_data_to_html(vis_data) | |
components.html(html_string, width=2300, height=800, scrolling=True) | |