File size: 2,041 Bytes
ce2716d
cdf4184
 
 
ce2716d
cdf4184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
---
tags:
- feature-extraction
pipeline_tag: feature-extraction
---
DRAGON-RoBERTa is a BERT-base sized dense retriever initialized from [RoBERTa](https://huggingface.co/roberta-base) and further trained on the data augmented from MS MARCO corpus, following the approach described in [How to Train Your DRAGON:
Diverse Augmentation Towards Generalizable Dense Retrieval](\url). The associated GitHub repository is available here https://github.com/facebookresearch/dpr-scale/tree/dragon. We use asymmetric dual encoder, with two distinctly parameterized encoders. 
The following models are also available:
Model | Initialization | Query Encoder Path | Context Encoder Path
|---|---|---
DRAGON-RoBERTa | roberta-base | facebook/dragon-roberta-query-encoder | facebook/dragon-roberta-context-encoder

## Usage (HuggingFace Transformers)
Using the model directly available in HuggingFace transformers .

```python
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('facebook/dragon-roberta-query-encoder')
query_encoder = AutoModel.from_pretrained('facebook/dragon-roberta-query-encoder')
context_encoder = AutoModel.from_pretrained('facebook/dragon-roberta-context-encoder')

# We use msmarco query and passages as an example
query =  "Where was Marie Curie born?"
contexts = [
    "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
    "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
]
# Apply tokenizer
query_input = tokenizer(query, return_tensors='pt')
ctx_input = tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
# Compute embeddings: take the last-layer hidden state of the [CLS] token
query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
# Compute similarity scores using dot product
score1 = query_emb @ ctx_emb[0]  # 385.1422
score2 = query_emb @ ctx_emb[1]  # 383.6051
```