Update model_loader.py
Browse files- model_loader.py +45 -19
model_loader.py
CHANGED
@@ -1,54 +1,80 @@
|
|
1 |
# model_loader.py
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
|
3 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
4 |
|
5 |
-
# Classifier Model (XLM-RoBERTa for toxicity classification)
|
6 |
class ClassifierModel:
|
7 |
def __init__(self):
|
8 |
self.model = None
|
9 |
self.tokenizer = None
|
10 |
-
self.
|
|
|
11 |
|
12 |
-
def
|
13 |
"""
|
14 |
-
Load the fine-tuned XLM-RoBERTa model and tokenizer for
|
15 |
"""
|
16 |
try:
|
17 |
model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone"
|
|
|
18 |
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
19 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name
|
|
|
|
|
|
|
20 |
except Exception as e:
|
21 |
-
|
|
|
|
|
|
|
22 |
|
23 |
-
# Paraphraser Model (Granite 3.2-2B-Instruct for paraphrasing)
|
24 |
class ParaphraserModel:
|
25 |
def __init__(self):
|
26 |
self.model = None
|
27 |
self.tokenizer = None
|
28 |
-
self.
|
|
|
29 |
|
30 |
-
def
|
31 |
"""
|
32 |
-
Load the Granite 3.2-2B-Instruct model and tokenizer for paraphrasing.
|
33 |
"""
|
34 |
try:
|
35 |
model_name = "ibm-granite/granite-3.2-2b-instruct"
|
|
|
36 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
37 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
except Exception as e:
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
-
# Metrics Models (Sentence-BERT only)
|
42 |
class MetricsModels:
|
43 |
def __init__(self):
|
44 |
-
self.
|
|
|
|
|
45 |
|
46 |
def load_sentence_bert(self):
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
# Singleton instances
|
52 |
-
classifier_model = ClassifierModel()
|
53 |
-
paraphraser_model = ParaphraserModel()
|
54 |
metrics_models = MetricsModels()
|
|
|
1 |
# model_loader.py
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
|
|
|
7 |
class ClassifierModel:
|
8 |
def __init__(self):
|
9 |
self.model = None
|
10 |
self.tokenizer = None
|
11 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
self.load_classifier_model()
|
13 |
|
14 |
+
def load_classifier_model(self):
|
15 |
"""
|
16 |
+
Load the fine-tuned XLM-RoBERTa model and tokenizer for toxicity classification.
|
17 |
"""
|
18 |
try:
|
19 |
model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone"
|
20 |
+
print(f"Loading classifier model: {model_name}")
|
21 |
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
22 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
+
self.model.to(self.device)
|
24 |
+
self.model.eval()
|
25 |
+
print("Classifier model loaded successfully")
|
26 |
except Exception as e:
|
27 |
+
print(f"Error loading classifier model: {str(e)}")
|
28 |
+
raise
|
29 |
+
|
30 |
+
classifier_model = ClassifierModel()
|
31 |
|
|
|
32 |
class ParaphraserModel:
|
33 |
def __init__(self):
|
34 |
self.model = None
|
35 |
self.tokenizer = None
|
36 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
self.load_paraphraser_model()
|
38 |
|
39 |
+
def load_paraphraser_model(self):
|
40 |
"""
|
41 |
+
Load the fine-tuned Granite 3.2-2B-Instruct model and tokenizer for paraphrasing.
|
42 |
"""
|
43 |
try:
|
44 |
model_name = "ibm-granite/granite-3.2-2b-instruct"
|
45 |
+
print(f"Loading paraphraser model: {model_name}")
|
46 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
47 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
48 |
+
# Set a distinct pad token to avoid conflict with eos token
|
49 |
+
if self.tokenizer.pad_token is None or self.tokenizer.pad_token == self.tokenizer.eos_token:
|
50 |
+
self.tokenizer.pad_token = "<pad>"
|
51 |
+
self.model.config.pad_token_id = self.tokenizer.convert_tokens_to_ids("<pad>")
|
52 |
+
self.model.to(self.device)
|
53 |
+
self.model.eval()
|
54 |
+
print("Paraphraser model loaded successfully")
|
55 |
except Exception as e:
|
56 |
+
print(f"Error loading paraphraser model: {str(e)}")
|
57 |
+
raise
|
58 |
+
|
59 |
+
paraphraser_model = ParaphraserModel()
|
60 |
|
|
|
61 |
class MetricsModels:
|
62 |
def __init__(self):
|
63 |
+
self.sentence_bert = None
|
64 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
+
self.load_sentence_bert()
|
66 |
|
67 |
def load_sentence_bert(self):
|
68 |
+
"""
|
69 |
+
Load the Sentence-BERT model for computing semantic similarity.
|
70 |
+
"""
|
71 |
+
try:
|
72 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
73 |
+
print(f"Loading Sentence-BERT model: {model_name}")
|
74 |
+
self.sentence_bert = SentenceTransformer(model_name, device=self.device)
|
75 |
+
print("Sentence-BERT model loaded successfully")
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error loading Sentence-BERT model: {str(e)}")
|
78 |
+
raise
|
79 |
|
|
|
|
|
|
|
80 |
metrics_models = MetricsModels()
|