Cisconardi commited on
Commit
93d3c59
·
verified ·
1 Parent(s): 0c5c9fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -159
app.py CHANGED
@@ -25,7 +25,7 @@ st.set_page_config(
25
  }
26
  )
27
 
28
- # Inizializzazione della sessione
29
  if 'model_loaded' not in st.session_state:
30
  st.session_state.model_loaded = False
31
  if 'analysis_complete' not in st.session_state:
@@ -60,43 +60,69 @@ st.markdown("""
60
  </style>
61
  """, unsafe_allow_html=True)
62
 
63
- # Funzioni di utilità
 
 
 
64
  @st.cache_resource
65
  def load_models():
66
- """Carica i modelli necessari con caching."""
67
  with st.spinner("Loading models... This may take a few minutes."):
68
  try:
 
69
  download("en_core_web_sm")
70
- model_filter = SpanMarkerModel.from_pretrained(
71
- "nbroad/span-marker-xdistil-l12-h384-orgs-v3"
72
- ).cuda() if cuda.is_available() else SpanMarkerModel.from_pretrained(
73
- "nbroad/span-marker-xdistil-l12-h384-orgs-v3")
 
 
 
 
 
 
 
 
74
  embedding_model = SentenceTransformer("all-mpnet-base-v2")
 
75
  return model_filter, embedding_model
76
  except Exception as e:
77
  st.error(f"Error loading models: {str(e)}")
78
  raise
79
 
 
 
 
 
80
  @st.cache_data
81
- def process_keywords(df, _model_filter):
 
 
 
 
 
 
 
 
 
 
 
82
  """
83
  Rileva eventuali keyword di tipo 'Brand' utilizzando il modello SpanMarker.
84
  Ritorna la lista di etichette 'Brand' o 'Unbranded' per ciascuna keyword.
85
  """
86
  results = []
87
  total = len(df)
88
-
89
  progress_text = "Processing keywords..."
90
  progress_bar = st.progress(0, text=progress_text)
91
 
92
  for i, keyword in enumerate(df['Keyword']):
93
  try:
94
- entities = _model_filter.predict([keyword])
95
  label = "Brand" if entities and isinstance(entities[0], list) and \
96
  any(entity.get("label") == "ORG" for entity in entities[0]) else "Unbranded"
97
  results.append(label)
98
  except Exception as e:
99
- # Se non riesce a rilevare entità, di default etichetta come 'Unbranded'
100
  st.error(f"Error processing keyword '{keyword}': {str(e)}")
101
  results.append("Unbranded")
102
 
@@ -105,8 +131,13 @@ def process_keywords(df, _model_filter):
105
  progress_bar.empty()
106
  return results
107
 
 
 
 
 
108
  def create_topic_model(embedding_model, model_params):
109
- """Crea e configura il modello di topic modeling."""
 
110
  try:
111
  # Configurazione quantizzazione per Hugging Face
112
  bnb_config = transformers.BitsAndBytesConfig(
@@ -224,129 +255,87 @@ def create_topic_model(embedding_model, model_params):
224
  st.error(f"Error creating topic model: {str(e)}")
225
  raise
226
 
227
- def process_data(df, model_filter, embedding_model, model_params, exclude_brand_keywords=False):
 
 
 
 
 
228
  """
229
- Processa i dati e crea il topic model.
230
-
231
- Se exclude_brand_keywords è True, esegue l'etichettatura "Brand"/"Unbranded" e
232
- rimuove le keyword 'Brand' dal dataset prima di eseguire il clustering.
233
- Altrimenti, include tutte le keyword.
234
  """
235
- try:
236
- # Se l'utente sceglie di escludere i brand, etichettiamo e filtriamo
237
- if exclude_brand_keywords:
238
- df['Label'] = process_keywords(df, model_filter)
239
- filtered_df = df[df['Label'] == 'Unbranded']
240
- else:
241
- # Non facciamo alcuna classificazione se non necessario
242
- df['Label'] = "Unbranded"
243
- filtered_df = df
244
-
245
- filtered_keywords = filtered_df['Keyword'].tolist()
246
-
247
- if not filtered_keywords:
248
- st.warning("No keywords found for analysis (perhaps all were branded).")
249
- return None, None
250
-
251
- # Genera embeddings
252
- embeddings = embedding_model.encode(filtered_keywords, show_progress_bar=True)
253
-
254
- # Crea e applica topic model
255
- topic_model = create_topic_model(embedding_model, model_params)
256
- topics, probs = topic_model.fit_transform(filtered_keywords, embeddings)
257
-
258
- # Ottieni gli embeddings ridotti per la visualizzazione
259
- reduced_embeddings = topic_model.umap_model.embedding_
260
-
261
- # Usa i label generati da Llama 2 (TextGeneration) come label finali
262
- llama_topic_labels = {
263
- topic: "".join(list(zip(*values))[0])
264
- for topic, values in topic_model.topic_aspects_["Llama2"].items()
265
- }
266
- llama_topic_labels[-1] = "Outlier Topic"
267
- topic_model.set_topic_labels(llama_topic_labels)
268
 
269
- # Ottieni le informazioni sui topic
270
- topic_info = topic_model.get_topic_info()
271
- topic_labels = dict(zip(topic_info["Topic"], topic_info["CustomName"]))
272
-
273
- # Ottieni le informazioni di default BERT
274
- bert_labels = dict(zip(topic_info["Topic"], topic_info["Name"]))
275
-
276
- # Creiamo il DataFrame dei risultati
277
- results_df = pd.DataFrame({
278
- "Keyword": filtered_keywords,
279
- "Topic ID": topics,
280
- "Confidence": probs
281
- })
282
-
283
- # Aggiungiamo le label Llama e BERT
284
- results_df["Llama label"] = [
285
- topic_labels[topic] if topic in topic_labels else "Outlier Topic"
286
- for topic in topics
287
- ]
288
- results_df["BERT label"] = [
289
- bert_labels[topic] if topic in bert_labels else "Outlier Topic"
290
- for topic in topics
291
- ]
292
-
293
- # Se nel CSV c'è una colonna 'Volume', la aggiungiamo
294
- if "Volume" in filtered_df.columns:
295
- results_df["Volume"] = filtered_df["Volume"].values
296
 
297
- # Visualizza risultati
298
- st.write("### Results Table")
299
- st.dataframe(results_df, use_container_width=True, hide_index=True)
300
-
301
- # Visualizza la dashboard interattiva
302
- st.write("### Interactive Topic Visualization")
303
- try:
304
- fig = topic_model.visualize_documents(
305
- filtered_keywords,
306
- reduced_embeddings=reduced_embeddings,
307
- hide_annotations=True,
308
- hide_document_hover=False,
309
- custom_labels=True
310
- )
311
- st.plotly_chart(fig, theme="streamlit", use_container_width=True)
312
-
313
- # Visualizzazione dei topic
314
- st.write("### Topic Overview")
315
- try:
316
- topic_fig = topic_model.visualize_topics(custom_labels=True)
317
- st.plotly_chart(topic_fig, theme="streamlit", use_container_width=True)
318
- except Exception as e:
319
- st.error(f"Error creating topic visualization: {str(e)}")
320
-
321
- # Visualizzazione barchart dei topic
322
- st.write("### Topic Distribution")
323
- try:
324
- # Calcola il numero di topic da visualizzare
325
- n_topics = len(topic_model.get_topic_info())
326
- n_topics = min(50, max(1, n_topics - 1)) # -1 per escludere l'outlier topic se presente
327
-
328
- barchart_fig = topic_model.visualize_barchart(
329
- top_n_topics=n_topics,
330
- custom_labels=True
331
- )
332
- st.plotly_chart(barchart_fig, theme="streamlit", use_container_width=True)
333
- except Exception as e:
334
- st.error(f"Error creating barchart visualization: {str(e)}")
335
 
336
- except Exception as e:
337
- st.error(f"Error creating visualization: {str(e)}")
338
-
339
- return topic_model, results_df
340
 
341
- except Exception as e:
342
- st.error(f"Error processing data: {str(e)}")
343
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
 
 
 
 
346
  def main():
347
- st.title("🔍 NLP Keyword Analysis")
348
-
349
- topic_model = None # Inizializza topic_model qui
350
 
351
  # Sidebar con configurazioni
352
  with st.sidebar:
@@ -439,9 +428,12 @@ def main():
439
  - Vectorizer: Controls text preprocessing
440
  - Topic Model: Controls topic generation
441
  - Llama 2: Controls topic labeling
 
 
 
442
  """)
443
 
444
- # Raccolta parametri in un dizionario
445
  model_params = {
446
  'umap_n_neighbors': umap_n_neighbors,
447
  'umap_n_components': umap_n_components,
@@ -459,10 +451,15 @@ def main():
459
  'llama_repetition_penalty': llama_repetition_penalty
460
  }
461
 
 
462
  if uploaded_file is not None:
463
  try:
464
- # Carica dati con il numero di righe specificato
465
- df = pd.read_csv(uploaded_file, skiprows=min_rows - 1, nrows=max_rows - min_rows + 1)
 
 
 
 
466
 
467
  if 'Keyword' not in df.columns:
468
  st.error("CSV must contain a 'Keyword' column")
@@ -473,27 +470,21 @@ def main():
473
  st.write(f"Reading rows {min_rows} to {max_rows}")
474
  st.dataframe(
475
  df.head(),
476
- use_container_width=True,
477
- column_config={
478
- "Keyword": st.column_config.TextColumn(
479
- "Keyword",
480
- help="Input keywords for analysis"
481
- )
482
- }
483
  )
484
  st.write(f"Total rows loaded: {len(df)}")
485
 
486
- # Analisi
487
  if st.button("Start Analysis", type="primary"):
488
  try:
489
- # Carica modelli
490
  with st.status("Loading models...", expanded=True) as status:
491
  model_filter, embedding_model = load_models()
492
  status.update(label="Models loaded successfully!", state="complete")
493
 
494
- # Processa dati
495
  with st.status("Processing data...", expanded=True) as status:
496
- topic_model, results_df = process_data(
497
  df,
498
  model_filter,
499
  embedding_model,
@@ -511,15 +502,56 @@ def main():
511
  with st.expander("Configuration Summary", expanded=False):
512
  st.json(model_params)
513
 
514
- # Download risultati
515
- if results_df is not None:
516
- st.download_button(
517
- label="Download Results",
518
- data=results_df.to_csv(index=False),
519
- file_name="keyword_analysis_results.csv",
520
- mime="text/csv",
521
- key="download_results"
 
 
 
 
 
 
522
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  except Exception as e:
525
  st.error(f"An error occurred during analysis: {str(e)}")
@@ -527,21 +559,18 @@ def main():
527
  st.error(f"Error reading file: {str(e)}")
528
 
529
  else:
 
530
  st.info("""
531
- 👋 Welcome to the NLP Keyword Analysis tool!
532
 
533
- Please upload a CSV file containing your keywords to get started.
534
- The file should have a column named 'Keyword'.
 
 
535
 
536
- You can configure:
537
- - Number of rows to read from the CSV
538
- - (Optionally) Exclude brand-labeled keywords
539
- - UMAP parameters for dimensionality reduction
540
- - HDBSCAN parameters for clustering
541
- - Vectorizer parameters for text preprocessing
542
- - Topic model parameters
543
- - Llama 2 parameters for topic labeling
544
  """)
545
 
 
546
  if __name__ == "__main__":
547
  main()
 
25
  }
26
  )
27
 
28
+ # Inizializzazione della sessione (opzionale, utile se vuoi tenere traccia di stati extra)
29
  if 'model_loaded' not in st.session_state:
30
  st.session_state.model_loaded = False
31
  if 'analysis_complete' not in st.session_state:
 
60
  </style>
61
  """, unsafe_allow_html=True)
62
 
63
+
64
+ #
65
+ # 1) Caricamento modelli con cache_resource
66
+ #
67
  @st.cache_resource
68
  def load_models():
69
+ """Carica i modelli necessari con caching (una sola volta)."""
70
  with st.spinner("Loading models... This may take a few minutes."):
71
  try:
72
+ # Scarica en_core_web_sm se non presente (per PartOfSpeech)
73
  download("en_core_web_sm")
74
+
75
+ # Modello SpanMarker: rilevazione entità (Brand/Unbranded)
76
+ if cuda.is_available():
77
+ model_filter = SpanMarkerModel.from_pretrained(
78
+ "nbroad/span-marker-xdistil-l12-h384-orgs-v3"
79
+ ).cuda()
80
+ else:
81
+ model_filter = SpanMarkerModel.from_pretrained(
82
+ "nbroad/span-marker-xdistil-l12-h384-orgs-v3"
83
+ )
84
+
85
+ # Modello di embedding SentenceTransformer
86
  embedding_model = SentenceTransformer("all-mpnet-base-v2")
87
+
88
  return model_filter, embedding_model
89
  except Exception as e:
90
  st.error(f"Error loading models: {str(e)}")
91
  raise
92
 
93
+
94
+ #
95
+ # 2) Lettura CSV con cache_data
96
+ #
97
  @st.cache_data
98
+ def load_csv(file, skiprows, nrows):
99
+ """Carica il CSV con caching, così se l'utente riscarica o scarica i risultati,
100
+ Streamlit non rilegge il file da zero (se non è cambiato)."""
101
+ df = pd.read_csv(file, skiprows=skiprows, nrows=nrows)
102
+ return df
103
+
104
+
105
+ #
106
+ # 3) Funzione di etichettatura Brand/Unbranded con cache_data
107
+ #
108
+ @st.cache_data
109
+ def process_keywords(df, model_filter):
110
  """
111
  Rileva eventuali keyword di tipo 'Brand' utilizzando il modello SpanMarker.
112
  Ritorna la lista di etichette 'Brand' o 'Unbranded' per ciascuna keyword.
113
  """
114
  results = []
115
  total = len(df)
 
116
  progress_text = "Processing keywords..."
117
  progress_bar = st.progress(0, text=progress_text)
118
 
119
  for i, keyword in enumerate(df['Keyword']):
120
  try:
121
+ entities = model_filter.predict([keyword])
122
  label = "Brand" if entities and isinstance(entities[0], list) and \
123
  any(entity.get("label") == "ORG" for entity in entities[0]) else "Unbranded"
124
  results.append(label)
125
  except Exception as e:
 
126
  st.error(f"Error processing keyword '{keyword}': {str(e)}")
127
  results.append("Unbranded")
128
 
 
131
  progress_bar.empty()
132
  return results
133
 
134
+
135
+ #
136
+ # 4) Creazione del modello di topic
137
+ #
138
  def create_topic_model(embedding_model, model_params):
139
+ """Crea e configura il modello di topic modeling (non cachiamo,
140
+ perché potrebbe dipendere da molti parametri)"""
141
  try:
142
  # Configurazione quantizzazione per Hugging Face
143
  bnb_config = transformers.BitsAndBytesConfig(
 
255
  st.error(f"Error creating topic model: {str(e)}")
256
  raise
257
 
258
+
259
+ #
260
+ # 5) Analisi principale (cachiamo i risultati finali dell'analisi)
261
+ #
262
+ @st.cache_data
263
+ def run_analysis(df, model_filter, embedding_model, model_params, exclude_brand_keywords):
264
  """
265
+ - Etichetta (facoltativo) come 'Brand' o 'Unbranded'
266
+ - Filtra i brand se richiesto
267
+ - Crea embeddings
268
+ - Esegue il topic modeling
269
+ - Restituisce il modello + results_df
270
  """
271
+ # Se l'utente sceglie di escludere i brand, etichettiamo e filtriamo
272
+ if exclude_brand_keywords:
273
+ df['Label'] = process_keywords(df, model_filter)
274
+ filtered_df = df[df['Label'] == 'Unbranded']
275
+ else:
276
+ df['Label'] = "Unbranded"
277
+ filtered_df = df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ filtered_keywords = filtered_df['Keyword'].tolist()
280
+
281
+ if not filtered_keywords:
282
+ st.warning("No keywords found for analysis (perhaps all were branded).")
283
+ return None, None
284
+
285
+ # Genera embeddings
286
+ embeddings = embedding_model.encode(filtered_keywords, show_progress_bar=True)
287
+
288
+ # Crea e applica topic model
289
+ topic_model = create_topic_model(embedding_model, model_params)
290
+ topics, probs = topic_model.fit_transform(filtered_keywords, embeddings)
291
+
292
+ # Ottieni gli embeddings ridotti per la visualizzazione
293
+ reduced_embeddings = topic_model.umap_model.embedding_
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ # Usa i label generati da Llama 2 come label finali
296
+ llama_topic_labels = {
297
+ topic: "".join(list(zip(*values))[0])
298
+ for topic, values in topic_model.topic_aspects_["Llama2"].items()
299
+ }
300
+ llama_topic_labels[-1] = "Outlier Topic"
301
+ topic_model.set_topic_labels(llama_topic_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ # Ottieni le informazioni sui topic
304
+ topic_info = topic_model.get_topic_info()
305
+ topic_labels = dict(zip(topic_info["Topic"], topic_info["CustomName"]))
 
306
 
307
+ # Ottieni le informazioni di default BERT
308
+ bert_labels = dict(zip(topic_info["Topic"], topic_info["Name"]))
309
+
310
+ # Creiamo il DataFrame dei risultati
311
+ results_df = pd.DataFrame({
312
+ "Keyword": filtered_keywords,
313
+ "Topic ID": topics,
314
+ "Confidence": probs
315
+ })
316
+
317
+ # Aggiungiamo le label Llama e BERT
318
+ results_df["Llama label"] = [
319
+ topic_labels[topic] if topic in topic_labels else "Outlier Topic"
320
+ for topic in topics
321
+ ]
322
+ results_df["BERT label"] = [
323
+ bert_labels[topic] if topic in bert_labels else "Outlier Topic"
324
+ for topic in topics
325
+ ]
326
+
327
+ # Se nel CSV c'è una colonna 'Volume', la aggiungiamo
328
+ if "Volume" in filtered_df.columns:
329
+ results_df["Volume"] = filtered_df["Volume"].values
330
+
331
+ return topic_model, results_df
332
 
333
 
334
+ #
335
+ # 6) Main Streamlit App
336
+ #
337
  def main():
338
+ st.title("🔍 NLP Keyword Analysis with Cache")
 
 
339
 
340
  # Sidebar con configurazioni
341
  with st.sidebar:
 
428
  - Vectorizer: Controls text preprocessing
429
  - Topic Model: Controls topic generation
430
  - Llama 2: Controls topic labeling
431
+
432
+ **Caching:**
433
+ - Con i decorator `@st.cache_data` e `@st.cache_resource`, eviterai ricalcoli costosi quando l'app si ricarica.
434
  """)
435
 
436
+ # 7) Prepariamo dizionario parametri
437
  model_params = {
438
  'umap_n_neighbors': umap_n_neighbors,
439
  'umap_n_components': umap_n_components,
 
451
  'llama_repetition_penalty': llama_repetition_penalty
452
  }
453
 
454
+ # 8) Se abbiamo caricato un file, procediamo
455
  if uploaded_file is not None:
456
  try:
457
+ # Carica dati con caching
458
+ df = load_csv(
459
+ file=uploaded_file,
460
+ skiprows=min_rows - 1,
461
+ nrows=max_rows - min_rows + 1
462
+ )
463
 
464
  if 'Keyword' not in df.columns:
465
  st.error("CSV must contain a 'Keyword' column")
 
470
  st.write(f"Reading rows {min_rows} to {max_rows}")
471
  st.dataframe(
472
  df.head(),
473
+ use_container_width=True
 
 
 
 
 
 
474
  )
475
  st.write(f"Total rows loaded: {len(df)}")
476
 
477
+ # Pulsante per avviare l'analisi
478
  if st.button("Start Analysis", type="primary"):
479
  try:
480
+ # Carichiamo i modelli (cache_resource)
481
  with st.status("Loading models...", expanded=True) as status:
482
  model_filter, embedding_model = load_models()
483
  status.update(label="Models loaded successfully!", state="complete")
484
 
485
+ # Eseguiamo l'analisi (cache_data)
486
  with st.status("Processing data...", expanded=True) as status:
487
+ topic_model, results_df = run_analysis(
488
  df,
489
  model_filter,
490
  embedding_model,
 
502
  with st.expander("Configuration Summary", expanded=False):
503
  st.json(model_params)
504
 
505
+ # 9) Mostra risultati
506
+ st.write("### Results Table")
507
+ st.dataframe(results_df, use_container_width=True, hide_index=True)
508
+
509
+ # Visualizza la dashboard interattiva
510
+ st.write("### Interactive Topic Visualization")
511
+ try:
512
+ # Embedding ridotto
513
+ fig = topic_model.visualize_documents(
514
+ results_df['Keyword'].tolist(),
515
+ reduced_embeddings=topic_model.umap_model.embedding_,
516
+ hide_annotations=True,
517
+ hide_document_hover=False,
518
+ custom_labels=True
519
  )
520
+ st.plotly_chart(fig, theme="streamlit", use_container_width=True)
521
+
522
+ # Visualizzazione dei topic
523
+ st.write("### Topic Overview")
524
+ try:
525
+ topic_fig = topic_model.visualize_topics(custom_labels=True)
526
+ st.plotly_chart(topic_fig, theme="streamlit", use_container_width=True)
527
+ except Exception as e:
528
+ st.error(f"Error creating topic visualization: {str(e)}")
529
+
530
+ # Visualizzazione barchart dei topic
531
+ st.write("### Topic Distribution")
532
+ try:
533
+ n_topics = len(topic_model.get_topic_info())
534
+ n_topics = min(50, max(1, n_topics - 1)) # -1 per outlier
535
+
536
+ barchart_fig = topic_model.visualize_barchart(
537
+ top_n_topics=n_topics,
538
+ custom_labels=True
539
+ )
540
+ st.plotly_chart(barchart_fig, theme="streamlit", use_container_width=True)
541
+ except Exception as e:
542
+ st.error(f"Error creating barchart visualization: {str(e)}")
543
+
544
+ except Exception as e:
545
+ st.error(f"Error creating visualization: {str(e)}")
546
+
547
+ # Download risultati in CSV
548
+ st.download_button(
549
+ label="Download Results",
550
+ data=results_df.to_csv(index=False),
551
+ file_name="keyword_analysis_results.csv",
552
+ mime="text/csv",
553
+ key="download_results"
554
+ )
555
 
556
  except Exception as e:
557
  st.error(f"An error occurred during analysis: {str(e)}")
 
559
  st.error(f"Error reading file: {str(e)}")
560
 
561
  else:
562
+ # Messaggio iniziale
563
  st.info("""
564
+ 👋 Welcome to the NLP Keyword Analysis tool (with caching)!
565
 
566
+ 1. Upload a CSV file with a column named **'Keyword'**.
567
+ 2. Adjust parameters in the sidebar if needed.
568
+ 3. Click **"Start Analysis"**.
569
+ 4. Download the results.
570
 
571
+ *Note: Caching helps avoid re-running expensive computations when the app reloads.*
 
 
 
 
 
 
 
572
  """)
573
 
574
+
575
  if __name__ == "__main__":
576
  main()