Spaces:
Runtime error
Runtime error
Refactor
Browse files- app.py +22 -20
- prompts.py +5 -3
app.py
CHANGED
|
@@ -7,7 +7,6 @@ from bertopic import BERTopic
|
|
| 7 |
import gradio as gr
|
| 8 |
from bertopic.representation import (
|
| 9 |
KeyBERTInspired,
|
| 10 |
-
MaximalMarginalRelevance,
|
| 11 |
TextGeneration,
|
| 12 |
)
|
| 13 |
from umap import UMAP
|
|
@@ -19,8 +18,7 @@ from transformers import (
|
|
| 19 |
AutoModelForCausalLM,
|
| 20 |
pipeline,
|
| 21 |
)
|
| 22 |
-
from prompts import
|
| 23 |
-
from umap import UMAP
|
| 24 |
from hdbscan import HDBSCAN
|
| 25 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 26 |
|
|
@@ -36,7 +34,6 @@ logging.basicConfig(
|
|
| 36 |
session = requests.Session()
|
| 37 |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 38 |
keybert = KeyBERTInspired()
|
| 39 |
-
mmr = MaximalMarginalRelevance(diversity=0.3)
|
| 40 |
vectorizer_model = CountVectorizer(stop_words="english")
|
| 41 |
|
| 42 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
|
@@ -52,7 +49,6 @@ bnb_config = BitsAndBytesConfig(
|
|
| 52 |
|
| 53 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 54 |
|
| 55 |
-
# Llama 2 Model
|
| 56 |
model = AutoModelForCausalLM.from_pretrained(
|
| 57 |
model_id,
|
| 58 |
trust_remote_code=True,
|
|
@@ -68,13 +64,11 @@ generator = pipeline(
|
|
| 68 |
max_new_tokens=500,
|
| 69 |
repetition_penalty=1.1,
|
| 70 |
)
|
| 71 |
-
prompt = system_prompt + example_prompt + main_prompt
|
| 72 |
|
| 73 |
-
llama2 = TextGeneration(generator, prompt=
|
| 74 |
representation_model = {
|
| 75 |
"KeyBERT": keybert,
|
| 76 |
"Llama2": llama2,
|
| 77 |
-
# "MMR": mmr,
|
| 78 |
}
|
| 79 |
|
| 80 |
umap_model = UMAP(
|
|
@@ -132,9 +126,9 @@ def fit_model(base_model, docs, embeddings):
|
|
| 132 |
verbose=True,
|
| 133 |
min_topic_size=15,
|
| 134 |
)
|
| 135 |
-
logging.
|
| 136 |
new_model.fit(docs, embeddings)
|
| 137 |
-
logging.
|
| 138 |
|
| 139 |
if base_model is None:
|
| 140 |
return new_model, new_model
|
|
@@ -157,35 +151,43 @@ def generate_topics(dataset, config, split, column, nested_column):
|
|
| 157 |
offset = 0
|
| 158 |
base_model = None
|
| 159 |
all_docs = []
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 162 |
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
|
|
|
|
|
|
|
|
|
| 163 |
logging.info(
|
| 164 |
-
f"
|
| 165 |
)
|
|
|
|
| 166 |
embeddings = calculate_embeddings(docs)
|
| 167 |
-
offset = offset + chunk_size
|
| 168 |
-
if not docs or offset >= limit:
|
| 169 |
-
break
|
| 170 |
base_model, _ = fit_model(base_model, docs, embeddings)
|
| 171 |
llama2_labels = [
|
| 172 |
label[0][0].split("\n")[0]
|
| 173 |
for label in base_model.get_topics(full=True)["Llama2"].values()
|
| 174 |
]
|
| 175 |
-
logging.info(f"Topics: {llama2_labels}")
|
| 176 |
base_model.set_topic_labels(llama2_labels)
|
| 177 |
|
| 178 |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
|
|
|
|
| 179 |
|
| 180 |
all_docs.extend(docs)
|
| 181 |
-
|
| 182 |
topics_info = base_model.get_topic_info()
|
| 183 |
topic_plot = base_model.visualize_documents(
|
| 184 |
-
all_docs,
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
yield topics_info, topic_plot
|
| 188 |
|
|
|
|
|
|
|
| 189 |
logging.info("Finished processing all data")
|
| 190 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
| 191 |
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
from bertopic.representation import (
|
| 9 |
KeyBERTInspired,
|
|
|
|
| 10 |
TextGeneration,
|
| 11 |
)
|
| 12 |
from umap import UMAP
|
|
|
|
| 18 |
AutoModelForCausalLM,
|
| 19 |
pipeline,
|
| 20 |
)
|
| 21 |
+
from prompts import REPRESENTATION_PROMPT
|
|
|
|
| 22 |
from hdbscan import HDBSCAN
|
| 23 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 24 |
|
|
|
|
| 34 |
session = requests.Session()
|
| 35 |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 36 |
keybert = KeyBERTInspired()
|
|
|
|
| 37 |
vectorizer_model = CountVectorizer(stop_words="english")
|
| 38 |
|
| 39 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
|
|
|
| 49 |
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 51 |
|
|
|
|
| 52 |
model = AutoModelForCausalLM.from_pretrained(
|
| 53 |
model_id,
|
| 54 |
trust_remote_code=True,
|
|
|
|
| 64 |
max_new_tokens=500,
|
| 65 |
repetition_penalty=1.1,
|
| 66 |
)
|
|
|
|
| 67 |
|
| 68 |
+
llama2 = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
|
| 69 |
representation_model = {
|
| 70 |
"KeyBERT": keybert,
|
| 71 |
"Llama2": llama2,
|
|
|
|
| 72 |
}
|
| 73 |
|
| 74 |
umap_model = UMAP(
|
|
|
|
| 126 |
verbose=True,
|
| 127 |
min_topic_size=15,
|
| 128 |
)
|
| 129 |
+
logging.debug("Fitting new model")
|
| 130 |
new_model.fit(docs, embeddings)
|
| 131 |
+
logging.debug("End fitting new model")
|
| 132 |
|
| 133 |
if base_model is None:
|
| 134 |
return new_model, new_model
|
|
|
|
| 151 |
offset = 0
|
| 152 |
base_model = None
|
| 153 |
all_docs = []
|
| 154 |
+
reduced_embeddings_list = []
|
| 155 |
+
|
| 156 |
+
while offset < limit:
|
| 157 |
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
| 158 |
+
if not docs:
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
logging.info(
|
| 162 |
+
f"----> Processing chunk: {offset=} {chunk_size=} with {len(docs)} docs"
|
| 163 |
)
|
| 164 |
+
|
| 165 |
embeddings = calculate_embeddings(docs)
|
|
|
|
|
|
|
|
|
|
| 166 |
base_model, _ = fit_model(base_model, docs, embeddings)
|
| 167 |
llama2_labels = [
|
| 168 |
label[0][0].split("\n")[0]
|
| 169 |
for label in base_model.get_topics(full=True)["Llama2"].values()
|
| 170 |
]
|
|
|
|
| 171 |
base_model.set_topic_labels(llama2_labels)
|
| 172 |
|
| 173 |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
|
| 174 |
+
reduced_embeddings_list.append(reduced_embeddings)
|
| 175 |
|
| 176 |
all_docs.extend(docs)
|
| 177 |
+
|
| 178 |
topics_info = base_model.get_topic_info()
|
| 179 |
topic_plot = base_model.visualize_documents(
|
| 180 |
+
all_docs,
|
| 181 |
+
reduced_embeddings=np.vstack(reduced_embeddings_list),
|
| 182 |
+
custom_labels=True,
|
| 183 |
)
|
| 184 |
+
|
| 185 |
+
logging.info(f"Topics: {llama2_labels}")
|
| 186 |
+
|
| 187 |
yield topics_info, topic_plot
|
| 188 |
|
| 189 |
+
offset += chunk_size
|
| 190 |
+
|
| 191 |
logging.info("Finished processing all data")
|
| 192 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
| 193 |
|
prompts.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
|
| 2 |
<s>[INST] <<SYS>>
|
| 3 |
You are a helpful, respectful and honest assistant for labeling topics.
|
| 4 |
<</SYS>>
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
|
| 8 |
I have a topic that contains the following documents:
|
| 9 |
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
| 10 |
- Meat, but especially beef, is the word food in terms of emissions.
|
|
@@ -17,7 +17,7 @@ Based on the information about the topic above, please create a short label of t
|
|
| 17 |
[/INST] Environmental impacts of eating meat
|
| 18 |
"""
|
| 19 |
|
| 20 |
-
|
| 21 |
[INST]
|
| 22 |
I have a topic that contains the following documents:
|
| 23 |
[DOCUMENTS]
|
|
@@ -27,3 +27,5 @@ The topic is described by the following keywords: '[KEYWORDS]'.
|
|
| 27 |
Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
|
| 28 |
[/INST]
|
| 29 |
"""
|
|
|
|
|
|
|
|
|
| 1 |
+
SYSTEM_PROMPT = """
|
| 2 |
<s>[INST] <<SYS>>
|
| 3 |
You are a helpful, respectful and honest assistant for labeling topics.
|
| 4 |
<</SYS>>
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
EXAMPLE_PROMPT = """
|
| 8 |
I have a topic that contains the following documents:
|
| 9 |
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
| 10 |
- Meat, but especially beef, is the word food in terms of emissions.
|
|
|
|
| 17 |
[/INST] Environmental impacts of eating meat
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
MAIN_PROMPT = """
|
| 21 |
[INST]
|
| 22 |
I have a topic that contains the following documents:
|
| 23 |
[DOCUMENTS]
|
|
|
|
| 27 |
Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
|
| 28 |
[/INST]
|
| 29 |
"""
|
| 30 |
+
|
| 31 |
+
REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT
|