Update README.md
Browse files
README.md
CHANGED
@@ -6,6 +6,78 @@ license: apache-2.0
|
|
6 |
The repo contains the weights for the custom architecture presented in [Bertolini et al., 2023](https://arxiv.org/abs/2302.14828).
|
7 |
Working example on how to load and use the model can be found in the [Git repo](https://github.com/lorenzoscottb/Dream_Reports_Annotation/tree/main/Experiments/Supervised_Learning).
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
### Cite
|
10 |
If you use the model, please cite the pre-print.
|
11 |
```bibtex
|
|
|
6 |
The repo contains the weights for the custom architecture presented in [Bertolini et al., 2023](https://arxiv.org/abs/2302.14828).
|
7 |
Working example on how to load and use the model can be found in the [Git repo](https://github.com/lorenzoscottb/Dream_Reports_Annotation/tree/main/Experiments/Supervised_Learning).
|
8 |
|
9 |
+
#### Use
|
10 |
+
|
11 |
+
```py
|
12 |
+
import torch, os
|
13 |
+
import pandas as pd
|
14 |
+
from tqdm import tqdm
|
15 |
+
import transformers
|
16 |
+
from transformers import AutoModel
|
17 |
+
from transformers import AutoConfig
|
18 |
+
from transformers import BertTokenizerFast
|
19 |
+
from SL_utils import *
|
20 |
+
|
21 |
+
Coding_emotions = {
|
22 |
+
"AN": "Anger",
|
23 |
+
"AP": "Apprehension",
|
24 |
+
"SD": "Sadness",
|
25 |
+
"CO": "Confusion",
|
26 |
+
"HA": "Happiness",
|
27 |
+
}
|
28 |
+
|
29 |
+
emotions_list = list(Coding_emotions.keys())
|
30 |
+
|
31 |
+
test_sentences = [
|
32 |
+
"In my dream I was follwed by the scary monster.",
|
33 |
+
"I was walking in a forest, sorrounded by singing birds. I was in calm and peace."
|
34 |
+
]
|
35 |
+
|
36 |
+
test_sentences_target = len(test_sentences)*[[0, 0, 0, 0, 0]]
|
37 |
+
test_sentences_df = pd.DataFrame.from_dict(
|
38 |
+
{
|
39 |
+
"report":test_sentences,
|
40 |
+
"Report_as_Multilabel":test_sentences_target
|
41 |
+
}
|
42 |
+
)
|
43 |
+
```
|
44 |
+
|
45 |
+
```py
|
46 |
+
model_name = "bert-large-cased"
|
47 |
+
model_config = AutoConfig.from_pretrained(model_name)
|
48 |
+
tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=False)
|
49 |
+
testing_set = CustomDataset(test_sentences_df, tokenizer, max_length=512)
|
50 |
+
|
51 |
+
test_params = {
|
52 |
+
'batch_size': 2,
|
53 |
+
'shuffle': True,
|
54 |
+
'num_workers': 0
|
55 |
+
}
|
56 |
+
|
57 |
+
testing_loader = DataLoader(testing_set, **test_params)
|
58 |
+
|
59 |
+
model = BERT_PTM(
|
60 |
+
model_config,
|
61 |
+
model_name=model_name,
|
62 |
+
n_classes=len(emotions_list),
|
63 |
+
freeze_BERT=False,
|
64 |
+
)
|
65 |
+
|
66 |
+
# Load the models' weights from the pre-treined model
|
67 |
+
model.load_state_dict(torch.load("path/to/pytorch_model.bin"))
|
68 |
+
model.to("cuda")
|
69 |
+
```
|
70 |
+
|
71 |
+
```py
|
72 |
+
outputs, targets, ids = validation(model, testing_loader, device="cuda", return_inputs=True)
|
73 |
+
|
74 |
+
corr_outputs = np.array(outputs) >= 0.5
|
75 |
+
corr_outputs_df = pd.DataFrame(corr_outputs, columns=emotions_list)
|
76 |
+
corr_outputs_df = corr_outputs_df.astype(int)
|
77 |
+
|
78 |
+
corr_outputs_df["report"] = decoded_ids = [decode_clean(x, tokenizer) for x in tqdm(ids)]
|
79 |
+
```
|
80 |
+
|
81 |
### Cite
|
82 |
If you use the model, please cite the pre-print.
|
83 |
```bibtex
|