jer233 commited on
Commit
00d8fb5
·
verified ·
1 Parent(s): 699e956

Create roberta_model_loader.py

Browse files
Files changed (1) hide show
  1. roberta_model_loader.py +25 -0
roberta_model_loader.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer, RobertaModel
2
+ import torch
3
+
4
+
5
+ class RobertaModelLoader:
6
+ def __init__(
7
+ self,
8
+ model_name="roberta-base-openai-detector",
9
+ cache_dir=".cache",
10
+ ):
11
+ print("Roberta Model init")
12
+ self.model_name = model_name
13
+ self.cache_dir = cache_dir
14
+ self.tokenizer, self.model = self.load_base_model_and_tokenizer()
15
+
16
+ def load_base_model_and_tokenizer(self):
17
+ print("Load model: ", self.model_name)
18
+ return RobertaTokenizer.from_pretrained(
19
+ self.model_name, cache_dir=self.cache_dir
20
+ ), RobertaModel.from_pretrained(
21
+ self.model_name, output_hidden_states=True, cache_dir=self.cache_dir
22
+ )
23
+
24
+
25
+ roberta_model = RobertaModelLoader()