lorenzoscottb commited on
Commit
5ba717e
·
verified ·
1 Parent(s): 9c0e723

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +72 -0
README.md CHANGED
@@ -6,6 +6,78 @@ license: apache-2.0
6
  The repo contains the weights for the custom architecture presented in [Bertolini et al., 2023](https://arxiv.org/abs/2302.14828).
7
  Working example on how to load and use the model can be found in the [Git repo](https://github.com/lorenzoscottb/Dream_Reports_Annotation/tree/main/Experiments/Supervised_Learning).
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ### Cite
10
  If you use the model, please cite the pre-print.
11
  ```bibtex
 
6
  The repo contains the weights for the custom architecture presented in [Bertolini et al., 2023](https://arxiv.org/abs/2302.14828).
7
  Working example on how to load and use the model can be found in the [Git repo](https://github.com/lorenzoscottb/Dream_Reports_Annotation/tree/main/Experiments/Supervised_Learning).
8
 
9
+ #### Use
10
+
11
+ ```py
12
+ import torch, os
13
+ import pandas as pd
14
+ from tqdm import tqdm
15
+ import transformers
16
+ from transformers import AutoModel
17
+ from transformers import AutoConfig
18
+ from transformers import BertTokenizerFast
19
+ from SL_utils import *
20
+
21
+ Coding_emotions = {
22
+ "AN": "Anger",
23
+ "AP": "Apprehension",
24
+ "SD": "Sadness",
25
+ "CO": "Confusion",
26
+ "HA": "Happiness",
27
+ }
28
+
29
+ emotions_list = list(Coding_emotions.keys())
30
+
31
+ test_sentences = [
32
+ "In my dream I was follwed by the scary monster.",
33
+ "I was walking in a forest, sorrounded by singing birds. I was in calm and peace."
34
+ ]
35
+
36
+ test_sentences_target = len(test_sentences)*[[0, 0, 0, 0, 0]]
37
+ test_sentences_df = pd.DataFrame.from_dict(
38
+ {
39
+ "report":test_sentences,
40
+ "Report_as_Multilabel":test_sentences_target
41
+ }
42
+ )
43
+ ```
44
+
45
+ ```py
46
+ model_name = "bert-large-cased"
47
+ model_config = AutoConfig.from_pretrained(model_name)
48
+ tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=False)
49
+ testing_set = CustomDataset(test_sentences_df, tokenizer, max_length=512)
50
+
51
+ test_params = {
52
+ 'batch_size': 2,
53
+ 'shuffle': True,
54
+ 'num_workers': 0
55
+ }
56
+
57
+ testing_loader = DataLoader(testing_set, **test_params)
58
+
59
+ model = BERT_PTM(
60
+ model_config,
61
+ model_name=model_name,
62
+ n_classes=len(emotions_list),
63
+ freeze_BERT=False,
64
+ )
65
+
66
+ # Load the models' weights from the pre-treined model
67
+ model.load_state_dict(torch.load("path/to/pytorch_model.bin"))
68
+ model.to("cuda")
69
+ ```
70
+
71
+ ```py
72
+ outputs, targets, ids = validation(model, testing_loader, device="cuda", return_inputs=True)
73
+
74
+ corr_outputs = np.array(outputs) >= 0.5
75
+ corr_outputs_df = pd.DataFrame(corr_outputs, columns=emotions_list)
76
+ corr_outputs_df = corr_outputs_df.astype(int)
77
+
78
+ corr_outputs_df["report"] = decoded_ids = [decode_clean(x, tokenizer) for x in tqdm(ids)]
79
+ ```
80
+
81
  ### Cite
82
  If you use the model, please cite the pre-print.
83
  ```bibtex