JanviMl commited on
Commit
986acc0
·
verified ·
1 Parent(s): ac0ca8d

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +45 -19
model_loader.py CHANGED
@@ -1,54 +1,80 @@
1
  # model_loader.py
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
3
  from sentence_transformers import SentenceTransformer
 
 
4
 
5
- # Classifier Model (XLM-RoBERTa for toxicity classification)
6
  class ClassifierModel:
7
  def __init__(self):
8
  self.model = None
9
  self.tokenizer = None
10
- self.load_model()
 
11
 
12
- def load_model(self):
13
  """
14
- Load the fine-tuned XLM-RoBERTa model and tokenizer for toxic comment classification.
15
  """
16
  try:
17
  model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone"
 
18
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
19
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
 
 
 
20
  except Exception as e:
21
- raise Exception(f"Error loading classifier model or tokenizer: {str(e)}")
 
 
 
22
 
23
- # Paraphraser Model (Granite 3.2-2B-Instruct for paraphrasing)
24
  class ParaphraserModel:
25
  def __init__(self):
26
  self.model = None
27
  self.tokenizer = None
28
- self.load_model()
 
29
 
30
- def load_model(self):
31
  """
32
- Load the Granite 3.2-2B-Instruct model and tokenizer for paraphrasing.
33
  """
34
  try:
35
  model_name = "ibm-granite/granite-3.2-2b-instruct"
 
36
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
37
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
38
  except Exception as e:
39
- raise Exception(f"Error loading paraphrase model or tokenizer: {str(e)}")
 
 
 
40
 
41
- # Metrics Models (Sentence-BERT only)
42
  class MetricsModels:
43
  def __init__(self):
44
- self.sentence_bert_model = None
 
 
45
 
46
  def load_sentence_bert(self):
47
- if self.sentence_bert_model is None:
48
- self.sentence_bert_model = SentenceTransformer('all-MiniLM-L6-v2')
49
- return self.sentence_bert_model
 
 
 
 
 
 
 
 
50
 
51
- # Singleton instances
52
- classifier_model = ClassifierModel()
53
- paraphraser_model = ParaphraserModel()
54
  metrics_models = MetricsModels()
 
1
  # model_loader.py
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
3
  from sentence_transformers import SentenceTransformer
4
+ import torch
5
+ import os
6
 
 
7
  class ClassifierModel:
8
  def __init__(self):
9
  self.model = None
10
  self.tokenizer = None
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.load_classifier_model()
13
 
14
+ def load_classifier_model(self):
15
  """
16
+ Load the fine-tuned XLM-RoBERTa model and tokenizer for toxicity classification.
17
  """
18
  try:
19
  model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone"
20
+ print(f"Loading classifier model: {model_name}")
21
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ self.model.to(self.device)
24
+ self.model.eval()
25
+ print("Classifier model loaded successfully")
26
  except Exception as e:
27
+ print(f"Error loading classifier model: {str(e)}")
28
+ raise
29
+
30
+ classifier_model = ClassifierModel()
31
 
 
32
  class ParaphraserModel:
33
  def __init__(self):
34
  self.model = None
35
  self.tokenizer = None
36
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ self.load_paraphraser_model()
38
 
39
+ def load_paraphraser_model(self):
40
  """
41
+ Load the fine-tuned Granite 3.2-2B-Instruct model and tokenizer for paraphrasing.
42
  """
43
  try:
44
  model_name = "ibm-granite/granite-3.2-2b-instruct"
45
+ print(f"Loading paraphraser model: {model_name}")
46
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
47
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
48
+ # Set a distinct pad token to avoid conflict with eos token
49
+ if self.tokenizer.pad_token is None or self.tokenizer.pad_token == self.tokenizer.eos_token:
50
+ self.tokenizer.pad_token = "<pad>"
51
+ self.model.config.pad_token_id = self.tokenizer.convert_tokens_to_ids("<pad>")
52
+ self.model.to(self.device)
53
+ self.model.eval()
54
+ print("Paraphraser model loaded successfully")
55
  except Exception as e:
56
+ print(f"Error loading paraphraser model: {str(e)}")
57
+ raise
58
+
59
+ paraphraser_model = ParaphraserModel()
60
 
 
61
  class MetricsModels:
62
  def __init__(self):
63
+ self.sentence_bert = None
64
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ self.load_sentence_bert()
66
 
67
  def load_sentence_bert(self):
68
+ """
69
+ Load the Sentence-BERT model for computing semantic similarity.
70
+ """
71
+ try:
72
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
73
+ print(f"Loading Sentence-BERT model: {model_name}")
74
+ self.sentence_bert = SentenceTransformer(model_name, device=self.device)
75
+ print("Sentence-BERT model loaded successfully")
76
+ except Exception as e:
77
+ print(f"Error loading Sentence-BERT model: {str(e)}")
78
+ raise
79
 
 
 
 
80
  metrics_models = MetricsModels()