import json
import argparse
import asyncio

from tasks.data.data_loaders import TextDataLoader
from tasks.models.text_classifiers import ModelFactory
from tasks.text import evaluate_text
from tasks.utils.evaluation import TextEvaluationRequest

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

async def train_model(config):
    # loading data
    text_request = TextEvaluationRequest()
    is_light_dataset = False
    data_loader = TextDataLoader(text_request, light=is_light_dataset)

    # define model
    model = ModelFactory.create_model(config)

    # train model
    train_dataset = data_loader.get_train_dataset()
    if model.model is None:
        model.train(train_dataset)
        model.save()
    print("Model training completed and saved.")

async def evaluate_model(config):
    # loading data
    text_request = TextEvaluationRequest()
    data_loader = TextDataLoader(text_request)

    # define model
    model = ModelFactory.create_model(config)

    # Call the evaluate_text function
    results = await evaluate_text(request=text_request, model=model)

    # Print the results
    print(json.dumps(results, indent=2))
    print(f"Achieved accuracy: {results['accuracy']}")
    print(f"Energy consumed: {results['energy_consumed_wh']} Wh")

async def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Train or evaluate the model.")
    parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file")
    args = parser.parse_args()

    # Load configuration
    config_path = args.config
    config = load_config(config_path)

    try:
        mode = config["mode"]
    except ValueError:
        raise ValueError(f"Missing mode in configuration file: {config_path}")

    if mode == "train":
        await train_model(config)
    elif mode == "evaluate":
        await evaluate_model(config)
    else:
        raise ValueError(f"Invalid mode in file '{config_path}': '{mode}'")

if __name__ == "__main__":
    asyncio.run(main())