UofTearsBotAPI / IllnessClassifier.py
42Cummer's picture
Uploaded files from Cursor
22d76f2 verified
raw
history blame
597 Bytes
from transformers import ( # pylint: disable=import-error
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
pipeline
)
import logging
class IllnessClassifier(object):
def __init__(self):
self.classifier = pipeline("text-classification", model="dsuram/distilbert-mentalhealth-classifier")
def forward(self, text: str):
output = self.classifier(text)[0]
disorder = output['label']
confidence = output['score']
logging.info(f"Disorder: {disorder}, Confidence: {confidence}")
return disorder, confidence