EasyRec / app.py
Balaji S
First commit
d87c6b1 verified
import streamlit as st
import pandas as pd
import torch
import yaml
from embeddings import load_model, compute_embeddings
# Load configuration from YAML file
with open("configs.yaml", "r") as file:
configs = yaml.safe_load(file)
# Load the processed movie dataset
movie_data = pd.read_csv(configs["processed_dataset"])
# Streamlit app
st.title("🎬 EasyRec Movie Recommender")
# Dropdown for model selection
model_names = configs['hf_models'] # Assuming this is a list of model names in your configs.yaml
selected_model_name = st.selectbox("Select a model:", model_names)
# Load the model based on user selection
model, tokenizer = load_model(selected_model_name)
# User input for movie description
user_description = st.text_input("Enter a description of the type of movie you're interested in:",
placeholder="e.g. A romantic comedy with a twist...")
if user_description:
# Load the precomputed movie embeddings from a .pt file
embedding_dir_path = f"{configs['movie_embeddings']}/{selected_model_name}"
embedding_file_path = f"{embedding_dir_path}/{configs['movie_embeddings']}.pt"
movie_embeddings = torch.load(embedding_file_path) # Load the .pt file
# Compute the embedding for the user input by passing it as a list
user_embedding = compute_embeddings([user_description], model, tokenizer)
similarity_scores = torch.matmul(movie_embeddings, user_embedding.T).flatten()
# Set the number of top recommendations to display
K = 5
top_k_indices = torch.argsort(similarity_scores, descending=True)[:K].tolist() # Get indices of top K
# Display recommendations
st.write("## πŸŽ‰ Top Recommendations:")
for rank, movie_id in enumerate(top_k_indices, start=1):
movie = movie_data.iloc[movie_id]
# Convert runtime from minutes to hours and minutes
hours = movie.runtime // 60
minutes = movie.runtime % 60
# Construct an HTML card for displaying the movie information
st.markdown(f"### {rank}. {movie.title}")
st.markdown(f"**Release Date:** {movie.release_date}    **Runtime:** {f'{hours}h {minutes}m' if hours > 0 else f'{minutes}m'}")
st.markdown(f"⭐ {movie.vote_average} ({movie.vote_count} votes)")
st.markdown(f"**Overview:** {movie.overview}")
st.markdown(f"**Genres:** {movie.genres}")
st.markdown(f"**Production Companies:** {movie.production_companies}")
st.markdown(f"**Production Countries:** {movie.production_countries}")
st.markdown("---")
# Additional styling (optional)
st.markdown(
"""
<style>
.stTextInput > div > input {
background-color: #f0f0f5; /* Light gray background */
border: 1px solid #ccc; /* Gray border */
border-radius: 5px; /* Rounded corners */
}
</style>
""",
unsafe_allow_html=True
)