juanpablomesa commited on
Commit
e6c2ed8
·
1 Parent(s): 094ee39

Fixed normalization error when only 1 image is sent to endpoint

Browse files
Files changed (1) hide show
  1. handler.py +6 -5
handler.py CHANGED
@@ -219,13 +219,14 @@ class EndpointHandler:
219
  # get image embeddings
220
  batch_emb = self.model_clip.get_image_features(pixel_values=batch)
221
  # detach text emb from graph, move to CPU, and convert to numpy array
222
- batch_emb = batch_emb.squeeze(0)
 
 
 
223
  batch_emb = batch_emb.cpu().detach().numpy()
224
  # NORMALIZE
225
- if batch_emb.ndim == 1:
226
- batch_emb = batch_emb / np.linalg.norm(batch_emb, axis=0)
227
- else:
228
- batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
229
  # transpose back to (21, 512)
230
  batch_emb = batch_emb.T.tolist()
231
  embedding_end_time = timeit.default_timer()
 
219
  # get image embeddings
220
  batch_emb = self.model_clip.get_image_features(pixel_values=batch)
221
  # detach text emb from graph, move to CPU, and convert to numpy array
222
+ # Check the shape of the tensor before squeezing
223
+ if batch_emb.dim() > 1 and batch_emb.shape[0] == 1:
224
+ batch_emb = batch_emb.squeeze(0)
225
+
226
  batch_emb = batch_emb.cpu().detach().numpy()
227
  # NORMALIZE
228
+
229
+ batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
 
 
230
  # transpose back to (21, 512)
231
  batch_emb = batch_emb.T.tolist()
232
  embedding_end_time = timeit.default_timer()