zaidmehdi commited on
Commit
f838a8b
·
1 Parent(s): 79dc319

returning numpy array of hidden state

Browse files
Files changed (1) hide show
  1. src/utils.py +1 -1
src/utils.py CHANGED
@@ -10,7 +10,7 @@ def extract_hidden_state(input_text, tokenizer, language_model):
10
  with torch.no_grad():
11
  outputs = language_model(**tokens)
12
 
13
- return outputs.last_hidden_state
14
 
15
 
16
  def get_metrics(y_true, y_preds):
 
10
  with torch.no_grad():
11
  outputs = language_model(**tokens)
12
 
13
+ return outputs.last_hidden_state[:,0].numpy()
14
 
15
 
16
  def get_metrics(y_true, y_preds):