|
import logging |
|
import traceback |
|
|
|
from finetrainers import Trainer, parse_arguments |
|
from finetrainers.constants import FINETRAINERS_LOG_LEVEL |
|
|
|
|
|
logger = logging.getLogger("finetrainers") |
|
logger.setLevel(FINETRAINERS_LOG_LEVEL) |
|
|
|
|
|
def main(): |
|
try: |
|
import multiprocessing |
|
|
|
multiprocessing.set_start_method("fork") |
|
except Exception as e: |
|
logger.error( |
|
f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. ' |
|
f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n" |
|
f"Error: {e}" |
|
) |
|
|
|
try: |
|
args = parse_arguments() |
|
trainer = Trainer(args) |
|
|
|
trainer.prepare_dataset() |
|
trainer.prepare_models() |
|
trainer.prepare_precomputations() |
|
trainer.prepare_trainable_parameters() |
|
trainer.prepare_optimizer() |
|
trainer.prepare_for_training() |
|
trainer.prepare_trackers() |
|
trainer.train() |
|
|
|
|
|
except KeyboardInterrupt: |
|
logger.info("Received keyboard interrupt. Exiting...") |
|
except Exception as e: |
|
logger.error(f"An error occurred during training: {e}") |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|