File size: 3,632 Bytes
581cab8
030d162
11db2b7
 
030d162
 
 
 
 
 
11db2b7
030d162
581cab8
030d162
581cab8
030d162
11db2b7
030d162
 
11db2b7
 
581cab8
030d162
581cab8
030d162
 
581cab8
 
 
 
 
030d162
 
 
581cab8
030d162
 
581cab8
030d162
581cab8
 
 
030d162
 
581cab8
 
 
030d162
581cab8
 
 
 
030d162
581cab8
 
030d162
 
581cab8
 
030d162
 
581cab8
 
030d162
 
581cab8
 
 
030d162
581cab8
 
 
 
 
 
 
 
030d162
581cab8
 
 
 
 
 
030d162
581cab8
 
030d162
581cab8
 
 
 
 
030d162
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
import re
import json
from itertools import chain
import numpy as np
from gensim.models import Word2Vec
from tqdm import tqdm
import faiss
import gradio as gr
from huggingface_hub import hf_hub_download

# --- CONFIGURATION ---
target_subreddits = ["askscience", "gaming", "technology", "todayilearned", "programming"]
chunk_size = 5

# --- LOAD REDDIT COMMENTS ---
def load_reddit_split(subreddit_name):
    path = hf_hub_download(repo_id="HuggingFaceGECLM/REDDIT_comments", filename=f"{subreddit_name}.jsonl")
    with open(path, "r") as f:
        for line in f:
            yield json.loads(line)

combined_dataset = list(chain(*(load_reddit_split(sub) for sub in target_subreddits)))

# --- CLEAN + CHUNK ---
def clean_text(text):
    text = text.lower()
    text = re.sub(r"http\S+|www\S+|https\S+", "", text)
    text = re.sub(r"[^a-zA-Z\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()

cleaned_comments = [clean_text(comment["body"]) for comment in combined_dataset if "body" in comment]
chunked_comments = [" ".join(cleaned_comments[i:i+chunk_size]) for i in range(0, len(cleaned_comments), chunk_size)]
subreddit_labels = [ex["subreddit_name_prefixed"] for ex in combined_dataset[:len(chunked_comments)]]

# --- TOKENIZE ---
tokenized_chunks = [chunk.split() for chunk in chunked_comments]

# --- TRAIN WORD2VEC ---
model = Word2Vec(sentences=tokenized_chunks, vector_size=100, window=5, min_count=2, workers=4, sg=1)
model.save("reddit_word2vec.model")

# --- EMBEDDINGS + FAISS ---
def embed_tokens(tokens, model):
    vectors = [model.wv[token] for token in tokens if token in model.wv]
    return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size)

embeddings = np.array([embed_tokens(chunk, model) for chunk in tokenized_chunks]).astype("float32")
index = faiss.IndexFlatL2(model.vector_size)
index.add(embeddings)
faiss.write_index(index, "reddit_faiss.index")

# --- SEARCH LOGIC ---
model = Word2Vec.load("reddit_word2vec.model")
index = faiss.read_index("reddit_faiss.index")
subreddit_map = {i: label for i, label in enumerate(subreddit_labels)}
unique_subreddits = sorted(set(subreddit_labels))

def embed_text(text):
    tokens = clean_text(text).split()
    return embed_tokens(tokens, model).astype("float32")

def search_reddit(query, selected_subreddit, top_k=5):
    query_vec = embed_text(query).reshape(1, -1)
    D, I = index.search(query_vec, top_k)

    results = []
    for idx in I[0]:
        if idx < len(chunked_comments) and subreddit_map.get(idx) == selected_subreddit:
            results.append(f"🔸 {chunked_comments[idx]}")
        if len(results) >= top_k:
            break

    if not results:
        return "⚠️ No relevant results found."
    return "\n\n".join(results)

# --- GRADIO APP ---
with gr.Blocks(theme=gr.themes.Base(primary_hue="orange", secondary_hue="gray")) as demo:
    gr.Image(
        value="https://1000logos.net/wp-content/uploads/2017/05/Reddit-Logo.png",
        show_label=False,
        height=100
    )
    gr.Markdown("## 👾 Reddit Semantic Search (Powered by Word2Vec + FAISS)\n_Disclaimer: Experimental prototype, not owned/developed by Reddit Inc_")
    with gr.Row():
        query = gr.Textbox(label="Enter your Reddit-like query", placeholder="e.g. What's new in AI?")
        subreddit_dropdown = gr.Dropdown(choices=unique_subreddits, label="Filter by Subreddit")
    output = gr.Textbox(label="Top Matching Chunks", lines=10)
    search_btn = gr.Button("🔍 Search")

    search_btn.click(fn=search_reddit, inputs=[query, subreddit_dropdown], outputs=output)

demo.launch(share=True)