Balaji S commited on
Commit
d87c6b1
·
verified ·
1 Parent(s): b4cd0b8

First commit

Browse files
Files changed (10) hide show
  1. .gitattributes +35 -35
  2. .gitignore +1 -0
  3. README.md +18 -14
  4. app.py +75 -0
  5. configs.yaml +10 -0
  6. embeddings.py +64 -0
  7. loss_utils.py +63 -0
  8. model.py +325 -0
  9. preprocessing.py +56 -0
  10. requirements.txt +4 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md CHANGED
@@ -1,14 +1,18 @@
1
- ---
2
- title: EasyRec Movie Recommender
3
- emoji: 🚀
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: streamlit
7
- sdk_version: 1.39.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: DA626 Final Project
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
+ ---
2
+ title: DA626
3
+ emoji: 🐠
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.39.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ short_description: Deployment of EasyRec
12
+ ---
13
+
14
+ The [EasyRec](https://github.com/HKUDS/EasyRec) and [EasyRec-Forked](https://github.com/jibala-1022/EasyRec) models is used to recommend movies from [TMDB 5000 Movies Dataset](https://www.kaggle.com/datasets/tmdb/tmdb-movie-metadata?select=tmdb_5000_movies.csv)
15
+
16
+ `requirements.txt` contains packages needed during deployment only. Thus `json` is omitted.
17
+
18
+ Execute `preprocessing.py` to compute movie embeddings before deploying the app.
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ import yaml
5
+ from embeddings import load_model, compute_embeddings
6
+
7
+ # Load configuration from YAML file
8
+ with open("configs.yaml", "r") as file:
9
+ configs = yaml.safe_load(file)
10
+
11
+ # Load the processed movie dataset
12
+ movie_data = pd.read_csv(configs["processed_dataset"])
13
+
14
+ # Streamlit app
15
+ st.title("🎬 EasyRec Movie Recommender")
16
+
17
+ # Dropdown for model selection
18
+ model_names = configs['hf_models'] # Assuming this is a list of model names in your configs.yaml
19
+ selected_model_name = st.selectbox("Select a model:", model_names)
20
+
21
+ # Load the model based on user selection
22
+ model, tokenizer = load_model(selected_model_name)
23
+
24
+ # User input for movie description
25
+ user_description = st.text_input("Enter a description of the type of movie you're interested in:",
26
+ placeholder="e.g. A romantic comedy with a twist...")
27
+
28
+ if user_description:
29
+ # Load the precomputed movie embeddings from a .pt file
30
+ embedding_dir_path = f"{configs['movie_embeddings']}/{selected_model_name}"
31
+ embedding_file_path = f"{embedding_dir_path}/{configs['movie_embeddings']}.pt"
32
+ movie_embeddings = torch.load(embedding_file_path) # Load the .pt file
33
+
34
+ # Compute the embedding for the user input by passing it as a list
35
+ user_embedding = compute_embeddings([user_description], model, tokenizer)
36
+
37
+ similarity_scores = torch.matmul(movie_embeddings, user_embedding.T).flatten()
38
+
39
+ # Set the number of top recommendations to display
40
+ K = 5
41
+ top_k_indices = torch.argsort(similarity_scores, descending=True)[:K].tolist() # Get indices of top K
42
+
43
+ # Display recommendations
44
+ st.write("## 🎉 Top Recommendations:")
45
+
46
+ for rank, movie_id in enumerate(top_k_indices, start=1):
47
+ movie = movie_data.iloc[movie_id]
48
+
49
+ # Convert runtime from minutes to hours and minutes
50
+ hours = movie.runtime // 60
51
+ minutes = movie.runtime % 60
52
+
53
+ # Construct an HTML card for displaying the movie information
54
+ st.markdown(f"### {rank}. {movie.title}")
55
+ st.markdown(f"**Release Date:** {movie.release_date}    **Runtime:** {f'{hours}h {minutes}m' if hours > 0 else f'{minutes}m'}")
56
+ st.markdown(f"⭐ {movie.vote_average} ({movie.vote_count} votes)")
57
+ st.markdown(f"**Overview:** {movie.overview}")
58
+ st.markdown(f"**Genres:** {movie.genres}")
59
+ st.markdown(f"**Production Companies:** {movie.production_companies}")
60
+ st.markdown(f"**Production Countries:** {movie.production_countries}")
61
+ st.markdown("---")
62
+
63
+ # Additional styling (optional)
64
+ st.markdown(
65
+ """
66
+ <style>
67
+ .stTextInput > div > input {
68
+ background-color: #f0f0f5; /* Light gray background */
69
+ border: 1px solid #ccc; /* Gray border */
70
+ border-radius: 5px; /* Rounded corners */
71
+ }
72
+ </style>
73
+ """,
74
+ unsafe_allow_html=True
75
+ )
configs.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ hf_models:
2
+ - "hkuds/easyrec-roberta-small"
3
+ - "hkuds/easyrec-roberta-base"
4
+ - "hkuds/easyrec-roberta-large"
5
+ - "jibala-1022/easyrec-small"
6
+ - "jibala-1022/easyrec-base"
7
+ - "jibala-1022/easyrec-large"
8
+ dataset: "data/tmdb_5000_movies.csv"
9
+ processed_dataset: "data/tmdb_5000_movies_processed.csv"
10
+ movie_embeddings: "movie_embeddings"
embeddings.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoConfig, AutoTokenizer
6
+ from model import Easyrec
7
+
8
+
9
+ def load_model(model_path: str) -> Tuple[Easyrec, AutoTokenizer]:
10
+ """
11
+ Load the pre-trained model and tokenizer from the specified path.
12
+
13
+ Args:
14
+ model_path: The path to the pre-trained huggingface model or local directory.
15
+
16
+ Returns:
17
+ tuple: A tuple containing the model and tokenizer.
18
+ """
19
+
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ config = AutoConfig.from_pretrained(model_path)
23
+ model = Easyrec.from_pretrained(model_path, config=config).to(device)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
25
+
26
+ return model, tokenizer
27
+
28
+
29
+ def compute_embeddings(
30
+ sentences: List[str],
31
+ model: Easyrec,
32
+ tokenizer: AutoTokenizer,
33
+ batch_size: int = 8) -> torch.Tensor:
34
+ """
35
+ Compute embeddings for a list of sentences using the specified model and tokenizer.
36
+
37
+ Args:
38
+ sentences: A list of sentences for which to compute embeddings.
39
+ model: The pre-trained model used for generating embeddings.
40
+ tokenizer: The tokenizer used to preprocess the sentences.
41
+ batch_size: The number of sentences to process in each batch (default is 8).
42
+
43
+ Returns:
44
+ torch.Tensor: A tensor containing the normalized embeddings for the input sentences.
45
+ """
46
+
47
+ embeddings = []
48
+ count_sentences = len(sentences)
49
+ device = next(model.parameters()).device # Get the device on which the model is located
50
+
51
+ for start in range(0, count_sentences, batch_size):
52
+ end = start + batch_size
53
+ batch_sentences = sentences[start:end]
54
+
55
+ inputs = tokenizer(batch_sentences, padding=True, truncation=True, max_length=512, return_tensors="pt")
56
+ inputs = {key: val.to(device) for key, val in inputs.items()} # Move input tensors to the same device as the model
57
+
58
+ with torch.inference_mode():
59
+ outputs = model.encode(inputs['input_ids'], inputs['attention_mask'])
60
+ batch_embeddings = F.normalize(outputs.pooler_output.detach().float(), dim=-1)
61
+
62
+ embeddings.append(batch_embeddings.cpu())
63
+
64
+ return torch.cat(embeddings, dim=0) # Concatenate all computed embeddings into a single tensor
loss_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+ import torch.nn.functional as F
3
+
4
+ def cal_bpr_loss(anc_embeds, pos_embeds, neg_embeds):
5
+ pos_preds = (anc_embeds * pos_embeds).sum(-1)
6
+ neg_preds = (anc_embeds * neg_embeds).sum(-1)
7
+ return t.sum(F.softplus(neg_preds - pos_preds))
8
+
9
+
10
+ def reg_pick_embeds(embeds_list):
11
+ reg_loss = 0
12
+ for embeds in embeds_list:
13
+ reg_loss += embeds.square().sum()
14
+ return reg_loss
15
+
16
+
17
+ def cal_infonce_loss(embeds1, embeds2, all_embeds2, temp=1.0):
18
+ normed_embeds1 = embeds1 / t.sqrt(1e-8 + embeds1.square().sum(-1, keepdim=True))
19
+ normed_embeds2 = embeds2 / t.sqrt(1e-8 + embeds2.square().sum(-1, keepdim=True))
20
+ normed_all_embeds2 = all_embeds2 / t.sqrt(1e-8 + all_embeds2.square().sum(-1, keepdim=True))
21
+ nume_term = -(normed_embeds1 * normed_embeds2 / temp).sum(-1)
22
+ deno_term = t.log(t.sum(t.exp(normed_embeds1 @ normed_all_embeds2.T / temp), dim=-1))
23
+ cl_loss = (nume_term + deno_term).sum()
24
+ return cl_loss
25
+
26
+
27
+ def cal_infonce_loss_spec_nodes(embeds1, embeds2, nodes, temp):
28
+ embeds1 = F.normalize(embeds1 + 1e-8, p=2)
29
+ embeds2 = F.normalize(embeds2 + 1e-8, p=2)
30
+ pckEmbeds1 = embeds1[nodes]
31
+ pckEmbeds2 = embeds2[nodes]
32
+ nume = t.exp(t.sum(pckEmbeds1 * pckEmbeds2, dim=-1) / temp)
33
+ deno = t.exp(pckEmbeds1 @ embeds2.T / temp).sum(-1) + 1e-8
34
+ return -t.log(nume / deno).mean()
35
+
36
+
37
+ def cal_sce_loss(x, y, alpha):
38
+ x = F.normalize(x, p=2, dim=-1)
39
+ y = F.normalize(y, p=2, dim=-1)
40
+ loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
41
+ loss = loss.mean()
42
+ return loss
43
+
44
+
45
+ def cal_rank_loss(stu_anc_emb, stu_pos_emb, stu_neg_emb, tea_anc_emb, tea_pos_emb, tea_neg_emb):
46
+ stu_pos_score = (stu_anc_emb * stu_pos_emb).sum(dim=-1)
47
+ stu_neg_score = (stu_anc_emb * stu_neg_emb).sum(dim=-1)
48
+ stu_r_score = F.sigmoid(stu_pos_score - stu_neg_score)
49
+
50
+ tea_pos_score = (tea_anc_emb * tea_pos_emb).sum(dim=-1)
51
+ tea_neg_score = (tea_anc_emb * tea_neg_emb).sum(dim=-1)
52
+ tea_r_score = F.sigmoid(tea_pos_score - tea_neg_score)
53
+
54
+ rank_loss = -(tea_r_score * t.log(stu_r_score + 1e-8) + (1 - tea_r_score) * t.log(1 - stu_r_score + 1e-8)).mean()
55
+
56
+ return rank_loss
57
+
58
+
59
+ def reg_params(model):
60
+ reg_loss = 0
61
+ for W in model.parameters():
62
+ reg_loss += W.norm(2).square()
63
+ return reg_loss
model.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+ import scipy.sparse as sp
7
+ import torch.nn.functional as F
8
+ import torch.distributed as dist
9
+
10
+ import transformers
11
+ from transformers import RobertaTokenizer
12
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
13
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
14
+ from transformers.activations import gelu
15
+ from transformers.file_utils import (
16
+ add_code_sample_docstrings,
17
+ add_start_docstrings,
18
+ add_start_docstrings_to_model_forward,
19
+ replace_return_docstrings,
20
+ )
21
+ from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
22
+
23
+ from loss_utils import *
24
+
25
+ init = nn.init.xavier_uniform_
26
+ uniformInit = nn.init.uniform
27
+
28
+
29
+ """
30
+ EasyRec
31
+ """
32
+ def dot_product_scores(q_vectors, ctx_vectors):
33
+ r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
34
+ return r
35
+
36
+ class MLPLayer(nn.Module):
37
+ """
38
+ Head for getting sentence representations over RoBERTa/BERT's CLS representation.
39
+ """
40
+ def __init__(self, config):
41
+ super().__init__()
42
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
43
+ self.activation = nn.Tanh()
44
+
45
+ def forward(self, features, **kwargs):
46
+ x = self.dense(features)
47
+ x = self.activation(x)
48
+ return x
49
+
50
+
51
+ class Pooler(nn.Module):
52
+ """
53
+ Parameter-free poolers to get the sentence embedding
54
+ 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
55
+ 'cls_before_pooler': [CLS] representation without the original MLP pooler.
56
+ 'avg': average of the last layers' hidden states at each token.
57
+ 'avg_top2': average of the last two layers.
58
+ 'avg_first_last': average of the first and the last layers.
59
+ """
60
+ def __init__(self, pooler_type):
61
+ super().__init__()
62
+ self.pooler_type = pooler_type
63
+ assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type
64
+
65
+ def forward(self, attention_mask, outputs):
66
+ last_hidden = outputs.last_hidden_state
67
+ pooler_output = outputs.pooler_output
68
+ hidden_states = outputs.hidden_states
69
+
70
+ if self.pooler_type in ['cls_before_pooler', 'cls']:
71
+ return last_hidden[:, 0]
72
+ elif self.pooler_type == "avg":
73
+ return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
74
+ elif self.pooler_type == "avg_first_last":
75
+ first_hidden = hidden_states[1]
76
+ last_hidden = hidden_states[-1]
77
+ pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
78
+ return pooled_result
79
+ elif self.pooler_type == "avg_top2":
80
+ second_last_hidden = hidden_states[-2]
81
+ last_hidden = hidden_states[-1]
82
+ pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
83
+ return pooled_result
84
+ else:
85
+ raise NotImplementedError
86
+
87
+
88
+ class Similarity(nn.Module):
89
+ """
90
+ Dot product or cosine similarity
91
+ """
92
+ def __init__(self, temp):
93
+ super().__init__()
94
+ self.temp = temp
95
+ self.cos = nn.CosineSimilarity(dim=-1)
96
+
97
+ def forward(self, x, y):
98
+ return self.cos(x, y) / self.temp
99
+
100
+
101
+ class Easyrec(RobertaPreTrainedModel):
102
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
103
+
104
+ def __init__(self, config, *model_args, **model_kargs):
105
+ super().__init__(config)
106
+ try:
107
+ self.model_args = model_kargs["model_args"]
108
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
109
+ if self.model_args.pooler_type == "cls":
110
+ self.mlp = MLPLayer(config)
111
+ if self.model_args.do_mlm:
112
+ self.lm_head = RobertaLMHead(config)
113
+ """
114
+ Contrastive learning class init function.
115
+ """
116
+ self.pooler_type = self.model_args.pooler_type
117
+ self.pooler = Pooler(self.pooler_type)
118
+ self.sim = Similarity(temp=self.model_args.temp)
119
+ self.init_weights()
120
+ except:
121
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
122
+ self.mlp = MLPLayer(config)
123
+ self.lm_head = RobertaLMHead(config)
124
+ self.pooler_type = 'cls'
125
+ self.pooler = Pooler(self.pooler_type)
126
+ self.init_weights()
127
+
128
+ def forward(self,
129
+ user_input_ids=None,
130
+ user_attention_mask=None,
131
+ pos_item_input_ids=None,
132
+ pos_item_attention_mask=None,
133
+ neg_item_input_ids=None,
134
+ neg_item_attention_mask=None,
135
+ token_type_ids=None,
136
+ position_ids=None,
137
+ head_mask=None,
138
+ inputs_embeds=None,
139
+ labels=None,
140
+ output_attentions=None,
141
+ output_hidden_states=None,
142
+ return_dict=None,
143
+ mlm_input_ids=None,
144
+ mlm_attention_mask=None,
145
+ mlm_labels=None,
146
+ ):
147
+ """
148
+ Contrastive learning forward function.
149
+ """
150
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
151
+ batch_size = user_input_ids.size(0)
152
+
153
+ # Get user embeddings
154
+ user_outputs = self.roberta(
155
+ input_ids=user_input_ids,
156
+ attention_mask=user_attention_mask,
157
+ token_type_ids=None,
158
+ position_ids=None,
159
+ head_mask=None,
160
+ inputs_embeds=None,
161
+ output_attentions=output_attentions,
162
+ output_hidden_states=output_hidden_states,
163
+ return_dict=return_dict,
164
+ )
165
+
166
+ # Get positive item embeddings
167
+ pos_item_outputs = self.roberta(
168
+ input_ids=pos_item_input_ids,
169
+ attention_mask=pos_item_attention_mask,
170
+ token_type_ids=None,
171
+ position_ids=None,
172
+ head_mask=None,
173
+ inputs_embeds=None,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+
179
+ # Get negative item embeddings
180
+ neg_item_outputs = self.roberta(
181
+ input_ids=neg_item_input_ids,
182
+ attention_mask=neg_item_attention_mask,
183
+ token_type_ids=None,
184
+ position_ids=None,
185
+ head_mask=None,
186
+ inputs_embeds=None,
187
+ output_attentions=output_attentions,
188
+ output_hidden_states=output_hidden_states,
189
+ return_dict=return_dict,
190
+ )
191
+
192
+ # MLM auxiliary objective
193
+ if mlm_input_ids is not None:
194
+ mlm_outputs = self.roberta(
195
+ input_ids=mlm_input_ids,
196
+ attention_mask=mlm_attention_mask,
197
+ token_type_ids=None,
198
+ position_ids=None,
199
+ head_mask=None,
200
+ inputs_embeds=None,
201
+ output_attentions=output_attentions,
202
+ output_hidden_states=output_hidden_states,
203
+ return_dict=return_dict,
204
+ )
205
+
206
+ # Pooling
207
+ user_pooler_output = self.pooler(user_attention_mask, user_outputs)
208
+ pos_item_pooler_output = self.pooler(pos_item_attention_mask, pos_item_outputs)
209
+ neg_item_pooler_output = self.pooler(neg_item_attention_mask, neg_item_outputs)
210
+
211
+ # If using "cls", we add an extra MLP layer
212
+ # (same as BERT's original implementation) over the representation.
213
+ if self.pooler_type == "cls":
214
+ user_pooler_output = self.mlp(user_pooler_output)
215
+ pos_item_pooler_output = self.mlp(pos_item_pooler_output)
216
+ neg_item_pooler_output = self.mlp(neg_item_pooler_output)
217
+
218
+ # Gather all item embeddings if using distributed training
219
+ if dist.is_initialized() and self.training:
220
+ # Dummy vectors for allgather
221
+ user_list = [torch.zeros_like(user_pooler_output) for _ in range(dist.get_world_size())]
222
+ pos_item_list = [torch.zeros_like(pos_item_pooler_output) for _ in range(dist.get_world_size())]
223
+ neg_item_list = [torch.zeros_like(neg_item_pooler_output) for _ in range(dist.get_world_size())]
224
+ # Allgather
225
+ dist.all_gather(tensor_list=user_list, tensor=user_pooler_output.contiguous())
226
+ dist.all_gather(tensor_list=pos_item_list, tensor=pos_item_pooler_output.contiguous())
227
+ dist.all_gather(tensor_list=neg_item_list, tensor=neg_item_pooler_output.contiguous())
228
+
229
+ # Since allgather results do not have gradients, we replace the
230
+ # current process's corresponding embeddings with original tensors
231
+ user_list[dist.get_rank()] = user_pooler_output
232
+ pos_item_list[dist.get_rank()] = pos_item_pooler_output
233
+ neg_item_list[dist.get_rank()] = neg_item_pooler_output
234
+
235
+ # Get full batch embeddings
236
+ user_pooler_output = torch.cat(user_list, dim=0)
237
+ pos_item_pooler_output = torch.cat(pos_item_list, dim=0)
238
+ neg_item_pooler_output = torch.cat(neg_item_list, dim=0)
239
+
240
+ cos_sim = self.sim(user_pooler_output.unsqueeze(1), pos_item_pooler_output.unsqueeze(0))
241
+ neg_sim = self.sim(user_pooler_output.unsqueeze(1), neg_item_pooler_output.unsqueeze(0))
242
+ cos_sim = torch.cat([cos_sim, neg_sim], 1)
243
+
244
+ labels = torch.arange(cos_sim.size(0)).long().to(self.device)
245
+ loss_fct = nn.CrossEntropyLoss()
246
+
247
+ loss = loss_fct(cos_sim, labels)
248
+
249
+ # Calculate loss for MLM
250
+ if mlm_outputs is not None and mlm_labels is not None and self.model_args.do_mlm:
251
+ mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
252
+ prediction_scores = self.lm_head(mlm_outputs.last_hidden_state)
253
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
254
+ loss = loss + self.model_args.mlm_weight * masked_lm_loss
255
+
256
+ if not return_dict:
257
+ raise NotImplementedError
258
+
259
+ return SequenceClassifierOutput(
260
+ loss=loss,
261
+ logits=cos_sim,
262
+ )
263
+
264
+ def encode(self,
265
+ input_ids=None,
266
+ attention_mask=None,
267
+ token_type_ids=None,
268
+ position_ids=None,
269
+ head_mask=None,
270
+ inputs_embeds=None,
271
+ labels=None,
272
+ output_attentions=None,
273
+ output_hidden_states=None,
274
+ return_dict=None,
275
+ ):
276
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
+ outputs = self.roberta(
278
+ input_ids=input_ids,
279
+ attention_mask=attention_mask,
280
+ token_type_ids=None,
281
+ position_ids=None,
282
+ head_mask=None,
283
+ inputs_embeds=None,
284
+ output_attentions=output_attentions,
285
+ output_hidden_states=output_hidden_states,
286
+ return_dict=return_dict,
287
+ )
288
+ pooler_output = self.pooler(attention_mask, outputs)
289
+ if self.pooler_type == "cls":
290
+ pooler_output = self.mlp(pooler_output)
291
+ if not return_dict:
292
+ return (outputs[0], pooler_output) + outputs[2:]
293
+
294
+ return BaseModelOutputWithPoolingAndCrossAttentions(
295
+ pooler_output=pooler_output,
296
+ last_hidden_state=outputs.last_hidden_state,
297
+ hidden_states=outputs.hidden_states,
298
+ )
299
+
300
+ def inference(self,
301
+ user_profile_list,
302
+ item_profile_list,
303
+ dataset_name,
304
+ tokenizer,
305
+ infer_batch_size=128
306
+ ):
307
+ n_user = len(user_profile_list)
308
+ profiles = user_profile_list + item_profile_list
309
+ n_batch = math.ceil(len(profiles) / infer_batch_size)
310
+ text_embeds = []
311
+ for i in tqdm(range(n_batch), desc=f'Encoding Text {dataset_name}'):
312
+ batch_profiles = profiles[i * infer_batch_size: (i + 1) * infer_batch_size]
313
+ inputs = tokenizer(batch_profiles, padding=True, truncation=True, max_length=512, return_tensors="pt")
314
+ for k in inputs:
315
+ inputs[k] = inputs[k].to(self.device)
316
+ with torch.inference_mode():
317
+ embeds = self.encode(
318
+ input_ids=inputs.input_ids,
319
+ attention_mask=inputs.attention_mask
320
+ )
321
+ text_embeds.append(embeds.pooler_output.detach().cpu())
322
+ text_embeds = torch.concat(text_embeds, dim=0).cuda()
323
+ user_embeds = F.normalize(text_embeds[: n_user], dim=-1)
324
+ item_embeds = F.normalize(text_embeds[n_user: ], dim=-1)
325
+ return user_embeds, item_embeds
preprocessing.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import pandas as pd
5
+ import torch
6
+ import yaml
7
+
8
+ from embeddings import compute_embeddings, load_model
9
+
10
+ # Load configurations
11
+ with open("configs.yaml", "r") as file:
12
+ configs = yaml.safe_load(file)
13
+
14
+ # Load and process the movie dataset
15
+ movies_data = pd.read_csv(configs['dataset'])
16
+
17
+ # Define columns to drop that are not needed
18
+ columns_drop = ['budget', 'homepage', 'id', 'original_language', 'original_title',
19
+ 'popularity', 'revenue', 'spoken_languages', 'status', 'tagline']
20
+ movies_data.drop(columns=columns_drop, axis=1, inplace=True)
21
+ movies_data.dropna(inplace=True) # Drop rows with missing values
22
+
23
+ # Convert JSON string columns to a comma-separated string of names
24
+ columns_json_to_csv = ['genres', 'keywords', 'production_companies', 'production_countries']
25
+ for col in columns_json_to_csv:
26
+ movies_data[col] = movies_data[col].apply(
27
+ lambda json_str: ', '.join([item["name"] for item in json.loads(json_str)])
28
+ )
29
+
30
+ # Extract the year from 'release_date'
31
+ movies_data['release_date'] = pd.to_datetime(movies_data['release_date']).dt.year
32
+
33
+ # Convert 'runtime' to integers
34
+ movies_data['runtime'] = movies_data['runtime'].astype(int)
35
+
36
+ # Combine 'overview', 'genres', and 'keywords' into a single string for each movie
37
+ movies_data_processed = movies_data[['overview', 'genres', 'keywords']].apply(
38
+ lambda row: '. '.join([f"{col.capitalize()}: {val}" for col, val in row.items()]),
39
+ axis=1
40
+ ).tolist()
41
+
42
+ # Save the processed dataset
43
+ movies_data.to_csv(configs['processed_dataset'], index=False)
44
+
45
+ # Process embeddings for each model
46
+ for model_name in configs['hf_models']:
47
+ model, tokenizer = load_model(model_name)
48
+ movie_embeddings = compute_embeddings(movies_data_processed, model, tokenizer)
49
+
50
+ embedding_dir_path = f"{configs['movie_embeddings']}/{model_name}"
51
+ embedding_file_path = f"{embedding_dir_path}/{configs['movie_embeddings']}.pt"
52
+ os.makedirs(embedding_dir_path, exist_ok=True)
53
+
54
+ torch.save(movie_embeddings, embedding_file_path)
55
+ print(f"Saved embeddings for {model_name}")
56
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ pandas