jpohhhh commited on
Commit
917cd83
·
1 Parent(s): 8751910

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -17
handler.py CHANGED
@@ -3,22 +3,10 @@ from transformers import AutoTokenizer, AutoModel
3
  import torch
4
 
5
  #Mean Pooling - Take attention mask into account for correct averaging
6
- def max_pooling(model_output):
7
- # Get dimensions
8
- Z, Y = len(model_output[0]), len(model_output[0][0])
9
-
10
- # Initialize an empty list with length Y (384 in your case)
11
- output_array = [0] * Y
12
-
13
- # Loop over secondary arrays (Z)
14
- for i in range(Z):
15
- # Loop over values in innermost arrays (Y)
16
- for j in range(Y):
17
- # If value is greater than current max, update max
18
- if model_output[0][i][j] > output_array[j]:
19
- output_array[j] = model_output[0][i][j]
20
-
21
- return output_array
22
 
23
  class EndpointHandler():
24
  def __init__(self, path=""):
@@ -44,5 +32,5 @@ class EndpointHandler():
44
  model_output = self.model(**encoded_input)
45
 
46
  # Perform pooling. In this case, max pooling.
47
- sentence_embeddings = max_pooling(model_output, encoded_input['attention_mask'])
48
  return sentence_embeddings.tolist()
 
3
  import torch
4
 
5
  #Mean Pooling - Take attention mask into account for correct averaging
6
+ def mean_pooling(model_output, attention_mask):
7
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
8
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
9
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
 
32
  model_output = self.model(**encoded_input)
33
 
34
  # Perform pooling. In this case, max pooling.
35
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
36
  return sentence_embeddings.tolist()