yliu279 commited on
Commit
fb97134
·
verified ·
1 Parent(s): b0a525a

Add Sentence Transformers integration (#7)

Browse files

- Update README; modeling_gemma2.py; overwrite (979b19f0a88cf8efed7a354454d4e0b9c400df66)
- Undo weird unicode changes (8041ca86014930f5cde7da3f6614149530e85122)

Files changed (2) hide show
  1. README.md +36 -1
  2. modeling_gemma2.py +3 -0
README.md CHANGED
@@ -1,5 +1,11 @@
1
  ---
2
  license: cc-by-nc-4.0
 
 
 
 
 
 
3
  ---
4
  <h1 align="center">Salesforce/SFR-Embedding-Code-2B_R</h1>
5
 
@@ -52,7 +58,7 @@ from transformers import AutoTokenizer, AutoModel
52
  query_instruction_example = "Given Code or Text, retrieval relevant content"
53
  queries = [
54
  "how to implement quick sort in Python?"
55
- ]
56
 
57
  # No instruction needed for retrieval passages
58
  passages = [
@@ -74,6 +80,35 @@ passage_embeddings = F.normalize(passage_embeddings, p=2, dim=1)
74
 
75
  scores = (query_embeddings @ passage_embeddings.T) * 100
76
  print(scores.tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  ```
78
 
79
  ### Citation
 
1
  ---
2
  license: cc-by-nc-4.0
3
+ pipeline_tag: feature-extraction
4
+ tags:
5
+ - transformers
6
+ - sentence-transformers
7
+ - code
8
+ - retrieval
9
  ---
10
  <h1 align="center">Salesforce/SFR-Embedding-Code-2B_R</h1>
11
 
 
58
  query_instruction_example = "Given Code or Text, retrieval relevant content"
59
  queries = [
60
  "how to implement quick sort in Python?"
61
+ ]
62
 
63
  # No instruction needed for retrieval passages
64
  passages = [
 
80
 
81
  scores = (query_embeddings @ passage_embeddings.T) * 100
82
  print(scores.tolist())
83
+ # [[52.76957702636719, 26.118698120117188]]
84
+ ```
85
+
86
+ #### Sentence Transformers
87
+
88
+ ```python
89
+ from sentence_transformers import SentenceTransformer
90
+
91
+ # Each query needs to be accompanied by an corresponding instruction describing the task.
92
+ query_instruction_example = "Instruct: Given Code or Text, retrieval relevant content\nQuery: "
93
+ queries = ["how to implement quick sort in Python?"]
94
+
95
+ # No instruction needed for retrieval passages
96
+ passages = [
97
+ "def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)",
98
+ "def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr"
99
+ ]
100
+
101
+ # Load the Sentence Transformer model, including pooling
102
+ model = SentenceTransformer('Salesforce/SFR-Embedding-Code-2B_R', trust_remote_code=True)
103
+
104
+ # Compute the embeddings for both queries and passages. Use 'prompt' for queries only
105
+ query_embeddings = model.encode(queries, prompt=query_instruction_example)
106
+ passage_embeddings = model.encode(passages)
107
+
108
+ # Compute the similarities between the queries and passages
109
+ similarities = model.similarity(query_embeddings, passage_embeddings)
110
+ print(similarities)
111
+ # tensor([[0.5277, 0.2612]])
112
  ```
113
 
114
  ### Citation
modeling_gemma2.py CHANGED
@@ -1350,6 +1350,9 @@ class CodeXEmbedModel2B(PreTrainedModel):
1350
  self.tokenizer.pad_token = self.tokenizer.eos_token
1351
  self.tokenizer.padding_side = 'right'
1352
 
 
 
 
1353
  def last_token_pool(self, model_output, attention_mask):
1354
  last_hidden_states = model_output.last_hidden_state
1355
  sequence_lengths = attention_mask.sum(dim=1) - 1
 
1350
  self.tokenizer.pad_token = self.tokenizer.eos_token
1351
  self.tokenizer.padding_side = 'right'
1352
 
1353
+ def forward(self, **kwargs):
1354
+ return self.model(**kwargs)
1355
+
1356
  def last_token_pool(self, model_output, attention_mask):
1357
  last_hidden_states = model_output.last_hidden_state
1358
  sequence_lengths = attention_mask.sum(dim=1) - 1