Spaces:
Sleeping
Sleeping

update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
import json | |
import argparse | |
import asyncio | |
from tasks.data.data_loaders import TextDataLoader | |
from tasks.models.text_classifiers import ModelFactory | |
from tasks.text import evaluate_text | |
from tasks.utils.evaluation import TextEvaluationRequest | |
def load_config(config_path): | |
with open(config_path, 'r') as config_file: | |
config = json.load(config_file) | |
return config | |
async def train_model(config): | |
# loading data | |
text_request = TextEvaluationRequest() | |
is_light_dataset = False | |
data_loader = TextDataLoader(text_request, light=is_light_dataset) | |
# define model | |
model = ModelFactory.create_model(config) | |
# train model | |
train_dataset = data_loader.get_train_dataset() | |
if model.model is None: | |
model.train(train_dataset) | |
model.save() | |
print("Model training completed and saved.") | |
async def evaluate_model(config): | |
# loading data | |
text_request = TextEvaluationRequest() | |
data_loader = TextDataLoader(text_request) | |
# define model | |
model = ModelFactory.create_model(config) | |
# Call the evaluate_text function | |
results = await evaluate_text(request=text_request, model=model) | |
# Print the results | |
print(json.dumps(results, indent=2)) | |
print(f"Achieved accuracy: {results['accuracy']}") | |
print(f"Energy consumed: {results['energy_consumed_wh']} Wh") | |
async def main(): | |
# Parse command-line arguments | |
parser = argparse.ArgumentParser(description="Train or evaluate the model.") | |
parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file") | |
args = parser.parse_args() | |
# Load configuration | |
config_path = args.config | |
config = load_config(config_path) | |
try: | |
mode = config["mode"] | |
except ValueError: | |
raise ValueError(f"Missing mode in configuration file: {config_path}") | |
if mode == "train": | |
await train_model(config) | |
elif mode == "evaluate": | |
await evaluate_model(config) | |
else: | |
raise ValueError(f"Invalid mode in file '{config_path}': '{mode}'") | |
if __name__ == "__main__": | |
asyncio.run(main()) | |