sudhanshu746 commited on
Commit
eb59f52
·
verified ·
1 Parent(s): 2373acb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +393 -3
README.md CHANGED
@@ -1,3 +1,393 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ This is ONNX version of Quantized model of bge-reranker-v2-m3 model created by Sudhanshu Sharma
3
+ ---
4
+ license: apache-2.0
5
+ language:
6
+ - multilingual
7
+ pipeline_tag: text-classification
8
+ tags:
9
+ - transformers
10
+ - sentence-transformers
11
+ - text-embeddings-inference
12
+
13
+
14
+ ---
15
+
16
+ # Reranker
17
+
18
+ **More details please refer to our Github: [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master).**
19
+
20
+ - [Model List](#model-list)
21
+ - [Usage](#usage)
22
+ - [Fine-tuning](#fine-tune)
23
+ - [Evaluation](#evaluation)
24
+ - [Citation](#citation)
25
+
26
+ Different from embedding model, reranker uses question and document as input and directly output similarity instead of embedding.
27
+ You can get a relevance score by inputting query and passage to the reranker.
28
+ And the score can be mapped to a float value in [0,1] by sigmoid function.
29
+
30
+
31
+ ## Model List
32
+
33
+ | Model | Base model | Language | layerwise | feature |
34
+ |:--------------------------------------------------------------------------|:--------:|:-----------------------------------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
35
+ | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | Chinese and English | - | Lightweight reranker model, easy to deploy, with fast inference. |
36
+ | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | [xlm-roberta-large](https://huggingface.co/FacebookAI/xlm-roberta-large) | Chinese and English | - | Lightweight reranker model, easy to deploy, with fast inference. |
37
+ | [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) | [bge-m3](https://huggingface.co/BAAI/bge-m3) | Multilingual | - | Lightweight reranker model, possesses strong multilingual capabilities, easy to deploy, with fast inference. |
38
+ | [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma) | [gemma-2b](https://huggingface.co/google/gemma-2b) | Multilingual | - | Suitable for multilingual contexts, performs well in both English proficiency and multilingual capabilities. |
39
+ | [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) | [MiniCPM-2B-dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16) | Multilingual | 8-40 | Suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers for output, facilitating accelerated inference. |
40
+
41
+
42
+ You can select the model according your senario and resource.
43
+ - For **multilingual**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma)
44
+
45
+ - For **Chinese or English**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise).
46
+
47
+ - For **efficiency**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and the low layer of [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise).
48
+
49
+ - For better performance, recommand [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) and [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma)
50
+
51
+ ## Usage
52
+ ### Using FlagEmbedding
53
+
54
+ ```
55
+ pip install -U FlagEmbedding
56
+ ```
57
+
58
+ #### For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
59
+
60
+ Get relevance scores (higher scores indicate more relevance):
61
+
62
+ ```python
63
+ from FlagEmbedding import FlagReranker
64
+ reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
65
+
66
+ score = reranker.compute_score(['query', 'passage'])
67
+ print(score) # -5.65234375
68
+
69
+ # You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
70
+ score = reranker.compute_score(['query', 'passage'], normalize=True)
71
+ print(score) # 0.003497010252573502
72
+
73
+ scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
74
+ print(scores) # [-8.1875, 5.26171875]
75
+
76
+ # You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
77
+ scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], normalize=True)
78
+ print(scores) # [0.00027803096387751553, 0.9948403768236574]
79
+ ```
80
+
81
+ #### For LLM-based reranker
82
+
83
+ ```python
84
+ from FlagEmbedding import FlagLLMReranker
85
+ reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
86
+ # reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation
87
+
88
+ score = reranker.compute_score(['query', 'passage'])
89
+ print(score)
90
+
91
+ scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
92
+ print(scores)
93
+ ```
94
+
95
+ #### For LLM-based layerwise reranker
96
+
97
+ ```python
98
+ from FlagEmbedding import LayerWiseFlagLLMReranker
99
+ reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
100
+ # reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation
101
+
102
+ score = reranker.compute_score(['query', 'passage'], cutoff_layers=[28]) # Adjusting 'cutoff_layers' to pick which layers are used for computing the score.
103
+ print(score)
104
+
105
+ scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], cutoff_layers=[28])
106
+ print(scores)
107
+ ```
108
+
109
+ ### Using Huggingface transformers
110
+
111
+ #### For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
112
+
113
+ Get relevance scores (higher scores indicate more relevance):
114
+
115
+ ```python
116
+ import torch
117
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
118
+
119
+ tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3')
120
+ model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-v2-m3')
121
+ model.eval()
122
+
123
+ pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
124
+ with torch.no_grad():
125
+ inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
126
+ scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
127
+ print(scores)
128
+ ```
129
+
130
+ #### For LLM-based reranker
131
+
132
+ ```python
133
+ import torch
134
+ from transformers import AutoModelForCausalLM, AutoTokenizer
135
+
136
+ def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
137
+ if prompt is None:
138
+ prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
139
+ sep = "\n"
140
+ prompt_inputs = tokenizer(prompt,
141
+ return_tensors=None,
142
+ add_special_tokens=False)['input_ids']
143
+ sep_inputs = tokenizer(sep,
144
+ return_tensors=None,
145
+ add_special_tokens=False)['input_ids']
146
+ inputs = []
147
+ for query, passage in pairs:
148
+ query_inputs = tokenizer(f'A: {query}',
149
+ return_tensors=None,
150
+ add_special_tokens=False,
151
+ max_length=max_length * 3 // 4,
152
+ truncation=True)
153
+ passage_inputs = tokenizer(f'B: {passage}',
154
+ return_tensors=None,
155
+ add_special_tokens=False,
156
+ max_length=max_length,
157
+ truncation=True)
158
+ item = tokenizer.prepare_for_model(
159
+ [tokenizer.bos_token_id] + query_inputs['input_ids'],
160
+ sep_inputs + passage_inputs['input_ids'],
161
+ truncation='only_second',
162
+ max_length=max_length,
163
+ padding=False,
164
+ return_attention_mask=False,
165
+ return_token_type_ids=False,
166
+ add_special_tokens=False
167
+ )
168
+ item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
169
+ item['attention_mask'] = [1] * len(item['input_ids'])
170
+ inputs.append(item)
171
+ return tokenizer.pad(
172
+ inputs,
173
+ padding=True,
174
+ max_length=max_length + len(sep_inputs) + len(prompt_inputs),
175
+ pad_to_multiple_of=8,
176
+ return_tensors='pt',
177
+ )
178
+
179
+ tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-gemma')
180
+ model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-gemma')
181
+ yes_loc = tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
182
+ model.eval()
183
+
184
+ pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
185
+ with torch.no_grad():
186
+ inputs = get_inputs(pairs, tokenizer)
187
+ scores = model(**inputs, return_dict=True).logits[:, -1, yes_loc].view(-1, ).float()
188
+ print(scores)
189
+ ```
190
+
191
+ #### For LLM-based layerwise reranker
192
+
193
+ ```python
194
+ import torch
195
+ from transformers import AutoModelForCausalLM, AutoTokenizer
196
+
197
+ def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
198
+ if prompt is None:
199
+ prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
200
+ sep = "\n"
201
+ prompt_inputs = tokenizer(prompt,
202
+ return_tensors=None,
203
+ add_special_tokens=False)['input_ids']
204
+ sep_inputs = tokenizer(sep,
205
+ return_tensors=None,
206
+ add_special_tokens=False)['input_ids']
207
+ inputs = []
208
+ for query, passage in pairs:
209
+ query_inputs = tokenizer(f'A: {query}',
210
+ return_tensors=None,
211
+ add_special_tokens=False,
212
+ max_length=max_length * 3 // 4,
213
+ truncation=True)
214
+ passage_inputs = tokenizer(f'B: {passage}',
215
+ return_tensors=None,
216
+ add_special_tokens=False,
217
+ max_length=max_length,
218
+ truncation=True)
219
+ item = tokenizer.prepare_for_model(
220
+ [tokenizer.bos_token_id] + query_inputs['input_ids'],
221
+ sep_inputs + passage_inputs['input_ids'],
222
+ truncation='only_second',
223
+ max_length=max_length,
224
+ padding=False,
225
+ return_attention_mask=False,
226
+ return_token_type_ids=False,
227
+ add_special_tokens=False
228
+ )
229
+ item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
230
+ item['attention_mask'] = [1] * len(item['input_ids'])
231
+ inputs.append(item)
232
+ return tokenizer.pad(
233
+ inputs,
234
+ padding=True,
235
+ max_length=max_length + len(sep_inputs) + len(prompt_inputs),
236
+ pad_to_multiple_of=8,
237
+ return_tensors='pt',
238
+ )
239
+
240
+ tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True)
241
+ model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True, torch_dtype=torch.bfloat16)
242
+ model = model.to('cuda')
243
+ model.eval()
244
+
245
+ pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
246
+ with torch.no_grad():
247
+ inputs = get_inputs(pairs, tokenizer).to(model.device)
248
+ all_scores = model(**inputs, return_dict=True, cutoff_layers=[28])
249
+ all_scores = [scores[:, -1].view(-1, ).float() for scores in all_scores[0]]
250
+ print(all_scores)
251
+ ```
252
+
253
+ ## Fine-tune
254
+
255
+ ### Data Format
256
+
257
+ Train data should be a json file, where each line is a dict like this:
258
+
259
+ ```
260
+ {"query": str, "pos": List[str], "neg":List[str], "prompt": str}
261
+ ```
262
+
263
+ `query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts, `prompt` indicates the relationship between query and texts. If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives.
264
+
265
+ See [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/llm_reranker/toy_finetune_data.jsonl) for a toy data file.
266
+
267
+ ### Train
268
+
269
+ You can fine-tune the reranker with the following code:
270
+
271
+ **For llm-based reranker**
272
+
273
+ ```shell
274
+ torchrun --nproc_per_node {number of gpus} \
275
+ -m FlagEmbedding.llm_reranker.finetune_for_instruction.run \
276
+ --output_dir {path to save model} \
277
+ --model_name_or_path google/gemma-2b \
278
+ --train_data ./toy_finetune_data.jsonl \
279
+ --learning_rate 2e-4 \
280
+ --num_train_epochs 1 \
281
+ --per_device_train_batch_size 1 \
282
+ --gradient_accumulation_steps 16 \
283
+ --dataloader_drop_last True \
284
+ --query_max_len 512 \
285
+ --passage_max_len 512 \
286
+ --train_group_size 16 \
287
+ --logging_steps 1 \
288
+ --save_steps 2000 \
289
+ --save_total_limit 50 \
290
+ --ddp_find_unused_parameters False \
291
+ --gradient_checkpointing \
292
+ --deepspeed stage1.json \
293
+ --warmup_ratio 0.1 \
294
+ --bf16 \
295
+ --use_lora True \
296
+ --lora_rank 32 \
297
+ --lora_alpha 64 \
298
+ --use_flash_attn True \
299
+ --target_modules q_proj k_proj v_proj o_proj
300
+ ```
301
+
302
+ **For llm-based layerwise reranker**
303
+
304
+ ```shell
305
+ torchrun --nproc_per_node {number of gpus} \
306
+ -m FlagEmbedding.llm_reranker.finetune_for_layerwise.run \
307
+ --output_dir {path to save model} \
308
+ --model_name_or_path openbmb/MiniCPM-2B-dpo-bf16 \
309
+ --train_data ./toy_finetune_data.jsonl \
310
+ --learning_rate 2e-4 \
311
+ --num_train_epochs 1 \
312
+ --per_device_train_batch_size 1 \
313
+ --gradient_accumulation_steps 16 \
314
+ --dataloader_drop_last True \
315
+ --query_max_len 512 \
316
+ --passage_max_len 512 \
317
+ --train_group_size 16 \
318
+ --logging_steps 1 \
319
+ --save_steps 2000 \
320
+ --save_total_limit 50 \
321
+ --ddp_find_unused_parameters False \
322
+ --gradient_checkpointing \
323
+ --deepspeed stage1.json \
324
+ --warmup_ratio 0.1 \
325
+ --bf16 \
326
+ --use_lora True \
327
+ --lora_rank 32 \
328
+ --lora_alpha 64 \
329
+ --use_flash_attn True \
330
+ --target_modules q_proj k_proj v_proj o_proj \
331
+ --start_layer 8 \
332
+ --head_multi True \
333
+ --head_type simple \
334
+ --lora_extra_parameters linear_head
335
+ ```
336
+
337
+ Our rerankers are initialized from [google/gemma-2b](https://huggingface.co/google/gemma-2b) (for llm-based reranker) and [openbmb/MiniCPM-2B-dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16) (for llm-based layerwise reranker), and we train it on a mixture of multilingual datasets:
338
+
339
+ - [bge-m3-data](https://huggingface.co/datasets/Shitao/bge-m3-data)
340
+ - [quora train data](https://huggingface.co/datasets/quora)
341
+ - [fever train data](https://fever.ai/dataset/fever.html)
342
+
343
+ ## Evaluation
344
+
345
+ - llama-index.
346
+
347
+ ![image-20240317193909373](./assets/llama-index.png)
348
+
349
+
350
+ - BEIR.
351
+
352
+ rereank the top 100 results from bge-en-v1.5 large.
353
+
354
+ ![image-20240317174633333](./assets/BEIR-bge-en-v1.5.png)
355
+
356
+ rereank the top 100 results from e5 mistral 7b instruct.
357
+
358
+ ![image-20240317172949713](./assets/BEIR-e5-mistral.png)
359
+
360
+ - CMTEB-retrieval.
361
+ It rereank the top 100 results from bge-zh-v1.5 large.
362
+
363
+ ![image-20240317173026235](./assets/CMTEB-retrieval-bge-zh-v1.5.png)
364
+
365
+ - miracl (multi-language).
366
+ It rereank the top 100 results from bge-m3.
367
+
368
+ ![image-20240317173117639](./assets/miracl-bge-m3.png)
369
+
370
+
371
+
372
+ ## Citation
373
+
374
+ If you find this repository useful, please consider giving a star and citation
375
+
376
+ ```bibtex
377
+ @misc{li2023making,
378
+ title={Making Large Language Models A Better Foundation For Dense Retrieval},
379
+ author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
380
+ year={2023},
381
+ eprint={2312.15503},
382
+ archivePrefix={arXiv},
383
+ primaryClass={cs.CL}
384
+ }
385
+ @misc{chen2024bge,
386
+ title={BGE M3-Embedding: Multi-Lingual, Multi-Functionality, Multi-Granularity Text Embeddings Through Self-Knowledge Distillation},
387
+ author={Jianlv Chen and Shitao Xiao and Peitian Zhang and Kun Luo and Defu Lian and Zheng Liu},
388
+ year={2024},
389
+ eprint={2402.03216},
390
+ archivePrefix={arXiv},
391
+ primaryClass={cs.CL}
392
+ }
393
+ ```