Spaces:
Build error
Build error
| import gc | |
| import unittest | |
| from parameterized import parameterized | |
| from diffusers import FlaxUNet2DConditionModel | |
| from diffusers.utils import is_flax_available | |
| from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow | |
| if is_flax_available(): | |
| import jax | |
| import jax.numpy as jnp | |
| class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): | |
| def get_file_format(self, seed, shape): | |
| return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" | |
| def tearDown(self): | |
| # clean up the VRAM after each test | |
| super().tearDown() | |
| gc.collect() | |
| def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): | |
| dtype = jnp.bfloat16 if fp16 else jnp.float32 | |
| image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) | |
| return image | |
| def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): | |
| dtype = jnp.bfloat16 if fp16 else jnp.float32 | |
| revision = "bf16" if fp16 else None | |
| model, params = FlaxUNet2DConditionModel.from_pretrained( | |
| model_id, subfolder="unet", dtype=dtype, revision=revision | |
| ) | |
| return model, params | |
| def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): | |
| dtype = jnp.bfloat16 if fp16 else jnp.float32 | |
| hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) | |
| return hidden_states | |
| def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): | |
| model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) | |
| latents = self.get_latents(seed, fp16=True) | |
| encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) | |
| sample = model.apply( | |
| {"params": params}, | |
| latents, | |
| jnp.array(timestep, dtype=jnp.int32), | |
| encoder_hidden_states=encoder_hidden_states, | |
| ).sample | |
| assert sample.shape == latents.shape | |
| output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) | |
| expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) | |
| # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware | |
| assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) | |
| def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): | |
| model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) | |
| latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) | |
| encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) | |
| sample = model.apply( | |
| {"params": params}, | |
| latents, | |
| jnp.array(timestep, dtype=jnp.int32), | |
| encoder_hidden_states=encoder_hidden_states, | |
| ).sample | |
| assert sample.shape == latents.shape | |
| output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) | |
| expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) | |
| # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware | |
| assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) | |