Upload pipeline.py
Browse files- pipeline.py +0 -99
pipeline.py
CHANGED
|
@@ -825,105 +825,6 @@ class FluxTransformer2DModelWithMasking(
|
|
| 825 |
|
| 826 |
return Transformer2DModelOutput(sample=output)
|
| 827 |
|
| 828 |
-
|
| 829 |
-
if __name__ == "__main__":
|
| 830 |
-
dtype = torch.bfloat16
|
| 831 |
-
bsz = 2
|
| 832 |
-
img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
|
| 833 |
-
timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
|
| 834 |
-
pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
|
| 835 |
-
text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
|
| 836 |
-
attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
|
| 837 |
-
"cuda", dtype=dtype
|
| 838 |
-
) # Last 128 positions are masked
|
| 839 |
-
|
| 840 |
-
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 841 |
-
latents = latents.view(
|
| 842 |
-
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
| 843 |
-
)
|
| 844 |
-
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 845 |
-
latents = latents.reshape(
|
| 846 |
-
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
| 847 |
-
)
|
| 848 |
-
|
| 849 |
-
return latents
|
| 850 |
-
|
| 851 |
-
def _prepare_latent_image_ids(
|
| 852 |
-
batch_size, height, width, device="cuda", dtype=dtype
|
| 853 |
-
):
|
| 854 |
-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 855 |
-
latent_image_ids[..., 1] = (
|
| 856 |
-
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 857 |
-
)
|
| 858 |
-
latent_image_ids[..., 2] = (
|
| 859 |
-
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 860 |
-
)
|
| 861 |
-
|
| 862 |
-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
|
| 863 |
-
latent_image_ids.shape
|
| 864 |
-
)
|
| 865 |
-
|
| 866 |
-
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 867 |
-
latent_image_ids = latent_image_ids.reshape(
|
| 868 |
-
batch_size,
|
| 869 |
-
latent_image_id_height * latent_image_id_width,
|
| 870 |
-
latent_image_id_channels,
|
| 871 |
-
)
|
| 872 |
-
|
| 873 |
-
return latent_image_ids.to(device=device, dtype=dtype)
|
| 874 |
-
|
| 875 |
-
txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
|
| 876 |
-
|
| 877 |
-
vae_scale_factor = 16
|
| 878 |
-
height = 2 * (int(512) // vae_scale_factor)
|
| 879 |
-
width = 2 * (int(512) // vae_scale_factor)
|
| 880 |
-
img_ids = _prepare_latent_image_ids(bsz, height, width)
|
| 881 |
-
img = _pack_latents(img, img.shape[0], 16, height, width)
|
| 882 |
-
|
| 883 |
-
# Gotta go fast
|
| 884 |
-
transformer = FluxTransformer2DModelWithMasking.from_config(
|
| 885 |
-
{
|
| 886 |
-
"attention_head_dim": 128,
|
| 887 |
-
"guidance_embeds": True,
|
| 888 |
-
"in_channels": 64,
|
| 889 |
-
"joint_attention_dim": 4096,
|
| 890 |
-
"num_attention_heads": 24,
|
| 891 |
-
"num_layers": 4,
|
| 892 |
-
"num_single_layers": 8,
|
| 893 |
-
"patch_size": 1,
|
| 894 |
-
"pooled_projection_dim": 768,
|
| 895 |
-
}
|
| 896 |
-
).to("cuda", dtype=dtype)
|
| 897 |
-
|
| 898 |
-
guidance = torch.tensor([2.0], device="cuda")
|
| 899 |
-
guidance = guidance.expand(bsz)
|
| 900 |
-
|
| 901 |
-
with torch.no_grad():
|
| 902 |
-
no_mask = transformer(
|
| 903 |
-
img,
|
| 904 |
-
encoder_hidden_states=text,
|
| 905 |
-
pooled_projections=pooled,
|
| 906 |
-
timestep=timestep,
|
| 907 |
-
img_ids=img_ids,
|
| 908 |
-
txt_ids=txt_ids,
|
| 909 |
-
guidance=guidance,
|
| 910 |
-
)
|
| 911 |
-
mask = transformer(
|
| 912 |
-
img,
|
| 913 |
-
encoder_hidden_states=text,
|
| 914 |
-
pooled_projections=pooled,
|
| 915 |
-
timestep=timestep,
|
| 916 |
-
img_ids=img_ids,
|
| 917 |
-
txt_ids=txt_ids,
|
| 918 |
-
guidance=guidance,
|
| 919 |
-
attention_mask=attn_mask,
|
| 920 |
-
)
|
| 921 |
-
|
| 922 |
-
assert torch.allclose(no_mask.sample, mask.sample) is False
|
| 923 |
-
print("Attention masking test ran OK. Differences in output were detected.")
|
| 924 |
-
|
| 925 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 926 |
-
|
| 927 |
EXAMPLE_DOC_STRING = """
|
| 928 |
Examples:
|
| 929 |
```py
|
|
|
|
| 825 |
|
| 826 |
return Transformer2DModelOutput(sample=output)
|
| 827 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
EXAMPLE_DOC_STRING = """
|
| 829 |
Examples:
|
| 830 |
```py
|