AJ-Gazin commited on
Commit
f709b0a
·
1 Parent(s): 7211ca7

Forced model to load to CPU (map_location)

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -25,12 +25,13 @@ API_KEY = os.getenv("HUGGINGFACE_API_KEY")
25
  API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
26
 
27
  # --- LOAD DATA AND MODEL ---
 
28
  movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data
29
- data = torch.load("./PyGdata.pt")
30
 
31
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
  model = model_def.Model(hidden_channels=32).to(device)
33
- model.load_state_dict(torch.load("PyGTrainedModelState.pt"))
34
  model.eval()
35
 
36
  # --- STREAMLIT APP ---
 
25
  API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
26
 
27
  # --- LOAD DATA AND MODEL ---
28
+ # map_location forces the model to be loaded on the CPU for huggingface compatibility
29
  movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data
30
+ data = torch.load("./PyGdata.pt", map_location=device('cpu'))
31
 
32
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
  model = model_def.Model(hidden_channels=32).to(device)
34
+ model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device('cpu')))
35
  model.eval()
36
 
37
  # --- STREAMLIT APP ---