Spaces:
Sleeping
Sleeping
Balaji S
commited on
First commit
Browse files- .gitattributes +35 -35
- .gitignore +1 -0
- README.md +18 -14
- app.py +75 -0
- configs.yaml +10 -0
- embeddings.py +64 -0
- loss_utils.py +63 -0
- model.py +325 -0
- preprocessing.py +56 -0
- requirements.txt +4 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
README.md
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.39.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
license:
|
11 |
-
short_description:
|
12 |
-
---
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DA626
|
3 |
+
emoji: 🐠
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.39.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: mit
|
11 |
+
short_description: Deployment of EasyRec
|
12 |
+
---
|
13 |
+
|
14 |
+
The [EasyRec](https://github.com/HKUDS/EasyRec) and [EasyRec-Forked](https://github.com/jibala-1022/EasyRec) models is used to recommend movies from [TMDB 5000 Movies Dataset](https://www.kaggle.com/datasets/tmdb/tmdb-movie-metadata?select=tmdb_5000_movies.csv)
|
15 |
+
|
16 |
+
`requirements.txt` contains packages needed during deployment only. Thus `json` is omitted.
|
17 |
+
|
18 |
+
Execute `preprocessing.py` to compute movie embeddings before deploying the app.
|
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
from embeddings import load_model, compute_embeddings
|
6 |
+
|
7 |
+
# Load configuration from YAML file
|
8 |
+
with open("configs.yaml", "r") as file:
|
9 |
+
configs = yaml.safe_load(file)
|
10 |
+
|
11 |
+
# Load the processed movie dataset
|
12 |
+
movie_data = pd.read_csv(configs["processed_dataset"])
|
13 |
+
|
14 |
+
# Streamlit app
|
15 |
+
st.title("🎬 EasyRec Movie Recommender")
|
16 |
+
|
17 |
+
# Dropdown for model selection
|
18 |
+
model_names = configs['hf_models'] # Assuming this is a list of model names in your configs.yaml
|
19 |
+
selected_model_name = st.selectbox("Select a model:", model_names)
|
20 |
+
|
21 |
+
# Load the model based on user selection
|
22 |
+
model, tokenizer = load_model(selected_model_name)
|
23 |
+
|
24 |
+
# User input for movie description
|
25 |
+
user_description = st.text_input("Enter a description of the type of movie you're interested in:",
|
26 |
+
placeholder="e.g. A romantic comedy with a twist...")
|
27 |
+
|
28 |
+
if user_description:
|
29 |
+
# Load the precomputed movie embeddings from a .pt file
|
30 |
+
embedding_dir_path = f"{configs['movie_embeddings']}/{selected_model_name}"
|
31 |
+
embedding_file_path = f"{embedding_dir_path}/{configs['movie_embeddings']}.pt"
|
32 |
+
movie_embeddings = torch.load(embedding_file_path) # Load the .pt file
|
33 |
+
|
34 |
+
# Compute the embedding for the user input by passing it as a list
|
35 |
+
user_embedding = compute_embeddings([user_description], model, tokenizer)
|
36 |
+
|
37 |
+
similarity_scores = torch.matmul(movie_embeddings, user_embedding.T).flatten()
|
38 |
+
|
39 |
+
# Set the number of top recommendations to display
|
40 |
+
K = 5
|
41 |
+
top_k_indices = torch.argsort(similarity_scores, descending=True)[:K].tolist() # Get indices of top K
|
42 |
+
|
43 |
+
# Display recommendations
|
44 |
+
st.write("## 🎉 Top Recommendations:")
|
45 |
+
|
46 |
+
for rank, movie_id in enumerate(top_k_indices, start=1):
|
47 |
+
movie = movie_data.iloc[movie_id]
|
48 |
+
|
49 |
+
# Convert runtime from minutes to hours and minutes
|
50 |
+
hours = movie.runtime // 60
|
51 |
+
minutes = movie.runtime % 60
|
52 |
+
|
53 |
+
# Construct an HTML card for displaying the movie information
|
54 |
+
st.markdown(f"### {rank}. {movie.title}")
|
55 |
+
st.markdown(f"**Release Date:** {movie.release_date} **Runtime:** {f'{hours}h {minutes}m' if hours > 0 else f'{minutes}m'}")
|
56 |
+
st.markdown(f"⭐ {movie.vote_average} ({movie.vote_count} votes)")
|
57 |
+
st.markdown(f"**Overview:** {movie.overview}")
|
58 |
+
st.markdown(f"**Genres:** {movie.genres}")
|
59 |
+
st.markdown(f"**Production Companies:** {movie.production_companies}")
|
60 |
+
st.markdown(f"**Production Countries:** {movie.production_countries}")
|
61 |
+
st.markdown("---")
|
62 |
+
|
63 |
+
# Additional styling (optional)
|
64 |
+
st.markdown(
|
65 |
+
"""
|
66 |
+
<style>
|
67 |
+
.stTextInput > div > input {
|
68 |
+
background-color: #f0f0f5; /* Light gray background */
|
69 |
+
border: 1px solid #ccc; /* Gray border */
|
70 |
+
border-radius: 5px; /* Rounded corners */
|
71 |
+
}
|
72 |
+
</style>
|
73 |
+
""",
|
74 |
+
unsafe_allow_html=True
|
75 |
+
)
|
configs.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hf_models:
|
2 |
+
- "hkuds/easyrec-roberta-small"
|
3 |
+
- "hkuds/easyrec-roberta-base"
|
4 |
+
- "hkuds/easyrec-roberta-large"
|
5 |
+
- "jibala-1022/easyrec-small"
|
6 |
+
- "jibala-1022/easyrec-base"
|
7 |
+
- "jibala-1022/easyrec-large"
|
8 |
+
dataset: "data/tmdb_5000_movies.csv"
|
9 |
+
processed_dataset: "data/tmdb_5000_movies_processed.csv"
|
10 |
+
movie_embeddings: "movie_embeddings"
|
embeddings.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import AutoConfig, AutoTokenizer
|
6 |
+
from model import Easyrec
|
7 |
+
|
8 |
+
|
9 |
+
def load_model(model_path: str) -> Tuple[Easyrec, AutoTokenizer]:
|
10 |
+
"""
|
11 |
+
Load the pre-trained model and tokenizer from the specified path.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
model_path: The path to the pre-trained huggingface model or local directory.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple: A tuple containing the model and tokenizer.
|
18 |
+
"""
|
19 |
+
|
20 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
+
|
22 |
+
config = AutoConfig.from_pretrained(model_path)
|
23 |
+
model = Easyrec.from_pretrained(model_path, config=config).to(device)
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
25 |
+
|
26 |
+
return model, tokenizer
|
27 |
+
|
28 |
+
|
29 |
+
def compute_embeddings(
|
30 |
+
sentences: List[str],
|
31 |
+
model: Easyrec,
|
32 |
+
tokenizer: AutoTokenizer,
|
33 |
+
batch_size: int = 8) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
Compute embeddings for a list of sentences using the specified model and tokenizer.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
sentences: A list of sentences for which to compute embeddings.
|
39 |
+
model: The pre-trained model used for generating embeddings.
|
40 |
+
tokenizer: The tokenizer used to preprocess the sentences.
|
41 |
+
batch_size: The number of sentences to process in each batch (default is 8).
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
torch.Tensor: A tensor containing the normalized embeddings for the input sentences.
|
45 |
+
"""
|
46 |
+
|
47 |
+
embeddings = []
|
48 |
+
count_sentences = len(sentences)
|
49 |
+
device = next(model.parameters()).device # Get the device on which the model is located
|
50 |
+
|
51 |
+
for start in range(0, count_sentences, batch_size):
|
52 |
+
end = start + batch_size
|
53 |
+
batch_sentences = sentences[start:end]
|
54 |
+
|
55 |
+
inputs = tokenizer(batch_sentences, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
56 |
+
inputs = {key: val.to(device) for key, val in inputs.items()} # Move input tensors to the same device as the model
|
57 |
+
|
58 |
+
with torch.inference_mode():
|
59 |
+
outputs = model.encode(inputs['input_ids'], inputs['attention_mask'])
|
60 |
+
batch_embeddings = F.normalize(outputs.pooler_output.detach().float(), dim=-1)
|
61 |
+
|
62 |
+
embeddings.append(batch_embeddings.cpu())
|
63 |
+
|
64 |
+
return torch.cat(embeddings, dim=0) # Concatenate all computed embeddings into a single tensor
|
loss_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as t
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def cal_bpr_loss(anc_embeds, pos_embeds, neg_embeds):
|
5 |
+
pos_preds = (anc_embeds * pos_embeds).sum(-1)
|
6 |
+
neg_preds = (anc_embeds * neg_embeds).sum(-1)
|
7 |
+
return t.sum(F.softplus(neg_preds - pos_preds))
|
8 |
+
|
9 |
+
|
10 |
+
def reg_pick_embeds(embeds_list):
|
11 |
+
reg_loss = 0
|
12 |
+
for embeds in embeds_list:
|
13 |
+
reg_loss += embeds.square().sum()
|
14 |
+
return reg_loss
|
15 |
+
|
16 |
+
|
17 |
+
def cal_infonce_loss(embeds1, embeds2, all_embeds2, temp=1.0):
|
18 |
+
normed_embeds1 = embeds1 / t.sqrt(1e-8 + embeds1.square().sum(-1, keepdim=True))
|
19 |
+
normed_embeds2 = embeds2 / t.sqrt(1e-8 + embeds2.square().sum(-1, keepdim=True))
|
20 |
+
normed_all_embeds2 = all_embeds2 / t.sqrt(1e-8 + all_embeds2.square().sum(-1, keepdim=True))
|
21 |
+
nume_term = -(normed_embeds1 * normed_embeds2 / temp).sum(-1)
|
22 |
+
deno_term = t.log(t.sum(t.exp(normed_embeds1 @ normed_all_embeds2.T / temp), dim=-1))
|
23 |
+
cl_loss = (nume_term + deno_term).sum()
|
24 |
+
return cl_loss
|
25 |
+
|
26 |
+
|
27 |
+
def cal_infonce_loss_spec_nodes(embeds1, embeds2, nodes, temp):
|
28 |
+
embeds1 = F.normalize(embeds1 + 1e-8, p=2)
|
29 |
+
embeds2 = F.normalize(embeds2 + 1e-8, p=2)
|
30 |
+
pckEmbeds1 = embeds1[nodes]
|
31 |
+
pckEmbeds2 = embeds2[nodes]
|
32 |
+
nume = t.exp(t.sum(pckEmbeds1 * pckEmbeds2, dim=-1) / temp)
|
33 |
+
deno = t.exp(pckEmbeds1 @ embeds2.T / temp).sum(-1) + 1e-8
|
34 |
+
return -t.log(nume / deno).mean()
|
35 |
+
|
36 |
+
|
37 |
+
def cal_sce_loss(x, y, alpha):
|
38 |
+
x = F.normalize(x, p=2, dim=-1)
|
39 |
+
y = F.normalize(y, p=2, dim=-1)
|
40 |
+
loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
|
41 |
+
loss = loss.mean()
|
42 |
+
return loss
|
43 |
+
|
44 |
+
|
45 |
+
def cal_rank_loss(stu_anc_emb, stu_pos_emb, stu_neg_emb, tea_anc_emb, tea_pos_emb, tea_neg_emb):
|
46 |
+
stu_pos_score = (stu_anc_emb * stu_pos_emb).sum(dim=-1)
|
47 |
+
stu_neg_score = (stu_anc_emb * stu_neg_emb).sum(dim=-1)
|
48 |
+
stu_r_score = F.sigmoid(stu_pos_score - stu_neg_score)
|
49 |
+
|
50 |
+
tea_pos_score = (tea_anc_emb * tea_pos_emb).sum(dim=-1)
|
51 |
+
tea_neg_score = (tea_anc_emb * tea_neg_emb).sum(dim=-1)
|
52 |
+
tea_r_score = F.sigmoid(tea_pos_score - tea_neg_score)
|
53 |
+
|
54 |
+
rank_loss = -(tea_r_score * t.log(stu_r_score + 1e-8) + (1 - tea_r_score) * t.log(1 - stu_r_score + 1e-8)).mean()
|
55 |
+
|
56 |
+
return rank_loss
|
57 |
+
|
58 |
+
|
59 |
+
def reg_params(model):
|
60 |
+
reg_loss = 0
|
61 |
+
for W in model.parameters():
|
62 |
+
reg_loss += W.norm(2).square()
|
63 |
+
return reg_loss
|
model.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
from tqdm import tqdm
|
6 |
+
import scipy.sparse as sp
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
import transformers
|
11 |
+
from transformers import RobertaTokenizer
|
12 |
+
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
|
13 |
+
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
|
14 |
+
from transformers.activations import gelu
|
15 |
+
from transformers.file_utils import (
|
16 |
+
add_code_sample_docstrings,
|
17 |
+
add_start_docstrings,
|
18 |
+
add_start_docstrings_to_model_forward,
|
19 |
+
replace_return_docstrings,
|
20 |
+
)
|
21 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
|
22 |
+
|
23 |
+
from loss_utils import *
|
24 |
+
|
25 |
+
init = nn.init.xavier_uniform_
|
26 |
+
uniformInit = nn.init.uniform
|
27 |
+
|
28 |
+
|
29 |
+
"""
|
30 |
+
EasyRec
|
31 |
+
"""
|
32 |
+
def dot_product_scores(q_vectors, ctx_vectors):
|
33 |
+
r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
|
34 |
+
return r
|
35 |
+
|
36 |
+
class MLPLayer(nn.Module):
|
37 |
+
"""
|
38 |
+
Head for getting sentence representations over RoBERTa/BERT's CLS representation.
|
39 |
+
"""
|
40 |
+
def __init__(self, config):
|
41 |
+
super().__init__()
|
42 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
43 |
+
self.activation = nn.Tanh()
|
44 |
+
|
45 |
+
def forward(self, features, **kwargs):
|
46 |
+
x = self.dense(features)
|
47 |
+
x = self.activation(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class Pooler(nn.Module):
|
52 |
+
"""
|
53 |
+
Parameter-free poolers to get the sentence embedding
|
54 |
+
'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
|
55 |
+
'cls_before_pooler': [CLS] representation without the original MLP pooler.
|
56 |
+
'avg': average of the last layers' hidden states at each token.
|
57 |
+
'avg_top2': average of the last two layers.
|
58 |
+
'avg_first_last': average of the first and the last layers.
|
59 |
+
"""
|
60 |
+
def __init__(self, pooler_type):
|
61 |
+
super().__init__()
|
62 |
+
self.pooler_type = pooler_type
|
63 |
+
assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type
|
64 |
+
|
65 |
+
def forward(self, attention_mask, outputs):
|
66 |
+
last_hidden = outputs.last_hidden_state
|
67 |
+
pooler_output = outputs.pooler_output
|
68 |
+
hidden_states = outputs.hidden_states
|
69 |
+
|
70 |
+
if self.pooler_type in ['cls_before_pooler', 'cls']:
|
71 |
+
return last_hidden[:, 0]
|
72 |
+
elif self.pooler_type == "avg":
|
73 |
+
return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
|
74 |
+
elif self.pooler_type == "avg_first_last":
|
75 |
+
first_hidden = hidden_states[1]
|
76 |
+
last_hidden = hidden_states[-1]
|
77 |
+
pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
|
78 |
+
return pooled_result
|
79 |
+
elif self.pooler_type == "avg_top2":
|
80 |
+
second_last_hidden = hidden_states[-2]
|
81 |
+
last_hidden = hidden_states[-1]
|
82 |
+
pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
|
83 |
+
return pooled_result
|
84 |
+
else:
|
85 |
+
raise NotImplementedError
|
86 |
+
|
87 |
+
|
88 |
+
class Similarity(nn.Module):
|
89 |
+
"""
|
90 |
+
Dot product or cosine similarity
|
91 |
+
"""
|
92 |
+
def __init__(self, temp):
|
93 |
+
super().__init__()
|
94 |
+
self.temp = temp
|
95 |
+
self.cos = nn.CosineSimilarity(dim=-1)
|
96 |
+
|
97 |
+
def forward(self, x, y):
|
98 |
+
return self.cos(x, y) / self.temp
|
99 |
+
|
100 |
+
|
101 |
+
class Easyrec(RobertaPreTrainedModel):
|
102 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
103 |
+
|
104 |
+
def __init__(self, config, *model_args, **model_kargs):
|
105 |
+
super().__init__(config)
|
106 |
+
try:
|
107 |
+
self.model_args = model_kargs["model_args"]
|
108 |
+
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
109 |
+
if self.model_args.pooler_type == "cls":
|
110 |
+
self.mlp = MLPLayer(config)
|
111 |
+
if self.model_args.do_mlm:
|
112 |
+
self.lm_head = RobertaLMHead(config)
|
113 |
+
"""
|
114 |
+
Contrastive learning class init function.
|
115 |
+
"""
|
116 |
+
self.pooler_type = self.model_args.pooler_type
|
117 |
+
self.pooler = Pooler(self.pooler_type)
|
118 |
+
self.sim = Similarity(temp=self.model_args.temp)
|
119 |
+
self.init_weights()
|
120 |
+
except:
|
121 |
+
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
122 |
+
self.mlp = MLPLayer(config)
|
123 |
+
self.lm_head = RobertaLMHead(config)
|
124 |
+
self.pooler_type = 'cls'
|
125 |
+
self.pooler = Pooler(self.pooler_type)
|
126 |
+
self.init_weights()
|
127 |
+
|
128 |
+
def forward(self,
|
129 |
+
user_input_ids=None,
|
130 |
+
user_attention_mask=None,
|
131 |
+
pos_item_input_ids=None,
|
132 |
+
pos_item_attention_mask=None,
|
133 |
+
neg_item_input_ids=None,
|
134 |
+
neg_item_attention_mask=None,
|
135 |
+
token_type_ids=None,
|
136 |
+
position_ids=None,
|
137 |
+
head_mask=None,
|
138 |
+
inputs_embeds=None,
|
139 |
+
labels=None,
|
140 |
+
output_attentions=None,
|
141 |
+
output_hidden_states=None,
|
142 |
+
return_dict=None,
|
143 |
+
mlm_input_ids=None,
|
144 |
+
mlm_attention_mask=None,
|
145 |
+
mlm_labels=None,
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
Contrastive learning forward function.
|
149 |
+
"""
|
150 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
151 |
+
batch_size = user_input_ids.size(0)
|
152 |
+
|
153 |
+
# Get user embeddings
|
154 |
+
user_outputs = self.roberta(
|
155 |
+
input_ids=user_input_ids,
|
156 |
+
attention_mask=user_attention_mask,
|
157 |
+
token_type_ids=None,
|
158 |
+
position_ids=None,
|
159 |
+
head_mask=None,
|
160 |
+
inputs_embeds=None,
|
161 |
+
output_attentions=output_attentions,
|
162 |
+
output_hidden_states=output_hidden_states,
|
163 |
+
return_dict=return_dict,
|
164 |
+
)
|
165 |
+
|
166 |
+
# Get positive item embeddings
|
167 |
+
pos_item_outputs = self.roberta(
|
168 |
+
input_ids=pos_item_input_ids,
|
169 |
+
attention_mask=pos_item_attention_mask,
|
170 |
+
token_type_ids=None,
|
171 |
+
position_ids=None,
|
172 |
+
head_mask=None,
|
173 |
+
inputs_embeds=None,
|
174 |
+
output_attentions=output_attentions,
|
175 |
+
output_hidden_states=output_hidden_states,
|
176 |
+
return_dict=return_dict,
|
177 |
+
)
|
178 |
+
|
179 |
+
# Get negative item embeddings
|
180 |
+
neg_item_outputs = self.roberta(
|
181 |
+
input_ids=neg_item_input_ids,
|
182 |
+
attention_mask=neg_item_attention_mask,
|
183 |
+
token_type_ids=None,
|
184 |
+
position_ids=None,
|
185 |
+
head_mask=None,
|
186 |
+
inputs_embeds=None,
|
187 |
+
output_attentions=output_attentions,
|
188 |
+
output_hidden_states=output_hidden_states,
|
189 |
+
return_dict=return_dict,
|
190 |
+
)
|
191 |
+
|
192 |
+
# MLM auxiliary objective
|
193 |
+
if mlm_input_ids is not None:
|
194 |
+
mlm_outputs = self.roberta(
|
195 |
+
input_ids=mlm_input_ids,
|
196 |
+
attention_mask=mlm_attention_mask,
|
197 |
+
token_type_ids=None,
|
198 |
+
position_ids=None,
|
199 |
+
head_mask=None,
|
200 |
+
inputs_embeds=None,
|
201 |
+
output_attentions=output_attentions,
|
202 |
+
output_hidden_states=output_hidden_states,
|
203 |
+
return_dict=return_dict,
|
204 |
+
)
|
205 |
+
|
206 |
+
# Pooling
|
207 |
+
user_pooler_output = self.pooler(user_attention_mask, user_outputs)
|
208 |
+
pos_item_pooler_output = self.pooler(pos_item_attention_mask, pos_item_outputs)
|
209 |
+
neg_item_pooler_output = self.pooler(neg_item_attention_mask, neg_item_outputs)
|
210 |
+
|
211 |
+
# If using "cls", we add an extra MLP layer
|
212 |
+
# (same as BERT's original implementation) over the representation.
|
213 |
+
if self.pooler_type == "cls":
|
214 |
+
user_pooler_output = self.mlp(user_pooler_output)
|
215 |
+
pos_item_pooler_output = self.mlp(pos_item_pooler_output)
|
216 |
+
neg_item_pooler_output = self.mlp(neg_item_pooler_output)
|
217 |
+
|
218 |
+
# Gather all item embeddings if using distributed training
|
219 |
+
if dist.is_initialized() and self.training:
|
220 |
+
# Dummy vectors for allgather
|
221 |
+
user_list = [torch.zeros_like(user_pooler_output) for _ in range(dist.get_world_size())]
|
222 |
+
pos_item_list = [torch.zeros_like(pos_item_pooler_output) for _ in range(dist.get_world_size())]
|
223 |
+
neg_item_list = [torch.zeros_like(neg_item_pooler_output) for _ in range(dist.get_world_size())]
|
224 |
+
# Allgather
|
225 |
+
dist.all_gather(tensor_list=user_list, tensor=user_pooler_output.contiguous())
|
226 |
+
dist.all_gather(tensor_list=pos_item_list, tensor=pos_item_pooler_output.contiguous())
|
227 |
+
dist.all_gather(tensor_list=neg_item_list, tensor=neg_item_pooler_output.contiguous())
|
228 |
+
|
229 |
+
# Since allgather results do not have gradients, we replace the
|
230 |
+
# current process's corresponding embeddings with original tensors
|
231 |
+
user_list[dist.get_rank()] = user_pooler_output
|
232 |
+
pos_item_list[dist.get_rank()] = pos_item_pooler_output
|
233 |
+
neg_item_list[dist.get_rank()] = neg_item_pooler_output
|
234 |
+
|
235 |
+
# Get full batch embeddings
|
236 |
+
user_pooler_output = torch.cat(user_list, dim=0)
|
237 |
+
pos_item_pooler_output = torch.cat(pos_item_list, dim=0)
|
238 |
+
neg_item_pooler_output = torch.cat(neg_item_list, dim=0)
|
239 |
+
|
240 |
+
cos_sim = self.sim(user_pooler_output.unsqueeze(1), pos_item_pooler_output.unsqueeze(0))
|
241 |
+
neg_sim = self.sim(user_pooler_output.unsqueeze(1), neg_item_pooler_output.unsqueeze(0))
|
242 |
+
cos_sim = torch.cat([cos_sim, neg_sim], 1)
|
243 |
+
|
244 |
+
labels = torch.arange(cos_sim.size(0)).long().to(self.device)
|
245 |
+
loss_fct = nn.CrossEntropyLoss()
|
246 |
+
|
247 |
+
loss = loss_fct(cos_sim, labels)
|
248 |
+
|
249 |
+
# Calculate loss for MLM
|
250 |
+
if mlm_outputs is not None and mlm_labels is not None and self.model_args.do_mlm:
|
251 |
+
mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
|
252 |
+
prediction_scores = self.lm_head(mlm_outputs.last_hidden_state)
|
253 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
|
254 |
+
loss = loss + self.model_args.mlm_weight * masked_lm_loss
|
255 |
+
|
256 |
+
if not return_dict:
|
257 |
+
raise NotImplementedError
|
258 |
+
|
259 |
+
return SequenceClassifierOutput(
|
260 |
+
loss=loss,
|
261 |
+
logits=cos_sim,
|
262 |
+
)
|
263 |
+
|
264 |
+
def encode(self,
|
265 |
+
input_ids=None,
|
266 |
+
attention_mask=None,
|
267 |
+
token_type_ids=None,
|
268 |
+
position_ids=None,
|
269 |
+
head_mask=None,
|
270 |
+
inputs_embeds=None,
|
271 |
+
labels=None,
|
272 |
+
output_attentions=None,
|
273 |
+
output_hidden_states=None,
|
274 |
+
return_dict=None,
|
275 |
+
):
|
276 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
277 |
+
outputs = self.roberta(
|
278 |
+
input_ids=input_ids,
|
279 |
+
attention_mask=attention_mask,
|
280 |
+
token_type_ids=None,
|
281 |
+
position_ids=None,
|
282 |
+
head_mask=None,
|
283 |
+
inputs_embeds=None,
|
284 |
+
output_attentions=output_attentions,
|
285 |
+
output_hidden_states=output_hidden_states,
|
286 |
+
return_dict=return_dict,
|
287 |
+
)
|
288 |
+
pooler_output = self.pooler(attention_mask, outputs)
|
289 |
+
if self.pooler_type == "cls":
|
290 |
+
pooler_output = self.mlp(pooler_output)
|
291 |
+
if not return_dict:
|
292 |
+
return (outputs[0], pooler_output) + outputs[2:]
|
293 |
+
|
294 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
295 |
+
pooler_output=pooler_output,
|
296 |
+
last_hidden_state=outputs.last_hidden_state,
|
297 |
+
hidden_states=outputs.hidden_states,
|
298 |
+
)
|
299 |
+
|
300 |
+
def inference(self,
|
301 |
+
user_profile_list,
|
302 |
+
item_profile_list,
|
303 |
+
dataset_name,
|
304 |
+
tokenizer,
|
305 |
+
infer_batch_size=128
|
306 |
+
):
|
307 |
+
n_user = len(user_profile_list)
|
308 |
+
profiles = user_profile_list + item_profile_list
|
309 |
+
n_batch = math.ceil(len(profiles) / infer_batch_size)
|
310 |
+
text_embeds = []
|
311 |
+
for i in tqdm(range(n_batch), desc=f'Encoding Text {dataset_name}'):
|
312 |
+
batch_profiles = profiles[i * infer_batch_size: (i + 1) * infer_batch_size]
|
313 |
+
inputs = tokenizer(batch_profiles, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
314 |
+
for k in inputs:
|
315 |
+
inputs[k] = inputs[k].to(self.device)
|
316 |
+
with torch.inference_mode():
|
317 |
+
embeds = self.encode(
|
318 |
+
input_ids=inputs.input_ids,
|
319 |
+
attention_mask=inputs.attention_mask
|
320 |
+
)
|
321 |
+
text_embeds.append(embeds.pooler_output.detach().cpu())
|
322 |
+
text_embeds = torch.concat(text_embeds, dim=0).cuda()
|
323 |
+
user_embeds = F.normalize(text_embeds[: n_user], dim=-1)
|
324 |
+
item_embeds = F.normalize(text_embeds[n_user: ], dim=-1)
|
325 |
+
return user_embeds, item_embeds
|
preprocessing.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
from embeddings import compute_embeddings, load_model
|
9 |
+
|
10 |
+
# Load configurations
|
11 |
+
with open("configs.yaml", "r") as file:
|
12 |
+
configs = yaml.safe_load(file)
|
13 |
+
|
14 |
+
# Load and process the movie dataset
|
15 |
+
movies_data = pd.read_csv(configs['dataset'])
|
16 |
+
|
17 |
+
# Define columns to drop that are not needed
|
18 |
+
columns_drop = ['budget', 'homepage', 'id', 'original_language', 'original_title',
|
19 |
+
'popularity', 'revenue', 'spoken_languages', 'status', 'tagline']
|
20 |
+
movies_data.drop(columns=columns_drop, axis=1, inplace=True)
|
21 |
+
movies_data.dropna(inplace=True) # Drop rows with missing values
|
22 |
+
|
23 |
+
# Convert JSON string columns to a comma-separated string of names
|
24 |
+
columns_json_to_csv = ['genres', 'keywords', 'production_companies', 'production_countries']
|
25 |
+
for col in columns_json_to_csv:
|
26 |
+
movies_data[col] = movies_data[col].apply(
|
27 |
+
lambda json_str: ', '.join([item["name"] for item in json.loads(json_str)])
|
28 |
+
)
|
29 |
+
|
30 |
+
# Extract the year from 'release_date'
|
31 |
+
movies_data['release_date'] = pd.to_datetime(movies_data['release_date']).dt.year
|
32 |
+
|
33 |
+
# Convert 'runtime' to integers
|
34 |
+
movies_data['runtime'] = movies_data['runtime'].astype(int)
|
35 |
+
|
36 |
+
# Combine 'overview', 'genres', and 'keywords' into a single string for each movie
|
37 |
+
movies_data_processed = movies_data[['overview', 'genres', 'keywords']].apply(
|
38 |
+
lambda row: '. '.join([f"{col.capitalize()}: {val}" for col, val in row.items()]),
|
39 |
+
axis=1
|
40 |
+
).tolist()
|
41 |
+
|
42 |
+
# Save the processed dataset
|
43 |
+
movies_data.to_csv(configs['processed_dataset'], index=False)
|
44 |
+
|
45 |
+
# Process embeddings for each model
|
46 |
+
for model_name in configs['hf_models']:
|
47 |
+
model, tokenizer = load_model(model_name)
|
48 |
+
movie_embeddings = compute_embeddings(movies_data_processed, model, tokenizer)
|
49 |
+
|
50 |
+
embedding_dir_path = f"{configs['movie_embeddings']}/{model_name}"
|
51 |
+
embedding_file_path = f"{embedding_dir_path}/{configs['movie_embeddings']}.pt"
|
52 |
+
os.makedirs(embedding_dir_path, exist_ok=True)
|
53 |
+
|
54 |
+
torch.save(movie_embeddings, embedding_file_path)
|
55 |
+
print(f"Saved embeddings for {model_name}")
|
56 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
pandas
|