GenAIDevTOProd commited on
Commit
581cab8
·
verified ·
1 Parent(s): 09f28a9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app.py
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1nLqIbyBDiBI96gDZ0TziLNX8I4uWnl9G
8
+ """
9
+
10
+ pip install datasets
11
+
12
+ """Picking subreddits, split=sub as the data on huggingface datasets is split w.r.t subreddits and not train/test/validation.
13
+
14
+ Streaming = True, because we don't want to load all the data into local memory
15
+
16
+ loading and combining all the iterables together.
17
+
18
+ """
19
+
20
+ from datasets import load_dataset, concatenate_datasets
21
+
22
+ target_subreddits = ["askscience", "gaming", "technology", "todayilearned", "programming"]
23
+
24
+ # Load and stream each subreddit split individually
25
+ datasets = [
26
+ load_dataset("HuggingFaceGECLM/REDDIT_comments", split=sub, streaming=True)
27
+ for sub in target_subreddits
28
+ ]
29
+
30
+ # Combine into one iterable dataset
31
+ from itertools import chain
32
+ combined_dataset = chain(*datasets)
33
+
34
+ """# Chunking Logic
35
+ - Group Reddit comments into small textual chunks to create a unit of meaning for embedding.
36
+
37
+ - Short Reddit comments are noisy and lack semantic depth. Chunking lets us:
38
+
39
+ - Aggregate context across comments
40
+
41
+ - Improve embedding quality for semantic search
42
+
43
+ - Normalize input length for vector similarity
44
+
45
+ - We'll group n comments (3-5) per chunk or limit chunk size by token count (100 words).
46
+
47
+ **Use PySpark for handling the large concatenantion of chunked data**
48
+ """
49
+
50
+ from pyspark.sql import SparkSession
51
+ from pyspark.sql.functions import col, udf, monotonically_increasing_id
52
+ from pyspark.sql.types import StringType
53
+ import re
54
+ from itertools import islice
55
+
56
+ spark = SparkSession.builder.getOrCreate()
57
+
58
+ # Load generator into pandas or write out sample file and read into Spark
59
+ df = spark.createDataFrame([{"body": ex["body"]} for ex in islice(combined_dataset, 100000)])
60
+
61
+ # Clean text UDF
62
+ def clean_body(text):
63
+ text = text.lower()
64
+ text = re.sub(r"http\S+|www\S+|https\S+", "", text)
65
+ text = re.sub(r"[^a-zA-Z\s]", "", text)
66
+ return re.sub(r"\s+", " ", text).strip()
67
+
68
+ clean_udf = udf(clean_body, StringType())
69
+ df_clean = df.withColumn("clean", clean_udf(col("body")))
70
+
71
+ # Add row numbers to chunk
72
+ df_indexed = df_clean.withColumn("row_num", monotonically_increasing_id())
73
+ chunk_size = 5
74
+ df_indexed = df_indexed.withColumn("chunk_id", (col("row_num") / chunk_size).cast("int"))
75
+
76
+ # Group and concatenate
77
+ from pyspark.sql.functions import collect_list, concat_ws
78
+ df_chunked = df_indexed.groupBy("chunk_id").agg(concat_ws(" ", collect_list("clean")).alias("chunk_text"))
79
+
80
+ chunked_comments = df_chunked.select("chunk_text").rdd.map(lambda x: x[0]).collect()
81
+
82
+ subreddit_labels = []
83
+ for example in combined_dataset:
84
+ subreddit_labels.append(example["subreddit_name_prefixed"])
85
+ if len(subreddit_labels) >= len(chunked_comments):
86
+ break
87
+
88
+ """Cleaner text = better embeddings. Noise like markdown or links pollute meaning.
89
+
90
+ We'll use regex and basic string methods.
91
+
92
+ Normalize the text: remove URLs, HTML tags, Reddit-specific formatting, etc.
93
+ """
94
+
95
+ !pip install gensim tqdm
96
+
97
+ from gensim.models import Word2Vec
98
+ from tqdm import tqdm
99
+ import re
100
+
101
+ def clean_text(text):
102
+ # Lowercase, remove URLs, special chars
103
+ text = text.lower()
104
+ text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE)
105
+ text = re.sub(r"[^a-zA-Z\s]", "", text)
106
+ text = re.sub(r"\s+", " ", text).strip()
107
+ return text
108
+
109
+ tokenized_chunks = []
110
+ for chunk in tqdm(chunked_comments):
111
+ cleaned = clean_text(chunk)
112
+ tokens = cleaned.split() # Simple whitespace tokenizer
113
+ tokenized_chunks.append(tokens)
114
+
115
+ """Chunking + Tokenizing, removing urls, reddit slang words and unnecessary noisy text information.
116
+
117
+
118
+ vector_size=100, # Size of word embeddings (dimensionality)
119
+
120
+ window=5, # Context window size (how many words to look left/right)
121
+
122
+ min_count=2, # Ignores words with frequency < 2 (reduces noise)
123
+
124
+ workers=4, # Parallel training threads (CPU cores)
125
+
126
+ sg=1 # 1 = Skip-Gram (better for rare words); 0 =CBOW
127
+ """
128
+
129
+ model = Word2Vec(sentences=tokenized_chunks, vector_size=100, window=5, min_count=2, workers=4, sg=1)
130
+ model.save("reddit_word2vec.model")
131
+
132
+ """Training a custom Word2Vec model for embeddings.
133
+
134
+ Word2Vec learns dense vector representations (embeddings) for words by capturing their semantic context in a corpus. It enables semantic similarity, clustering, and search.
135
+
136
+ Skip-gram learns to predict surrounding words for a given center word. It performs better on small to medium-sized datasets and captures rare word semantics effectively.
137
+
138
+ - Word2Vec only generates vectors for individual words, not entire sentences or documents.
139
+
140
+ - Each word gets mapped to a dense vector (e.g., 100-dim) that captures its semantic relationships with other words.
141
+
142
+ # Why Averaging?
143
+ - It's a simple and surprisingly strong baseline:
144
+
145
+ - -Works well in low-resource or custom-trained embedding settings
146
+
147
+ - Keeps computation cheap
148
+
149
+ - Captures the "semantic center" of the chunk
150
+
151
+ Alternative strategies:
152
+
153
+ - Weighted average (e.g., using TF-IDF or word frequency)
154
+
155
+ - Doc2Vec (learns doc embeddings directly)
156
+
157
+ - Transformers (e.g., BERT) for sentence embeddings (but heavier)
158
+ """
159
+
160
+ import numpy as np
161
+
162
+ def get_chunk_embedding(chunk_tokens, model):
163
+ vectors = []
164
+ for token in chunk_tokens:
165
+ if token in model.wv:
166
+ vectors.append(model.wv[token])
167
+ if not vectors:
168
+ return np.zeros(model.vector_size)
169
+ return np.mean(vectors, axis=0)
170
+
171
+ # Dense embedding for each chunk
172
+ chunk_embeddings = [get_chunk_embedding(tokens, model) for tokens in tokenized_chunks]
173
+
174
+ """Converting variable length chunks to fixed level embeddings"""
175
+
176
+ !pip install faiss-cpu
177
+
178
+ import faiss
179
+
180
+ # Convert embeddings to float32 numpy array
181
+ embedding_matrix = np.array(chunk_embeddings).astype("float32")
182
+
183
+ # Initialize FAISS index (L2 similarity)
184
+ index = faiss.IndexFlatL2(model.vector_size)
185
+ index.add(embedding_matrix)
186
+
187
+ """Building FAISS index with the dense vectors generated from avaraging earlier.
188
+
189
+ FAISS is optimized for fast, approximate nearest-neighbor search — standard for semantic search pipelines.
190
+
191
+ Indexing takes precomputed embeddings (vectors generated from text) and organizes them into a searchable format like FAISS, enabling fast similarity-based retrieval.
192
+ """
193
+
194
+ import faiss
195
+ import numpy as np
196
+
197
+ # Embed each chunk using average Word2Vec token embeddings
198
+ def embed_chunk(text, model):
199
+ tokens = text.lower().split()
200
+ vectors = [model.wv[token] for token in tokens if token in model.wv]
201
+ return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size)
202
+
203
+ embeddings = np.array([embed_chunk(chunk, model) for chunk in chunked_comments]).astype("float32")
204
+
205
+ # Build and save FAISS index
206
+ index = faiss.IndexFlatL2(model.vector_size)
207
+ index.add(embeddings)
208
+ faiss.write_index(index, "reddit_faiss.index")
209
+
210
+ def search(query, model, index, top_k=5):
211
+ tokens = clean_text(query).split()
212
+ query_vec = get_chunk_embedding(tokens, model).astype("float32").reshape(1, -1)
213
+
214
+ distances, indices = index.search(query_vec, top_k)
215
+ return indices[0], distances[0]
216
+
217
+ original_chunks = [" ".join(tokens) for tokens in tokenized_chunks]
218
+
219
+ query = "quantum physics experiments"
220
+ top_ids, top_distances = search(query, model, index)
221
+
222
+ for i, idx in enumerate(top_ids):
223
+ print(f"Rank {i+1} | Distance: {top_distances[i]:.2f}")
224
+ print(original_chunks[idx][:300], "...\n")
225
+
226
+ """# **Reddit Semantic Search App**"""
227
+
228
+ import gradio as gr
229
+ import numpy as np
230
+ from sklearn.metrics.pairwise import cosine_similarity
231
+ from PIL import Image
232
+
233
+ from gensim.models import Word2Vec
234
+ import faiss
235
+ import numpy as np
236
+ import gradio as gr
237
+
238
+ # Load Word2Vec model and FAISS index
239
+ model = Word2Vec.load("reddit_word2vec.model")
240
+ index = faiss.read_index("reddit_faiss.index")
241
+
242
+ # Prepare embedding function
243
+ def embed_text(text):
244
+ tokens = text.lower().split()
245
+ vectors = [model.wv[token] for token in tokens if token in model.wv]
246
+ if not vectors:
247
+ return np.zeros(model.vector_size)
248
+ return np.mean(vectors, axis=0)
249
+
250
+ # Build subreddit index
251
+ subreddit_map = {i: label for i, label in enumerate(subreddit_labels)}
252
+ unique_subreddits = sorted(set(subreddit_labels)) # for dropdown
253
+
254
+ # Semantic search function
255
+ def search_reddit(query, selected_subreddit, top_k=5):
256
+ query_vec = embed_text(query).astype("float32")
257
+ D, I = index.search(np.array([query_vec]), top_k)
258
+
259
+ results = []
260
+ for idx in I[0]:
261
+ if idx < len(chunked_comments) and subreddit_map[idx] == selected_subreddit:
262
+ results.append(f"🔸 {chunked_comments[idx]}")
263
+ if len(results) >= top_k:
264
+ break
265
+
266
+ if not results:
267
+ return "⚠️ No relevant results found."
268
+ return "\n\n".join(results)
269
+
270
+ # Gradio UI
271
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="orange", secondary_hue="gray")) as demo:
272
+ gr.Image(
273
+ value="https://1000logos.net/wp-content/uploads/2017/05/Reddit-Logo.png",
274
+ show_label=False,
275
+ height=100
276
+ )
277
+ gr.Markdown("## Reddit Semantic Search (Powered by Word2Vec + FAISS)\n_Disclaimer: Exterimental prototype, not owned/developed by Reddit Inc_")
278
+
279
+ with gr.Row():
280
+ query = gr.Textbox(label="Enter your Reddit-like query", placeholder="e.g. What's new in AI?")
281
+
282
+ output = gr.Textbox(label="Top Matching Chunks", lines=10)
283
+ search_btn = gr.Button("🔍 Search")
284
+
285
+ search_btn.click(fn=search_reddit, inputs=[query, subreddit_dropdown], outputs=output)
286
+
287
+ demo.launch(share=True)