feat: add load checkpoint
Browse files
train.py
CHANGED
|
@@ -14,6 +14,7 @@ torch.set_float32_matmul_precision("high")
|
|
| 14 |
parser = argparse.ArgumentParser()
|
| 15 |
parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
|
| 16 |
parser.add_argument("-b", "--single-batch-size", type=int, default=64)
|
|
|
|
| 17 |
|
| 18 |
args = parser.parse_args()
|
| 19 |
|
|
@@ -88,5 +89,5 @@ detector = FontDetector(
|
|
| 88 |
num_iters=num_iters,
|
| 89 |
)
|
| 90 |
|
| 91 |
-
trainer.fit(detector, datamodule=data_module)
|
| 92 |
trainer.test(detector, datamodule=data_module)
|
|
|
|
| 14 |
parser = argparse.ArgumentParser()
|
| 15 |
parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
|
| 16 |
parser.add_argument("-b", "--single-batch-size", type=int, default=64)
|
| 17 |
+
parser.add_argument("-c", "--checkpoint", type=str, default=None)
|
| 18 |
|
| 19 |
args = parser.parse_args()
|
| 20 |
|
|
|
|
| 89 |
num_iters=num_iters,
|
| 90 |
)
|
| 91 |
|
| 92 |
+
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
| 93 |
trainer.test(detector, datamodule=data_module)
|