refactor/motifimage
#2
by
beomgyu-kim
- opened
checkpoints/checkpoint.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:6eda25e5e3a73a1363eaf4aa98219af13237509f79195e6c4b37c0d1dd6b0d89
|
3 |
-
size 23666150551
|
|
|
|
|
|
|
|
inference.py
CHANGED
@@ -8,7 +8,7 @@ import numpy as np
|
|
8 |
import torch
|
9 |
import pickle
|
10 |
from configs.configuration_mmdit import MMDiTConfig
|
11 |
-
from models.
|
12 |
|
13 |
from safetensors.torch import load_file
|
14 |
# from tools.motif_api import PromptRewriter
|
@@ -54,7 +54,7 @@ def main(args):
|
|
54 |
config.height = args.resolution
|
55 |
config.width = args.resolution
|
56 |
|
57 |
-
model =
|
58 |
|
59 |
# Load checkpoint
|
60 |
try:
|
|
|
8 |
import torch
|
9 |
import pickle
|
10 |
from configs.configuration_mmdit import MMDiTConfig
|
11 |
+
from models.modeling_motifimage import MotifImage
|
12 |
|
13 |
from safetensors.torch import load_file
|
14 |
# from tools.motif_api import PromptRewriter
|
|
|
54 |
config.height = args.resolution
|
55 |
config.width = args.resolution
|
56 |
|
57 |
+
model = MotifImage(config)
|
58 |
|
59 |
# Load checkpoint
|
60 |
try:
|
models/{modeling_motif_vision.py → modeling_motifimage.py}
RENAMED
@@ -28,9 +28,9 @@ def generate_intervals(steps, ratio, start=1.0):
|
|
28 |
return intervals
|
29 |
|
30 |
|
31 |
-
class
|
32 |
"""
|
33 |
-
|
34 |
|
35 |
This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders.
|
36 |
It uses a VAE (Variational Autoencoder) for image encoding and decoding.
|
@@ -128,7 +128,7 @@ class MotifVision(nn.Module, FlowMixin):
|
|
128 |
self.text_encoders = [self.t5, self.clip_l, self.clip_g]
|
129 |
|
130 |
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
131 |
-
state_dict = super(
|
132 |
exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."]
|
133 |
for key in list(state_dict.keys()):
|
134 |
if any(key.startswith(exclude_key) for exclude_key in exclude_keys):
|
|
|
28 |
return intervals
|
29 |
|
30 |
|
31 |
+
class MotifImage(nn.Module, FlowMixin):
|
32 |
"""
|
33 |
+
MotifImage Text-to-Image model.
|
34 |
|
35 |
This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders.
|
36 |
It uses a VAE (Variational Autoencoder) for image encoding and decoding.
|
|
|
128 |
self.text_encoders = [self.t5, self.clip_l, self.clip_g]
|
129 |
|
130 |
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
131 |
+
state_dict = super(MotifImage, self).state_dict(destination, prefix, keep_vars)
|
132 |
exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."]
|
133 |
for key in list(state_dict.keys()):
|
134 |
if any(key.startswith(exclude_key) for exclude_key in exclude_keys):
|