submission / test_text_classifiers.py
pierre-loic's picture
update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
import random
import numpy as np
import pytest
from main import load_config
from tasks.data.data_loaders import TextDataLoader
from tasks.models.text_classifiers import DistilBERTModel, ModelFactory, TextEmbedder, MLModel, EmbeddingMLModel, \
TfIdfEmbedder
from tasks.utils.evaluation import TextEvaluationRequest
@pytest.fixture()
def data_loader():
# define text request
text_request = TextEvaluationRequest()
return TextDataLoader(text_request, light=True)
@pytest.fixture()
def train_dataset(data_loader):
return data_loader.get_train_dataset()
@pytest.fixture()
def test_dataset(data_loader):
return data_loader.get_test_dataset()
class TestDistilBERTModel:
@pytest.fixture()
def distilBERT_model(self):
config = load_config("config_training_test.json")
return ModelFactory.create_model(config)
def test_trained_distilBERT(self, train_dataset, distilBERT_model, test_dataset):
assert "DistilBERT" in distilBERT_model.description
# train model
distilBERT_model.train(train_dataset)
# inference
predictions = [distilBERT_model.predict(quote) for quote in test_dataset["quote"]]
for prediction in predictions:
assert prediction in range(8)
def test_data_preprocessing(self, train_dataset, distilBERT_model):
pre_processed_data = distilBERT_model.pre_process_data(train_dataset)
assert pre_processed_data is not None
assert pre_processed_data["train"].num_rows == 8
assert pre_processed_data["test"].num_rows == 2
for subset in ["train", "test"]:
for feature_name in ['quote', 'label', 'input_ids', 'attention_mask']:
assert feature_name in pre_processed_data[subset].features.keys()
class DummyEmbedder(TextEmbedder):
def encode(self, text: str) -> np.ndarray:
return np.random.rand(42)
class DummyMLModel(MLModel):
def fit(self, X, y):
pass
def predict(self, X):
return random.choice(range(8))
class TestEmbeddingMLModel:
@pytest.fixture()
def embeddingML(self):
config = load_config("config_training_embedding_test.json")
config["model"] = "EmbeddingMLModel"
return ModelFactory.create_model(config)
def test_EmbeddingML(self, train_dataset, embeddingML):
assert "EmbeddingMLModel" in embeddingML.description
# train model
embeddingML.train(train_dataset)
# inference
assert embeddingML.predict("a quote") in range(8)
def test_dummy_train_EmbeddingML(self, train_dataset):
dummy_model = EmbeddingMLModel(embedder=DummyEmbedder(),
ml_model=DummyMLModel())
dummy_model.train(train_dataset)
assert dummy_model.predict("dummy") in range(8)
class TestEmbedders:
def test_tf_idf(self):
embedder = TfIdfEmbedder()
texts = [
"hello world",
"world hello",
"yet another text",
"this is a test",
"this one as well"
]
encoded_texts = embedder.encode(texts)
assert encoded_texts.shape == (5, 11)