File size: 1,493 Bytes
af7c068 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax
@require_flax
class FlaxModelTesterMixin:
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|