File size: 3,313 Bytes
29958a2
1da736d
 
29958a2
 
0d7e4cb
8565ee2
29958a2
eeddd9f
 
8b6e253
29958a2
fad62be
29958a2
e6cf7d1
29958a2
b1eac09
 
29958a2
 
 
 
fad62be
859c26b
4184307
29958a2
 
 
 
eeddd9f
 
29958a2
 
d3fc59d
1da736d
8dd4d65
1da736d
ce08539
a1fff46
0d7e4cb
29958a2
eeddd9f
 
29958a2
 
 
 
eeddd9f
 
 
f55a34e
3cab9bb
1da736d
cf23f16
1da736d
eeddd9f
 
 
 
 
 
 
 
 
 
 
 
29958a2
 
eeddd9f
 
 
 
1da736d
 
c848fb3
29958a2
c848fb3
 
29958a2
858501a
 
 
 
 
 
 
 
 
 
 
 
 
bfd5e5b
858501a
 
 
 
52c932d
858501a
 
29958a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
from diffusers import UniPCMultistepScheduler
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torchvision
import torchvision.transforms as T
from flax.jax_utils import replicate
from flax.training.common_utils import shard
#from torchvision.transforms import v2 as T2
import cv2
import PIL
from PIL import Image
import numpy as np

import torchvision.transforms.functional as F

output_res = (768,768)

conditioning_image_transforms = T.Compose(
    [
        #T2.ScaleJitter(target_size=output_res, scale_range=(0.5, 3.0))),
        T.RandomCrop(size=output_res, pad_if_needed=True, padding_mode="symmetric"),
        T.ToTensor(),
        T.Normalize([0.5], [0.5]),
    ]
)

cnet, cnet_params = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        "./models/wd-1-5-b2",
        controlnet=cnet,
        revision="flax",
        dtype=jnp.bfloat16,
        from_pt=True
        )
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
#pipe.enable_xformers_memory_efficient_attention()

def get_random(seed):
    jax.random.PRNGKey(seed)

# inference function takes prompt, negative prompt and image
def infer(prompt, negative_prompt, image):
    # implement your inference function here
    params["controlnet"] = cnet_params
    num_samples = 1

    inp = Image.fromarray(image)

    cond_input = conditioning_image_transforms(inp)
    cond_input = T.ToPILImage()(cond_input)

    cond_img_in = pipe.prepare_image_inputs([cond_input] * num_samples)

    prompt_in = pipe.prepare_text_inputs([prompt] * num_samples)
    prompt_in = shard(prompt_in)

    n_prompt_in = pipe.prepare_text_inputs([negative_prompt] * num_samples)
    n_prompt_in = shard(n_prompt_in)

    rng = get_random(0)
    rng.random.split(rng, jax.device_count())

    p_params = replicate(params)
    
    output = pipe(
        prompt_ids=prompts_in,
        image=cond_img_in,
        prng_seed=rng,
        neg_prompt_ids=n_prompt_in,
        num_inference_steps=20,
        jit=True
            ).images

    output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return output_images

gr.Interface(
    infer,
    inputs=[
        gr.Textbox(
            label="Enter prompt",
            max_lines=1,
            placeholder="1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck",
        ),
        gr.Textbox(
            label="Enter negative prompt",
            max_lines=1,
            placeholder="low quality",
        ),
        gr.Image(),
    ],
    outputs=gr.Gallery().style(grid=[2], height="auto"),
    title="Generate controlled outputs with Categorical Conditioning on Waifu Diffusion 1.5 beta 2.",
    description="This Space uses image examples as style conditioning.",
    examples=[["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "low quality", "wikipe_cond_1.png"]],
    allow_flagging=False,
).launch(enable_queue=True)