# -*- coding: utf-8 -*- import os import json import re from itertools import chain, islice import numpy as np from gensim.models import Word2Vec from tqdm import tqdm import faiss import gradio as gr from sklearn.metrics.pairwise import cosine_similarity from huggingface_hub import hf_hub_download, login from huggingface_hub import HfApi # Load token from Hugging Face Secrets HF_TOKEN = os.environ.get("RedditSemanticSearch") # Function to stream JSONL Reddit files from HF Hub from datasets import load_dataset # Define target subreddits target_subreddits = ["askscience", "gaming", "technology", "todayilearned", "programming"] # Load full Reddit dataset (assumes it's pre-split by subreddit or has a field) dataset_splits = [load_dataset("HuggingFaceGECLM/REDDIT_comments", split=sub, streaming=True) for sub in target_subreddits] # Filter only relevant subreddits dataset = dataset.filter(lambda x: x["subreddit"] in target_subreddits) # Take a sample (to limit memory for now) comments = [{"body": ex["body"]} for ex in dataset.select(range(100000))] import pandas as pd import re from itertools import islice # Load a sample of the dataset (e.g., 100,000 records for performance) comments = [{"body": ex["body"]} for ex in islice(combined_dataset, 100000)] # Convert to DataFrame df = pd.DataFrame(comments) # Clean text function def clean_body(text): text = text.lower() text = re.sub(r"http\S+|www\S+|https\S+", "", text) text = re.sub(r"[^a-zA-Z\s]", "", text) return re.sub(r"\s+", " ", text).strip() # Apply cleaning df["clean"] = df["body"].apply(clean_body) # Chunk every 5 rows chunk_size = 5 df["chunk_id"] = df.index // chunk_size df_chunked = df.groupby("chunk_id")["clean"].apply(lambda texts: " ".join(texts)).reset_index() df_chunked.rename(columns={"clean": "chunk_text"}, inplace=True) # Final list for embedding chunked_comments = df_chunked["chunk_text"].tolist() # Create subreddit labels combined_dataset = chain(*(load_reddit_split(sub) for sub in target_subreddits)) subreddit_labels = [] for example in combined_dataset: subreddit_labels.append(example["subreddit_name_prefixed"]) if len(subreddit_labels) >= len(chunked_comments): break # Tokenize def clean_text(text): text = text.lower() text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) text = re.sub(r"[^a-zA-Z\s]", "", text) text = re.sub(r"\s+", " ", text).strip() return text tokenized_chunks = [] for chunk in tqdm(chunked_comments): cleaned = clean_text(chunk) tokens = cleaned.split() tokenized_chunks.append(tokens) # Train Word2Vec model = Word2Vec(sentences=tokenized_chunks, vector_size=100, window=5, min_count=2, workers=4, sg=1) model.save("reddit_word2vec.model") # Embedding function def get_chunk_embedding(chunk_tokens, model): vectors = [model.wv[token] for token in chunk_tokens if token in model.wv] if not vectors: return np.zeros(model.vector_size) return np.mean(vectors, axis=0) chunk_embeddings = [get_chunk_embedding(tokens, model) for tokens in tokenized_chunks] embedding_matrix = np.array(chunk_embeddings).astype("float32") # Build FAISS index index = faiss.IndexFlatL2(model.vector_size) index.add(embedding_matrix) faiss.write_index(index, "reddit_faiss.index") # Load model and index for search API model = Word2Vec.load("reddit_word2vec.model") index = faiss.read_index("reddit_faiss.index") subreddit_map = {i: label for i, label in enumerate(subreddit_labels)} unique_subreddits = sorted(set(subreddit_labels)) original_chunks = [" ".join(tokens) for tokens in tokenized_chunks] # Search function def embed_text(text): tokens = text.lower().split() vectors = [model.wv[token] for token in tokens if token in model.wv] return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size) def search_reddit(query, selected_subreddit, top_k=5): query_vec = embed_text(query).astype("float32").reshape(1, -1) D, I = index.search(query_vec, top_k * 2) results = [] for idx in I[0]: if idx < len(chunked_comments) and subreddit_map[idx] == selected_subreddit: results.append(f"🔸 {chunked_comments[idx]}") if len(results) >= top_k: break return "\n\n".join(results) if results else "⚠️ No relevant results found." # Gradio UI with gr.Blocks(theme=gr.themes.Base(primary_hue="orange")) as demo: gr.Image(value="https://1000logos.net/wp-content/uploads/2017/05/Reddit-Logo.png", show_label=False, height=100) gr.Markdown("## Reddit Semantic Search (Powered by Word2Vec + FAISS)\n_Disclaimer: Prototype, not affiliated with Reddit Inc._") with gr.Row(): query = gr.Textbox(label="Enter Reddit-style query") subreddit_dropdown = gr.Dropdown(choices=unique_subreddits, label="Choose Subreddit") output = gr.Textbox(label="Matching Comments", lines=10) search_btn = gr.Button("🔍 Search") search_btn.click(fn=search_reddit, inputs=[query, subreddit_dropdown], outputs=output) demo.launch(share=True)