wilwork commited on
Commit
a3401dd
·
verified ·
1 Parent(s): 22e781b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -13,7 +13,16 @@ def get_embedding(text):
13
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
14
  with torch.no_grad():
15
  output = model(**inputs)
16
- return output.last_hidden_state[:, 0, :].squeeze() # Use CLS token embedding
 
 
 
 
 
 
 
 
 
17
 
18
  def get_similarity_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
19
  paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]
 
13
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
14
  with torch.no_grad():
15
  output = model(**inputs)
16
+
17
+ # Mean pooling over token embeddings
18
+ embeddings = output.last_hidden_state # Shape: (batch_size, seq_len, hidden_dim)
19
+ attention_mask = inputs["attention_mask"].unsqueeze(-1) # Shape: (batch_size, seq_len, 1)
20
+
21
+ # Apply mean pooling: Sum(token_embeddings * mask) / Sum(mask)
22
+ pooled_embedding = (embeddings * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
23
+
24
+ # Normalize embedding
25
+ return F.normalize(pooled_embedding, p=2, dim=1).squeeze()
26
 
27
  def get_similarity_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
28
  paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]