|
|
import os |
|
|
import pickle |
|
|
|
|
|
import joblib |
|
|
import torch |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from .configuration_sm_subgroup_classifier import SmSubgroupClassifierConfig |
|
|
|
|
|
|
|
|
class SmSubgroupClassifier(PreTrainedModel): |
|
|
config_class = SmSubgroupClassifierConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self._loaded_classifiers = {} |
|
|
self.model_dir = None |
|
|
self._available_models = None |
|
|
|
|
|
@property |
|
|
def available_models(self): |
|
|
"""Auto-discover available models""" |
|
|
if self._available_models is None: |
|
|
self._available_models = self._discover_available_models() |
|
|
return self._available_models |
|
|
|
|
|
def _discover_available_models(self): |
|
|
"""Scan model directory for available models""" |
|
|
if not self.model_dir or not os.path.exists(self.model_dir): |
|
|
return [] |
|
|
|
|
|
models = [] |
|
|
for item in os.listdir(self.model_dir): |
|
|
item_path = os.path.join(self.model_dir, item) |
|
|
if os.path.isdir(item_path): |
|
|
|
|
|
required_files = ["model.pkl", "scaler.pkl", "metadata.pkl"] |
|
|
if all( |
|
|
os.path.exists(os.path.join(item_path, f)) for f in required_files |
|
|
): |
|
|
models.append(item) |
|
|
|
|
|
return sorted(models) |
|
|
|
|
|
def _load_classifier(self, model_key): |
|
|
"""Load a specific classifier by model key (e.g., 'en_OP-ob')""" |
|
|
if model_key in self._loaded_classifiers: |
|
|
return self._loaded_classifiers[model_key] |
|
|
|
|
|
if model_key not in self.available_models: |
|
|
raise ValueError( |
|
|
f"Model '{model_key}' not available. Available: {self.available_models}" |
|
|
) |
|
|
|
|
|
|
|
|
classifier_path = os.path.join(self.model_dir, model_key) |
|
|
if not os.path.exists(classifier_path): |
|
|
raise FileNotFoundError(f"Classifier not found at {classifier_path}") |
|
|
|
|
|
|
|
|
classifier = joblib.load(os.path.join(classifier_path, "model.pkl")) |
|
|
scaler = joblib.load(os.path.join(classifier_path, "scaler.pkl")) |
|
|
with open(os.path.join(classifier_path, "metadata.pkl"), "rb") as f: |
|
|
metadata = pickle.load(f) |
|
|
|
|
|
classifier_info = { |
|
|
"classifier": classifier, |
|
|
"scaler": scaler, |
|
|
"class_names": metadata["class_names"], |
|
|
} |
|
|
|
|
|
self._loaded_classifiers[model_key] = classifier_info |
|
|
return classifier_info |
|
|
|
|
|
def forward(self, language, model_name, embeddings): |
|
|
""" |
|
|
Args: |
|
|
language: Language code (en, fi, sv) |
|
|
model_name: Model name (OP-ob, NA, etc.) |
|
|
embeddings: Pre-computed embeddings |
|
|
""" |
|
|
|
|
|
model_key = f"{language}_{model_name}" |
|
|
|
|
|
|
|
|
if torch.is_tensor(embeddings): |
|
|
embeddings = embeddings.detach().cpu().numpy() |
|
|
if embeddings.ndim == 1: |
|
|
embeddings = embeddings.reshape(1, -1) |
|
|
|
|
|
|
|
|
classifier_info = self._load_classifier(model_key) |
|
|
|
|
|
|
|
|
embeddings_scaled = classifier_info["scaler"].transform(embeddings) |
|
|
predictions = classifier_info["classifier"].predict(embeddings_scaled) |
|
|
probabilities = classifier_info["classifier"].predict_proba(embeddings_scaled) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for pred, probs in zip(predictions, probabilities): |
|
|
predicted_class_name = classifier_info["class_names"][pred] |
|
|
|
|
|
all_probs = { |
|
|
classifier_info["class_names"][i]: float(prob) |
|
|
for i, prob in enumerate(probs) |
|
|
} |
|
|
results.append( |
|
|
{ |
|
|
"label": predicted_class_name, |
|
|
"probabilities": all_probs, |
|
|
} |
|
|
) |
|
|
|
|
|
return results[0] if len(results) == 1 else results |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
|
|
|
|
config = SmSubgroupClassifierConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, **kwargs |
|
|
) |
|
|
|
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
model.model_dir = snapshot_download(pretrained_model_name_or_path) |
|
|
except ImportError: |
|
|
|
|
|
model.model_dir = pretrained_model_name_or_path |
|
|
|
|
|
return model |
|
|
|