AJ-Gazin commited on
Commit
960b542
·
1 Parent(s): e7fe866

added streamlit app

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components
4
+ import pandas as pd
5
+ import torch
6
+ import requests
7
+ import random
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ from torch_geometric.nn import SAGEConv, to_hetero, Linear
11
+ from dotenv import load_dotenv
12
+ import os
13
+
14
+ from IPython.display import HTML
15
+
16
+ import viz_utils
17
+ import model_def
18
+
19
+ load_dotenv() #load environment variables from .env file
20
+
21
+ ##no clue why this is necessary. But won't see subfolders without it. Just on my laptop.
22
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ API_KEY = os.getenv("HUGGINGFACE_API_KEY")
25
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
26
+
27
+ # --- LOAD DATA AND MODEL ---
28
+ movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data
29
+ data = torch.load("./PyGdata.pt")
30
+
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ model = model_def.Model(hidden_channels=32).to(device)
33
+ model.load_state_dict(torch.load("PyGTrainedModelState.pt"))
34
+ model.eval()
35
+
36
+ # --- STREAMLIT APP ---
37
+ st.title("Movie Recommendation App")
38
+
39
+
40
+
41
+ # --- VISUALIZATIONS ---
42
+ #with open("umap_visualization.html", "r", encoding='utf-8') as f:
43
+ # umap_html = f.read()
44
+
45
+ #with open("tsne_visualization.html", "r") as f:
46
+ # tsne_html = f.read()
47
+
48
+ #with open("pca_visualization.html", "r") as f:
49
+ # pca_html = f.read()
50
+
51
+ tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
52
+
53
+
54
+ with torch.no_grad():
55
+ a = model.encoder(data.x_dict,data.edge_index_dict)
56
+ user = pd.DataFrame(a['user'].detach().cpu())
57
+ movie = pd.DataFrame(a['movie'].detach().cpu())
58
+ embedding_df = pd.concat([user, movie], axis=0)
59
+
60
+ with tab1:
61
+ umap_expander = st.expander("UMAP Visualization")
62
+ with umap_expander:
63
+ st.subheader('UMAP Visualization')
64
+ umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
65
+ st.plotly_chart(umap_fig)
66
+ #components.html(umap_html, width=800, height=800)
67
+
68
+ tsne_expander = st.expander("TSNE Visualization")
69
+ with tsne_expander:
70
+ st.subheader('TSNE Visualization')
71
+ tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
72
+ st.plotly_chart(tsne_fig)
73
+ #components.html(tsne_html, width=800, height=800)
74
+
75
+ pca_expander = st.expander("PCA Visualization")
76
+ with pca_expander:
77
+ st.subheader('PCA Visualization')
78
+ pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
79
+ st.plotly_chart(pca_fig)
80
+ #components.html(pca_html, width=800, height=800)
81
+
82
+
83
+
84
+
85
+ def get_movie_recommendations(model, data, user_id, total_movies):
86
+ user_row = torch.tensor([user_id] * total_movies).to(device)
87
+ all_movie_ids = torch.arange(total_movies).to(device)
88
+ edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
89
+
90
+ pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu')
91
+ top_five_indices = pred.topk(5).indices
92
+
93
+ recommended_movies = movies_df.iloc[top_five_indices]
94
+ return recommended_movies
95
+
96
+ def generate_poster(movie_title):
97
+ headers = {"Authorization": f"Bearer {API_KEY}"}
98
+
99
+ #creates random seed so movie poster changes on refresh even if same title.
100
+ seed = random.randint(0, 2**32 - 1)
101
+ payload = {
102
+ "inputs": movie_title,
103
+ # "parameters": {
104
+ # "seed": seed
105
+ # }
106
+ }
107
+
108
+ try:
109
+ response = requests.post(API_URL, headers=headers, json=payload)
110
+ response.raise_for_status() # Raise an error if the request fails
111
+
112
+ # Display the generated image
113
+ image = Image.open(BytesIO(response.content))
114
+ st.image(image, caption=movie_title)
115
+
116
+ except requests.exceptions.HTTPError as err:
117
+ st.error(f"Image generation failed: {err}")
118
+
119
+ with tab2:
120
+ user_id = st.number_input("Enter the User ID:", min_value=0)
121
+ if st.button("Get Recommendations"):
122
+ st.write("Top 5 Recommendations:")
123
+ try:
124
+ total_movies = data['movie'].num_nodes
125
+ recommended_movies = get_movie_recommendations(model, data, user_id, total_movies)
126
+ cols = st.columns(3)
127
+
128
+
129
+ for i, row in recommended_movies.iterrows():
130
+ with cols[i % 3]:
131
+ #st.write(f"{i+1}. {row['title']}")
132
+ try:
133
+ image = generate_poster(row['title'])
134
+ except requests.exceptions.HTTPError as err:
135
+ st.error(f"Image generation failed for {row['title']}: {err}")
136
+
137
+ except Exception as e:
138
+ st.error(f"An error occurred: {e}")