sm-subgroup-classifier / modeling_sm_subgroup_classifier.py
erikhenriksson's picture
Upload folder using huggingface_hub
f71be9a verified
import os
import pickle
import joblib
import torch
from transformers import PreTrainedModel
from .configuration_sm_subgroup_classifier import SmSubgroupClassifierConfig
class SmSubgroupClassifier(PreTrainedModel):
config_class = SmSubgroupClassifierConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self._loaded_classifiers = {}
self.model_dir = None
self._available_models = None
@property
def available_models(self):
"""Auto-discover available models"""
if self._available_models is None:
self._available_models = self._discover_available_models()
return self._available_models
def _discover_available_models(self):
"""Scan model directory for available models"""
if not self.model_dir or not os.path.exists(self.model_dir):
return []
models = []
for item in os.listdir(self.model_dir):
item_path = os.path.join(self.model_dir, item)
if os.path.isdir(item_path):
# Verify it's a valid model directory
required_files = ["model.pkl", "scaler.pkl", "metadata.pkl"]
if all(
os.path.exists(os.path.join(item_path, f)) for f in required_files
):
models.append(item)
return sorted(models)
def _load_classifier(self, model_key):
"""Load a specific classifier by model key (e.g., 'en_OP-ob')"""
if model_key in self._loaded_classifiers:
return self._loaded_classifiers[model_key]
if model_key not in self.available_models:
raise ValueError(
f"Model '{model_key}' not available. Available: {self.available_models}"
)
# Path to classifier
classifier_path = os.path.join(self.model_dir, model_key)
if not os.path.exists(classifier_path):
raise FileNotFoundError(f"Classifier not found at {classifier_path}")
# Load components
classifier = joblib.load(os.path.join(classifier_path, "model.pkl"))
scaler = joblib.load(os.path.join(classifier_path, "scaler.pkl"))
with open(os.path.join(classifier_path, "metadata.pkl"), "rb") as f:
metadata = pickle.load(f)
classifier_info = {
"classifier": classifier,
"scaler": scaler,
"class_names": metadata["class_names"],
}
self._loaded_classifiers[model_key] = classifier_info
return classifier_info
def forward(self, language, model_name, embeddings):
"""
Args:
language: Language code (en, fi, sv)
model_name: Model name (OP-ob, NA, etc.)
embeddings: Pre-computed embeddings
"""
# Create model key
model_key = f"{language}_{model_name}"
# Convert embeddings to numpy if needed
if torch.is_tensor(embeddings):
embeddings = embeddings.detach().cpu().numpy()
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
# Load classifier
classifier_info = self._load_classifier(model_key)
# Scale and predict
embeddings_scaled = classifier_info["scaler"].transform(embeddings)
predictions = classifier_info["classifier"].predict(embeddings_scaled)
probabilities = classifier_info["classifier"].predict_proba(embeddings_scaled)
# Format results - just use class names and probabilities
results = []
for pred, probs in zip(predictions, probabilities):
predicted_class_name = classifier_info["class_names"][pred]
# Get all class probabilities
all_probs = {
classifier_info["class_names"][i]: float(prob)
for i, prob in enumerate(probs)
}
results.append(
{
"label": predicted_class_name,
"probabilities": all_probs,
}
)
return results[0] if len(results) == 1 else results
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# Load config
config = SmSubgroupClassifierConfig.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
# Create model instance (skip the pytorch weight loading)
model = cls(config)
# For HF Hub, we need to resolve to the actual cached directory
try:
from huggingface_hub import snapshot_download
# Download/get the cached directory path
model.model_dir = snapshot_download(pretrained_model_name_or_path)
except ImportError:
# Fallback if huggingface_hub not available
model.model_dir = pretrained_model_name_or_path
return model