Spaces:
Runtime error
Runtime error
half only if cuda is available
Browse files- climategan_wrapper.py +7 -4
climategan_wrapper.py
CHANGED
|
@@ -15,6 +15,8 @@ from skimage.transform import resize
|
|
| 15 |
|
| 16 |
from climategan.trainer import Trainer
|
| 17 |
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def concat_events(output_dict, events, i=None, axis=1):
|
| 20 |
"""
|
|
@@ -136,7 +138,8 @@ class ClimateGAN:
|
|
| 136 |
inference=True,
|
| 137 |
new_exp=None,
|
| 138 |
)
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
def _setup_stable_diffusion(self):
|
| 142 |
"""
|
|
@@ -150,8 +153,8 @@ class ClimateGAN:
|
|
| 150 |
try:
|
| 151 |
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 152 |
"runwayml/stable-diffusion-inpainting",
|
| 153 |
-
revision="fp16",
|
| 154 |
-
torch_dtype=torch.float16,
|
| 155 |
safety_checker=None,
|
| 156 |
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
|
| 157 |
).to(self.trainer.device)
|
|
@@ -308,7 +311,7 @@ class ClimateGAN:
|
|
| 308 |
images,
|
| 309 |
numpy=True,
|
| 310 |
bin_value=0.5,
|
| 311 |
-
half=
|
| 312 |
ignore_event=ignore_event,
|
| 313 |
return_masks=True,
|
| 314 |
)
|
|
|
|
| 15 |
|
| 16 |
from climategan.trainer import Trainer
|
| 17 |
|
| 18 |
+
CUDA = torch.cuda.is_available()
|
| 19 |
+
|
| 20 |
|
| 21 |
def concat_events(output_dict, events, i=None, axis=1):
|
| 22 |
"""
|
|
|
|
| 138 |
inference=True,
|
| 139 |
new_exp=None,
|
| 140 |
)
|
| 141 |
+
if CUDA:
|
| 142 |
+
self.trainer.G.half()
|
| 143 |
|
| 144 |
def _setup_stable_diffusion(self):
|
| 145 |
"""
|
|
|
|
| 153 |
try:
|
| 154 |
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 155 |
"runwayml/stable-diffusion-inpainting",
|
| 156 |
+
revision="fp16" if CUDA else "main",
|
| 157 |
+
torch_dtype=torch.float16 if CUDA else torch.float32,
|
| 158 |
safety_checker=None,
|
| 159 |
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
|
| 160 |
).to(self.trainer.device)
|
|
|
|
| 311 |
images,
|
| 312 |
numpy=True,
|
| 313 |
bin_value=0.5,
|
| 314 |
+
half=CUDA,
|
| 315 |
ignore_event=ignore_event,
|
| 316 |
return_masks=True,
|
| 317 |
)
|