File size: 5,732 Bytes
e7e5b40
 
 
 
bb0ac24
e7e5b40
 
 
 
 
483de8a
 
e7e5b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad3da90
e7e5b40
 
ad3da90
e7e5b40
 
 
 
 
 
 
 
bb0ac24
 
 
483de8a
bb0ac24
 
 
 
 
 
483de8a
bb0ac24
 
 
 
 
b7c66d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
import torch
import pandas as pd
import numpy as np
from flask import Flask, request, jsonify
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

flask_app = Flask(__name__)

class ModeratelySimplifiedGATConvModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
        self.dropout1 = torch.nn.Dropout(0.45)
        self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)

    def forward(self, x, edge_index, edge_attr=None):
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        x = self.dropout1(x)
        x = self.conv2(x, edge_index, edge_attr)
        return x

# Load the dataset and the GATConv model
data = torch.load("graph_data.pt", map_location=torch.device("cpu"))

# Correct the state dictionary's key names
original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
corrected_state_dict = {}
for key, value in original_state_dict.items():
    if "lin.weight" in key:
        corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
        corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
    else:
        corrected_state_dict[key] = value

# Initialize the GATConv model with the corrected state dictionary
gatconv_model = ModeratelySimplifiedGATConvModel(
    in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
)
gatconv_model.load_state_dict(corrected_state_dict)

# Load the BERT-based sentence transformer model
model_bert = SentenceTransformer("all-mpnet-base-v2")

# Ensure the DataFrame is loaded properly
try:
    df = pd.read_json("combined_data.json.gz", orient='records', lines=True, compression='gzip')
except Exception as e:
    st.error(f"Error reading JSON file: {e}")

# Generate GNN-based embeddings
with torch.no_grad():
    all_video_embeddings = gatconv_model(data.x, data.edge_index, data.edge_attr).cpu()

# Function to find the most similar video and recommend the top 10 based on GNN embeddings
def get_similar_and_recommend(input_text):
    # Find the most similar video based on cosine similarity
    embeddings_matrix = np.array(df["embeddings"].tolist())
    input_embedding = model_bert.encode([input_text])[0]
    similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]

    most_similar_index = np.argmax(similarities)  # Find the most similar video

    # Get all features of the most similar video
    most_similar_video_features = df.iloc[most_similar_index].to_dict()

    # Clean up certain fields
    if "text_for_embedding" in most_similar_video_features:
        del most_similar_video_features["text_for_embedding"]
    if "embeddings" in most_similar_video_features:
        del most_similar_video_features["embeddings"]

    # Recommend the top 10 videos based on GNN embeddings
    def recommend_top_10(given_video_index, all_video_embeddings):
        dot_products = [
            torch.dot(all_video_embeddings[given_video_index], all_video_embeddings[i])
            for i in range(all_video_embeddings.shape[0])
        ]
        dot_products[given_video_index] = -float("inf")  # Exclude the most similar video

        top_10_indices = np.argsort(dot_products)[::-1][:10]
        return [df.iloc[idx].to_dict() for idx in top_10_indices]

    top_10_recommended_videos_features = recommend_top_10(most_similar_index, all_video_embeddings)

    # Apply search context to determine weights for GNN results
    user_keywords = input_text.split()  # Create a list of keywords from user input
    video_weights = []
    weight = 1.0  # Initial weight factor

    for keyword in user_keywords:
        if keyword.lower() in df["title"].str.lower().tolist():  # Check for matching keywords
            weight += 0.1  # Increase weight for matching keyword

    # Calculate the weight for each GNN output
    video_weights = [weight] * len(top_10_recommended_videos_features)

    # Clean up certain fields in recommendations
    for recommended_video in top_10_recommended_videos_features:
        if "text_for_embedding" in recommended_video:
            del recommended_video["text_for_embedding"]
        if "embeddings" in recommended_video:
            del recommended_video["embeddings"]

    # Create the output JSON with the most similar video, final recommendations, and weights
    output = {
        "search_context": {
            "input_text": input_text,  # What the user provided
            "weights": video_weights,  # Weights for each GNN-based recommendation
        },
        "most_similar_video": most_similar_video_features,
        "final_recommendations": top_10_recommended_videos_features  # Top 10 recommended videos
    }


    return output

# Create a Streamlit text input widget for entering text and retrieve the most similar video and top 10 recommended videos
user_input = st.text_input("Enter text to find the most similar video")

if user_input:
    recommendations = get_similar_and_recommend(user_input)
    st.json(recommendations)



@flask_app.route('/recommend', methods=['POST'])
def recommend():
    input_text = request.json['input_text']

    recommendations = get_similar_and_recommend(input_text)
    return jsonify(recommendations)


# Create a simple Streamlit interface with instructions
st.title("Video Recommendation API")
st.write("Use POST requests to `/recommend` with JSON data {'input_text': '<your text>'}")

if __name__ == "__main__":
    flask_app.run(host='0.0.0.0', port=8501)