MIND-states-LDA / streamlit_app_LDA.py
kleinay's picture
Update streamlit_app_LDA.py
316bd95 verified
import streamlit as st
import json
import random
import numpy as np
from gensim import corpora, models
import pyLDAvis.gensim_models as gensimvis
import pyLDAvis
import pandas as pd
import streamlit.components.v1 as components
from MIND_utils import df_to_self_states_json, element_short_desc_map
# ---------------------------
# Streamlit App Layout
# ---------------------------
st.set_page_config(layout="wide")
st.title("Prototypical Self-States via Topic Modeling")
uploaded_file = st.file_uploader("Upload your own data file (CSV)", type="csv")
st.header("Model Parameters")
lda_document_is = st.radio("A 'Document' in the topic model will correspond to a:", ("self-state", "segment"))
num_topics = st.slider("Number of Topics", min_value=2, max_value=20, value=5)
num_passes = st.slider("Number of Passes", min_value=5, max_value=50, value=10)
seed_value = st.number_input("Random Seed", value=42)
st.subheader("Beta -- dispersion of words in a topic - lower means less words in each topic")
is_set_beta = st.checkbox("Set custom Beta (default: 1 / num_topics)? ")
if is_set_beta:
beta = st.number_input("Beta", min_value=0.0, max_value=1.0, value=1/num_topics, step=0.05, format="%.3f")
else:
beta = 1 / num_topics
st.subheader("Alpha -- dispersion of topics in a document - lower means less topics in each document")
is_set_alpha = st.checkbox("Set custom Alpha (default: dynamic per document)? ")
if is_set_alpha:
alpha = st.number_input("Alpha", min_value=0.0, max_value=1.0, value=1/num_topics, step=0.05, format="%.3f")
else:
alpha = "auto"
st.header("Display")
num_top_elements_to_show = st.slider("# top element to show in a topic", min_value=2, max_value=15, value=5)
show_long_elements = st.checkbox("Show full element name")
# ---------------------------
# Load Data
# ---------------------------
@st.cache_data
def load_data(csv):
return pd.read_csv(csv)
df = load_data(uploaded_file or "clean_annotations_safe.csv")
# ---------------------------
# Preprocess Data: Build Documents
# ---------------------------
# Set random seeds for reproducibility
random.seed(seed_value)
np.random.seed(seed_value)
# Functions to extract "words" (elements -- <dim>:<category>) from a segment / self-state
def extract_elements_from_selfstate(selfstate):
words = []
for dim, dim_obj in selfstate.items():
if dim == "is_adaptive":
continue
if "Category" in dim_obj and not pd.isna(dim_obj["Category"]):
word = f"{dim}:{dim_obj['Category']}"
words.append(word)
return words
def extract_elements_from_segment(segment):
words = []
for selfstate in segment["self-states"]:
words += extract_elements_from_selfstate(selfstate)
return words
# Build a list of "documents" (one per segment)
lda_documents = []
lda_document_ids = []
for (doc_id, annotator), df_ in df.groupby(["document", "annotator"]):
doc_json = df_to_self_states_json(df_, doc_id, annotator)
### * for Segment-level LDA-documents:
if lda_document_is == "segment":
for segment in doc_json["segments"]:
lda_doc = extract_elements_from_segment(segment)
if lda_doc: # only add if non-empty
lda_documents.append(lda_doc)
lda_document_ids.append(f"{doc_id}_seg{segment['segment']}")
### * for SelfState-level LDA-documents:
elif lda_document_is == "self-state":
for segment in doc_json["segments"]:
for i, selfstate in enumerate(segment["self-states"]):
lda_doc = extract_elements_from_selfstate(selfstate)
if lda_doc:
lda_documents.append(lda_doc)
lda_document_ids.append(f"{doc_id}_seg{segment['segment']}_state{i+1}")
# Create a dictionary and corpus for LDA
dictionary = corpora.Dictionary(lda_documents)
corpus = [dictionary.doc2bow(doc) for doc in lda_documents]
# ---------------------------
# Run LDA Model
# ---------------------------
lda_model = models.LdaModel(corpus,
num_topics=num_topics,
id2word=dictionary,
passes=num_passes,
eta=beta,
alpha=alpha,
random_state=seed_value)
# ---------------------------
# Display Pretty Printed Topics
# ---------------------------
st.header("Pretty Printed Topics")
# Build a mapping for each topic to the list of (document index, topic probability)
topic_docs = {topic_id: [] for topic_id in range(lda_model.num_topics)}
# Iterate over the corpus to get topic distributions for each document
for i, doc_bow in enumerate(corpus):
# Get the full topic distribution (with minimum_probability=0 so every topic is included)
doc_topics = lda_model.get_document_topics(doc_bow, minimum_probability=0)
for topic_id, prob in doc_topics:
topic_docs[topic_id].append((i, prob))
# For each topic, sort the documents by probability in descending order and keep the top 3
top_docs = {}
for topic_id, doc_list in topic_docs.items():
sorted_docs = sorted(doc_list, key=lambda x: x[1], reverse=True)
top_docs[topic_id] = sorted_docs[:3]
# Aggregate output into a single string
output_str = "Identified Prototypical Self-States (Topics):\n\n"
for topic_id, topic_str in lda_model.print_topics(num_words=num_top_elements_to_show):
output_str += f"Topic {topic_id}:\n"
terms = topic_str.split(" + ")
for term in terms:
weight, token = term.split("*")
token = token.strip().replace('"', '')
output_str += f" {float(weight):.3f} -> {token}\n"
output_str += " Top 3 Documents (Segment Indices) for this topic:\n"
for doc_index, prob in top_docs[topic_id]:
# Assuming lda_document_ids is a list or dict mapping document indices to identifiers
output_str += f" Doc {doc_index} ({lda_document_ids[doc_index]}) with probability {prob:.3f}\n"
output_str += "-" * 60 + "\n"
# Now you can display the aggregated string in Streamlit:
import streamlit as st
st.text(output_str)
# ---------------------------
# Prepare and Display pyLDAvis Visualization
# ---------------------------
st.header("Interactive Topic Visualization")
if not show_long_elements:
vis_dict = {i: element_short_desc_map[v] for i, v in dictionary.items()}
vis_dictionary = corpora.dictionary.Dictionary([[new_token] for new_token in vis_dict.values()])
vis_data = gensimvis.prepare(lda_model, corpus, vis_dictionary)
else:
vis_data = gensimvis.prepare(lda_model, corpus, dictionary)
html_string = pyLDAvis.prepared_data_to_html(vis_data)
components.html(html_string, width=2300, height=800, scrolling=True)