Update pipeline.py
Browse files- pipeline.py +22 -7
pipeline.py
CHANGED
|
@@ -7,7 +7,7 @@ from huggingface_hub import hf_hub_download
|
|
| 7 |
from safetensors.torch import load_file
|
| 8 |
import os
|
| 9 |
from .vae import AutoencoderKL
|
| 10 |
-
from .mar import
|
| 11 |
|
| 12 |
# inheriting from DiffusionPipeline for HF
|
| 13 |
class MARModel(DiffusionPipeline):
|
|
@@ -33,12 +33,27 @@ class MARModel(DiffusionPipeline):
|
|
| 33 |
model_type = kwargs.get("model_type", "mar_base")
|
| 34 |
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# download and load the model weights (.safetensors or .pth)
|
| 43 |
model_checkpoint_path = hf_hub_download(
|
| 44 |
repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
|
|
|
|
| 7 |
from safetensors.torch import load_file
|
| 8 |
import os
|
| 9 |
from .vae import AutoencoderKL
|
| 10 |
+
from .mar import mar_base, mar_large, mar_huge
|
| 11 |
|
| 12 |
# inheriting from DiffusionPipeline for HF
|
| 13 |
class MARModel(DiffusionPipeline):
|
|
|
|
| 33 |
model_type = kwargs.get("model_type", "mar_base")
|
| 34 |
|
| 35 |
|
| 36 |
+
if model_type == "mar_base":
|
| 37 |
+
self.model = mar_base(
|
| 38 |
+
buffer_size=buffer_size,
|
| 39 |
+
diffloss_d=diffloss_d,
|
| 40 |
+
diffloss_w=diffloss_w,
|
| 41 |
+
num_sampling_steps=str(num_sampling_steps)
|
| 42 |
+
).to(device)
|
| 43 |
+
elif model_type == "mar_large":
|
| 44 |
+
self.model = mar_large(
|
| 45 |
+
buffer_size=buffer_size,
|
| 46 |
+
diffloss_d=diffloss_d,
|
| 47 |
+
diffloss_w=diffloss_w,
|
| 48 |
+
num_sampling_steps=str(num_sampling_steps)
|
| 49 |
+
).to(device)
|
| 50 |
+
elif model_type == "mar_huge":
|
| 51 |
+
self.model = mar_huge(
|
| 52 |
+
buffer_size=buffer_size,
|
| 53 |
+
diffloss_d=diffloss_d,
|
| 54 |
+
diffloss_w=diffloss_w,
|
| 55 |
+
num_sampling_steps=str(num_sampling_steps)
|
| 56 |
+
).to(device)
|
| 57 |
# download and load the model weights (.safetensors or .pth)
|
| 58 |
model_checkpoint_path = hf_hub_download(
|
| 59 |
repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
|