|
|
|
import re |
|
import json |
|
from itertools import chain |
|
import numpy as np |
|
from gensim.models import Word2Vec |
|
from tqdm import tqdm |
|
import faiss |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
target_subreddits = ["askscience", "gaming", "technology", "todayilearned", "programming"] |
|
chunk_size = 5 |
|
|
|
|
|
def load_reddit_split(subreddit_name): |
|
path = hf_hub_download(repo_id="HuggingFaceGECLM/REDDIT_comments", filename=f"{subreddit_name}.jsonl") |
|
with open(path, "r") as f: |
|
for line in f: |
|
yield json.loads(line) |
|
|
|
combined_dataset = list(chain(*(load_reddit_split(sub) for sub in target_subreddits))) |
|
|
|
|
|
def clean_text(text): |
|
text = text.lower() |
|
text = re.sub(r"http\S+|www\S+|https\S+", "", text) |
|
text = re.sub(r"[^a-zA-Z\s]", "", text) |
|
return re.sub(r"\s+", " ", text).strip() |
|
|
|
cleaned_comments = [clean_text(comment["body"]) for comment in combined_dataset if "body" in comment] |
|
chunked_comments = [" ".join(cleaned_comments[i:i+chunk_size]) for i in range(0, len(cleaned_comments), chunk_size)] |
|
subreddit_labels = [ex["subreddit_name_prefixed"] for ex in combined_dataset[:len(chunked_comments)]] |
|
|
|
|
|
tokenized_chunks = [chunk.split() for chunk in chunked_comments] |
|
|
|
|
|
model = Word2Vec(sentences=tokenized_chunks, vector_size=100, window=5, min_count=2, workers=4, sg=1) |
|
model.save("reddit_word2vec.model") |
|
|
|
|
|
def embed_tokens(tokens, model): |
|
vectors = [model.wv[token] for token in tokens if token in model.wv] |
|
return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size) |
|
|
|
embeddings = np.array([embed_tokens(chunk, model) for chunk in tokenized_chunks]).astype("float32") |
|
index = faiss.IndexFlatL2(model.vector_size) |
|
index.add(embeddings) |
|
faiss.write_index(index, "reddit_faiss.index") |
|
|
|
|
|
model = Word2Vec.load("reddit_word2vec.model") |
|
index = faiss.read_index("reddit_faiss.index") |
|
subreddit_map = {i: label for i, label in enumerate(subreddit_labels)} |
|
unique_subreddits = sorted(set(subreddit_labels)) |
|
|
|
def embed_text(text): |
|
tokens = clean_text(text).split() |
|
return embed_tokens(tokens, model).astype("float32") |
|
|
|
def search_reddit(query, selected_subreddit, top_k=5): |
|
query_vec = embed_text(query).reshape(1, -1) |
|
D, I = index.search(query_vec, top_k) |
|
|
|
results = [] |
|
for idx in I[0]: |
|
if idx < len(chunked_comments) and subreddit_map.get(idx) == selected_subreddit: |
|
results.append(f"๐ธ {chunked_comments[idx]}") |
|
if len(results) >= top_k: |
|
break |
|
|
|
if not results: |
|
return "โ ๏ธ No relevant results found." |
|
return "\n\n".join(results) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Base(primary_hue="orange", secondary_hue="gray")) as demo: |
|
gr.Image( |
|
value="https://1000logos.net/wp-content/uploads/2017/05/Reddit-Logo.png", |
|
show_label=False, |
|
height=100 |
|
) |
|
gr.Markdown("## ๐พ Reddit Semantic Search (Powered by Word2Vec + FAISS)\n_Disclaimer: Experimental prototype, not owned/developed by Reddit Inc_") |
|
with gr.Row(): |
|
query = gr.Textbox(label="Enter your Reddit-like query", placeholder="e.g. What's new in AI?") |
|
subreddit_dropdown = gr.Dropdown(choices=unique_subreddits, label="Filter by Subreddit") |
|
output = gr.Textbox(label="Top Matching Chunks", lines=10) |
|
search_btn = gr.Button("๐ Search") |
|
|
|
search_btn.click(fn=search_reddit, inputs=[query, subreddit_dropdown], outputs=output) |
|
|
|
demo.launch(share=True) |
|
|