import json import argparse import asyncio from tasks.data.data_loaders import TextDataLoader from tasks.models.text_classifiers import ModelFactory from tasks.text import evaluate_text from tasks.utils.evaluation import TextEvaluationRequest def load_config(config_path): with open(config_path, 'r') as config_file: config = json.load(config_file) return config async def train_model(config): # loading data text_request = TextEvaluationRequest() is_light_dataset = False data_loader = TextDataLoader(text_request, light=is_light_dataset) # define model model = ModelFactory.create_model(config) # train model train_dataset = data_loader.get_train_dataset() if model.model is None: model.train(train_dataset) model.save() print("Model training completed and saved.") async def evaluate_model(config): # loading data text_request = TextEvaluationRequest() data_loader = TextDataLoader(text_request) # define model model = ModelFactory.create_model(config) # Call the evaluate_text function results = await evaluate_text(request=text_request, model=model) # Print the results print(json.dumps(results, indent=2)) print(f"Achieved accuracy: {results['accuracy']}") print(f"Energy consumed: {results['energy_consumed_wh']} Wh") async def main(): # Parse command-line arguments parser = argparse.ArgumentParser(description="Train or evaluate the model.") parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file") args = parser.parse_args() # Load configuration config_path = args.config config = load_config(config_path) try: mode = config["mode"] except ValueError: raise ValueError(f"Missing mode in configuration file: {config_path}") if mode == "train": await train_model(config) elif mode == "evaluate": await evaluate_model(config) else: raise ValueError(f"Invalid mode in file '{config_path}': '{mode}'") if __name__ == "__main__": asyncio.run(main())