Upload model
Browse files
model.py
CHANGED
|
@@ -65,6 +65,9 @@ def create_model_from_args(args) -> nn.Module:
|
|
| 65 |
elif args.input_size is not None:
|
| 66 |
in_chans = args.input_size[0]
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
model = create_model(
|
| 69 |
args.model,
|
| 70 |
pretrained=args.pretrained,
|
|
@@ -78,6 +81,7 @@ def create_model_from_args(args) -> nn.Module:
|
|
| 78 |
bn_eps=args.bn_eps,
|
| 79 |
scriptable=args.torchscript,
|
| 80 |
checkpoint_path=args.initial_checkpoint,
|
|
|
|
| 81 |
**args.model_kwargs,
|
| 82 |
)
|
| 83 |
|
|
|
|
| 65 |
elif args.input_size is not None:
|
| 66 |
in_chans = args.input_size[0]
|
| 67 |
|
| 68 |
+
# Skip weight initialization unless it's explicitly requested.
|
| 69 |
+
weight_init = args.model_kwargs.pop("weight_init", "skip")
|
| 70 |
+
|
| 71 |
model = create_model(
|
| 72 |
args.model,
|
| 73 |
pretrained=args.pretrained,
|
|
|
|
| 81 |
bn_eps=args.bn_eps,
|
| 82 |
scriptable=args.torchscript,
|
| 83 |
checkpoint_path=args.initial_checkpoint,
|
| 84 |
+
weight_init=weight_init,
|
| 85 |
**args.model_kwargs,
|
| 86 |
)
|
| 87 |
|