GenAIDevTOProd's picture
Update app.py
7b8fa4f verified
raw
history blame
5.08 kB
# -*- coding: utf-8 -*-
import os
import json
import re
from itertools import chain, islice
import numpy as np
from gensim.models import Word2Vec
from tqdm import tqdm
import faiss
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from huggingface_hub import hf_hub_download, login
from huggingface_hub import HfApi
# Load token from Hugging Face Secrets
HF_TOKEN = os.environ.get("RedditSemanticSearch")
# Function to stream JSONL Reddit files from HF Hub
from datasets import load_dataset
# Define target subreddits
target_subreddits = ["askscience", "gaming", "technology", "todayilearned", "programming"]
# Load full Reddit dataset (assumes it's pre-split by subreddit or has a field)
dataset_splits = [load_dataset("HuggingFaceGECLM/REDDIT_comments", split=sub, streaming=True) for sub in target_subreddits]
# Filter only relevant subreddits
dataset = dataset.filter(lambda x: x["subreddit"] in target_subreddits)
# Take a sample (to limit memory for now)
comments = [{"body": ex["body"]} for ex in dataset.select(range(100000))]
import pandas as pd
import re
from itertools import islice
# Load a sample of the dataset (e.g., 100,000 records for performance)
comments = [{"body": ex["body"]} for ex in islice(combined_dataset, 100000)]
# Convert to DataFrame
df = pd.DataFrame(comments)
# Clean text function
def clean_body(text):
text = text.lower()
text = re.sub(r"http\S+|www\S+|https\S+", "", text)
text = re.sub(r"[^a-zA-Z\s]", "", text)
return re.sub(r"\s+", " ", text).strip()
# Apply cleaning
df["clean"] = df["body"].apply(clean_body)
# Chunk every 5 rows
chunk_size = 5
df["chunk_id"] = df.index // chunk_size
df_chunked = df.groupby("chunk_id")["clean"].apply(lambda texts: " ".join(texts)).reset_index()
df_chunked.rename(columns={"clean": "chunk_text"}, inplace=True)
# Final list for embedding
chunked_comments = df_chunked["chunk_text"].tolist()
# Create subreddit labels
combined_dataset = chain(*(load_reddit_split(sub) for sub in target_subreddits))
subreddit_labels = []
for example in combined_dataset:
subreddit_labels.append(example["subreddit_name_prefixed"])
if len(subreddit_labels) >= len(chunked_comments):
break
# Tokenize
def clean_text(text):
text = text.lower()
text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE)
text = re.sub(r"[^a-zA-Z\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
tokenized_chunks = []
for chunk in tqdm(chunked_comments):
cleaned = clean_text(chunk)
tokens = cleaned.split()
tokenized_chunks.append(tokens)
# Train Word2Vec
model = Word2Vec(sentences=tokenized_chunks, vector_size=100, window=5, min_count=2, workers=4, sg=1)
model.save("reddit_word2vec.model")
# Embedding function
def get_chunk_embedding(chunk_tokens, model):
vectors = [model.wv[token] for token in chunk_tokens if token in model.wv]
if not vectors:
return np.zeros(model.vector_size)
return np.mean(vectors, axis=0)
chunk_embeddings = [get_chunk_embedding(tokens, model) for tokens in tokenized_chunks]
embedding_matrix = np.array(chunk_embeddings).astype("float32")
# Build FAISS index
index = faiss.IndexFlatL2(model.vector_size)
index.add(embedding_matrix)
faiss.write_index(index, "reddit_faiss.index")
# Load model and index for search API
model = Word2Vec.load("reddit_word2vec.model")
index = faiss.read_index("reddit_faiss.index")
subreddit_map = {i: label for i, label in enumerate(subreddit_labels)}
unique_subreddits = sorted(set(subreddit_labels))
original_chunks = [" ".join(tokens) for tokens in tokenized_chunks]
# Search function
def embed_text(text):
tokens = text.lower().split()
vectors = [model.wv[token] for token in tokens if token in model.wv]
return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size)
def search_reddit(query, selected_subreddit, top_k=5):
query_vec = embed_text(query).astype("float32").reshape(1, -1)
D, I = index.search(query_vec, top_k * 2)
results = []
for idx in I[0]:
if idx < len(chunked_comments) and subreddit_map[idx] == selected_subreddit:
results.append(f"🔸 {chunked_comments[idx]}")
if len(results) >= top_k:
break
return "\n\n".join(results) if results else "⚠️ No relevant results found."
# Gradio UI
with gr.Blocks(theme=gr.themes.Base(primary_hue="orange")) as demo:
gr.Image(value="https://1000logos.net/wp-content/uploads/2017/05/Reddit-Logo.png", show_label=False, height=100)
gr.Markdown("## Reddit Semantic Search (Powered by Word2Vec + FAISS)\n_Disclaimer: Prototype, not affiliated with Reddit Inc._")
with gr.Row():
query = gr.Textbox(label="Enter Reddit-style query")
subreddit_dropdown = gr.Dropdown(choices=unique_subreddits, label="Choose Subreddit")
output = gr.Textbox(label="Matching Comments", lines=10)
search_btn = gr.Button("🔍 Search")
search_btn.click(fn=search_reddit, inputs=[query, subreddit_dropdown], outputs=output)
demo.launch(share=True)