Rohit8y commited on
Commit
9668314
·
1 Parent(s): 20d7a51

Fixing model loading

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -93,11 +93,12 @@ def predict(img):
93
  models_list = ['clip-sd.pth', 'clip-glide.pth', 'clip-ld.pth']
94
  modality = "Image+Text"
95
  for model_path in models_list:
96
- model = get_model('clip-sd.pth', modality)
97
  tensor = preprocessing(img, 224)
98
  input_tensor = tensor.view(1, 3, 224, 224)
99
  with torch.no_grad():
100
- out = model(input_tensor)
 
101
  prediction.append(out)
102
 
103
  if prediction[0] > 0.5:
 
93
  models_list = ['clip-sd.pth', 'clip-glide.pth', 'clip-ld.pth']
94
  modality = "Image+Text"
95
  for model_path in models_list:
96
+ model = get_model(model_path, modality)
97
  tensor = preprocessing(img, 224)
98
  input_tensor = tensor.view(1, 3, 224, 224)
99
  with torch.no_grad():
100
+ out = model(input_tensor, caption)
101
+ print('------------>', out)
102
  prediction.append(out)
103
 
104
  if prediction[0] > 0.5: