Commit
·
0012f0c
1
Parent(s):
004d6aa
Refactor argument parsing and improve error handling in model loading
Browse filesUpdate .gitignore to exclude PNG files and remove unnecessary output images
- .gitignore +1 -0
- inference.py +2 -4
- models/modeling_dit.py +1 -1
.gitignore
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
*.pyc
|
|
|
|
|
|
| 1 |
*.pyc
|
| 2 |
+
*.png
|
inference.py
CHANGED
|
@@ -51,7 +51,6 @@ def main(args):
|
|
| 51 |
|
| 52 |
# Load model configuration and model
|
| 53 |
config = MMDiTConfig.from_json_file(args.model_config)
|
| 54 |
-
config.vae_type = args.vae_type # VAE overriding
|
| 55 |
config.height = args.resolution
|
| 56 |
config.width = args.resolution
|
| 57 |
|
|
@@ -135,11 +134,10 @@ if __name__ == "__main__":
|
|
| 135 |
# parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
|
| 136 |
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
|
| 137 |
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
|
| 138 |
-
parser.add_argument("--batch-size", type=int, default=32)
|
| 139 |
-
parser.add_argument("--streaming", action="store_true")
|
| 140 |
parser.add_argument("--noisy-pad", action="store_true")
|
| 141 |
parser.add_argument("--zero-masking", action="store_true")
|
| 142 |
-
parser.add_argument("--vae-type", type=str, default="SD3", help="Type of VAE")
|
| 143 |
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
|
| 144 |
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
|
| 145 |
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
|
|
|
|
| 51 |
|
| 52 |
# Load model configuration and model
|
| 53 |
config = MMDiTConfig.from_json_file(args.model_config)
|
|
|
|
| 54 |
config.height = args.resolution
|
| 55 |
config.width = args.resolution
|
| 56 |
|
|
|
|
| 134 |
# parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
|
| 135 |
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
|
| 136 |
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
|
| 137 |
+
parser.add_argument("--batch-size", type=int, default=32,help="Batch size for image generation")
|
| 138 |
+
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode for intermediate steps")
|
| 139 |
parser.add_argument("--noisy-pad", action="store_true")
|
| 140 |
parser.add_argument("--zero-masking", action="store_true")
|
|
|
|
| 141 |
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
|
| 142 |
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
|
| 143 |
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
|
models/modeling_dit.py
CHANGED
|
@@ -13,7 +13,7 @@ try:
|
|
| 13 |
MotifRMSNorm = motif_ops.T5LayerNorm
|
| 14 |
ScaledDotProductAttention = None
|
| 15 |
MotifFlashAttention = motif_ops.flash_attention
|
| 16 |
-
except
|
| 17 |
MotifRMSNorm = None
|
| 18 |
ScaledDotProductAttention = None
|
| 19 |
MotifFlashAttention = None
|
|
|
|
| 13 |
MotifRMSNorm = motif_ops.T5LayerNorm
|
| 14 |
ScaledDotProductAttention = None
|
| 15 |
MotifFlashAttention = motif_ops.flash_attention
|
| 16 |
+
except Exception: # if motif_ops is not available
|
| 17 |
MotifRMSNorm = None
|
| 18 |
ScaledDotProductAttention = None
|
| 19 |
MotifFlashAttention = None
|