jpohhhh commited on
Commit
c5ac78c
·
1 Parent(s): a81bb3f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -12
handler.py CHANGED
@@ -8,10 +8,10 @@ import time
8
  import os
9
  import torch
10
 
11
- def max_pooling(model_output):
12
  # Get dimensions
13
  Z, Y = len(model_output[0]), len(model_output[0][0])
14
-
15
  # Initialize an empty list with length Y (384 in your case)
16
  output_array = [0.0] * Y
17
 
@@ -19,18 +19,15 @@ def max_pooling(model_output):
19
  for i in range(Z):
20
  # Loop over values in innermost arrays (Y)
21
  for j in range(Y):
22
- # If value is greater than current max, update max
23
- if model_output[0][i][j] > output_array[j]:
24
- output_array[j] = model_output[0][i][j]
 
 
25
 
26
  return output_array
27
 
28
- #Mean Pooling - Take attention mask into account for correct averaging
29
- def mean_pooling(model_output, attention_mask):
30
- token_embeddings = model_output[0] #First element of model_output contains all token embeddings
31
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
32
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
33
-
34
  class EndpointHandler():
35
  def __init__(self, path=""):
36
  print("HELLO THIS IS THE CWD:", os.getcwd())
@@ -73,6 +70,6 @@ class EndpointHandler():
73
  # embedding = mean_pooling(model_output, encoded_input['attention_mask'])
74
  print("F")
75
 
76
- sentence_embeddings.append(max_pooling(model_output))
77
  print("G")
78
  return sentence_embeddings
 
8
  import os
9
  import torch
10
 
11
+ def mean_pooling(model_output):
12
  # Get dimensions
13
  Z, Y = len(model_output[0]), len(model_output[0][0])
14
+
15
  # Initialize an empty list with length Y (384 in your case)
16
  output_array = [0.0] * Y
17
 
 
19
  for i in range(Z):
20
  # Loop over values in innermost arrays (Y)
21
  for j in range(Y):
22
+ # Accumulate values
23
+ output_array[j] += model_output[0][i][j]
24
+
25
+ # Compute mean
26
+ output_array = [val / Z for val in output_array]
27
 
28
  return output_array
29
 
30
+
 
 
 
 
 
31
  class EndpointHandler():
32
  def __init__(self, path=""):
33
  print("HELLO THIS IS THE CWD:", os.getcwd())
 
70
  # embedding = mean_pooling(model_output, encoded_input['attention_mask'])
71
  print("F")
72
 
73
+ sentence_embeddings.append(mean_pooling(model_output))
74
  print("G")
75
  return sentence_embeddings