Spaces:
Sleeping
Sleeping

update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
import random | |
import numpy as np | |
import pytest | |
from main import load_config | |
from tasks.data.data_loaders import TextDataLoader | |
from tasks.models.text_classifiers import DistilBERTModel, ModelFactory, TextEmbedder, MLModel, EmbeddingMLModel, \ | |
TfIdfEmbedder | |
from tasks.utils.evaluation import TextEvaluationRequest | |
def data_loader(): | |
# define text request | |
text_request = TextEvaluationRequest() | |
return TextDataLoader(text_request, light=True) | |
def train_dataset(data_loader): | |
return data_loader.get_train_dataset() | |
def test_dataset(data_loader): | |
return data_loader.get_test_dataset() | |
class TestDistilBERTModel: | |
def distilBERT_model(self): | |
config = load_config("config_training_test.json") | |
return ModelFactory.create_model(config) | |
def test_trained_distilBERT(self, train_dataset, distilBERT_model, test_dataset): | |
assert "DistilBERT" in distilBERT_model.description | |
# train model | |
distilBERT_model.train(train_dataset) | |
# inference | |
predictions = [distilBERT_model.predict(quote) for quote in test_dataset["quote"]] | |
for prediction in predictions: | |
assert prediction in range(8) | |
def test_data_preprocessing(self, train_dataset, distilBERT_model): | |
pre_processed_data = distilBERT_model.pre_process_data(train_dataset) | |
assert pre_processed_data is not None | |
assert pre_processed_data["train"].num_rows == 8 | |
assert pre_processed_data["test"].num_rows == 2 | |
for subset in ["train", "test"]: | |
for feature_name in ['quote', 'label', 'input_ids', 'attention_mask']: | |
assert feature_name in pre_processed_data[subset].features.keys() | |
class DummyEmbedder(TextEmbedder): | |
def encode(self, text: str) -> np.ndarray: | |
return np.random.rand(42) | |
class DummyMLModel(MLModel): | |
def fit(self, X, y): | |
pass | |
def predict(self, X): | |
return random.choice(range(8)) | |
class TestEmbeddingMLModel: | |
def embeddingML(self): | |
config = load_config("config_training_embedding_test.json") | |
config["model"] = "EmbeddingMLModel" | |
return ModelFactory.create_model(config) | |
def test_EmbeddingML(self, train_dataset, embeddingML): | |
assert "EmbeddingMLModel" in embeddingML.description | |
# train model | |
embeddingML.train(train_dataset) | |
# inference | |
assert embeddingML.predict("a quote") in range(8) | |
def test_dummy_train_EmbeddingML(self, train_dataset): | |
dummy_model = EmbeddingMLModel(embedder=DummyEmbedder(), | |
ml_model=DummyMLModel()) | |
dummy_model.train(train_dataset) | |
assert dummy_model.predict("dummy") in range(8) | |
class TestEmbedders: | |
def test_tf_idf(self): | |
embedder = TfIdfEmbedder() | |
texts = [ | |
"hello world", | |
"world hello", | |
"yet another text", | |
"this is a test", | |
"this one as well" | |
] | |
encoded_texts = embedder.encode(texts) | |
assert encoded_texts.shape == (5, 11) | |