submission / main.py
pierre-loic's picture
update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
import json
import argparse
import asyncio
from tasks.data.data_loaders import TextDataLoader
from tasks.models.text_classifiers import ModelFactory
from tasks.text import evaluate_text
from tasks.utils.evaluation import TextEvaluationRequest
def load_config(config_path):
with open(config_path, 'r') as config_file:
config = json.load(config_file)
return config
async def train_model(config):
# loading data
text_request = TextEvaluationRequest()
is_light_dataset = False
data_loader = TextDataLoader(text_request, light=is_light_dataset)
# define model
model = ModelFactory.create_model(config)
# train model
train_dataset = data_loader.get_train_dataset()
if model.model is None:
model.train(train_dataset)
model.save()
print("Model training completed and saved.")
async def evaluate_model(config):
# loading data
text_request = TextEvaluationRequest()
data_loader = TextDataLoader(text_request)
# define model
model = ModelFactory.create_model(config)
# Call the evaluate_text function
results = await evaluate_text(request=text_request, model=model)
# Print the results
print(json.dumps(results, indent=2))
print(f"Achieved accuracy: {results['accuracy']}")
print(f"Energy consumed: {results['energy_consumed_wh']} Wh")
async def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Train or evaluate the model.")
parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file")
args = parser.parse_args()
# Load configuration
config_path = args.config
config = load_config(config_path)
try:
mode = config["mode"]
except ValueError:
raise ValueError(f"Missing mode in configuration file: {config_path}")
if mode == "train":
await train_model(config)
elif mode == "evaluate":
await evaluate_model(config)
else:
raise ValueError(f"Invalid mode in file '{config_path}': '{mode}'")
if __name__ == "__main__":
asyncio.run(main())