Spaces:
Running
on
L4
Running
on
L4
from typing import List | |
import json | |
import os | |
import random | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler | |
from mm_story_agent.modality_agents.llm import QwenAgent | |
from mm_story_agent.prompts_en import role_extract_system, role_review_system, \ | |
story_to_image_reviser_system, story_to_image_review_system | |
def setup_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
class AttnProcessor(torch.nn.Module): | |
r""" | |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | |
""" | |
def __init__( | |
self, | |
hidden_size=None, | |
cross_attention_dim=None, | |
): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def cal_attn_mask_xl(total_length, | |
id_length, | |
sa32, | |
sa64, | |
height, | |
width, | |
device="cuda", | |
dtype=torch.float16): | |
nums_1024 = (height // 32) * (width // 32) | |
nums_4096 = (height // 16) * (width // 16) | |
bool_matrix1024 = torch.rand((1, total_length * nums_1024),device = device,dtype = dtype) < sa32 | |
bool_matrix4096 = torch.rand((1, total_length * nums_4096),device = device,dtype = dtype) < sa64 | |
bool_matrix1024 = bool_matrix1024.repeat(total_length,1) | |
bool_matrix4096 = bool_matrix4096.repeat(total_length,1) | |
for i in range(total_length): | |
bool_matrix1024[i:i+1,id_length*nums_1024:] = False | |
bool_matrix4096[i:i+1,id_length*nums_4096:] = False | |
bool_matrix1024[i:i+1,i*nums_1024:(i+1)*nums_1024] = True | |
bool_matrix4096[i:i+1,i*nums_4096:(i+1)*nums_4096] = True | |
mask1024 = bool_matrix1024.unsqueeze(1).repeat(1,nums_1024,1).reshape(-1,total_length * nums_1024) | |
mask4096 = bool_matrix4096.unsqueeze(1).repeat(1,nums_4096,1).reshape(-1,total_length * nums_4096) | |
return mask1024, mask4096 | |
class SpatialAttnProcessor2_0(torch.nn.Module): | |
r""" | |
Attention processor for IP-Adapater for PyTorch 2.0. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
text_context_len (`int`, defaults to 77): | |
The context length of the text features. | |
scale (`float`, defaults to 1.0): | |
the weight scale of image prompt. | |
""" | |
def __init__(self, | |
global_attn_args, | |
hidden_size=None, | |
cross_attention_dim=None, | |
id_length=4, | |
device="cuda", | |
dtype=torch.float16, | |
height=1280, | |
width=720, | |
sa32=0.5, | |
sa64=0.5, | |
): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
self.device = device | |
self.dtype = dtype | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.total_length = id_length + 1 | |
self.id_length = id_length | |
self.id_bank = {} | |
self.height = height | |
self.width = width | |
self.sa32 = sa32 | |
self.sa64 = sa64 | |
self.write = True | |
self.global_attn_args = global_attn_args | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None | |
): | |
total_count = self.global_attn_args["total_count"] | |
attn_count = self.global_attn_args["attn_count"] | |
cur_step = self.global_attn_args["cur_step"] | |
mask1024 = self.global_attn_args["mask1024"] | |
mask4096 = self.global_attn_args["mask4096"] | |
if self.write: | |
self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]] | |
else: | |
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device), | |
hidden_states[:1], | |
self.id_bank[cur_step][1].to(self.device), hidden_states[1:])) | |
# skip in early step | |
if cur_step < 5: | |
hidden_states = self.__call2__(attn, hidden_states, encoder_hidden_states, attention_mask, temb) | |
else: # 256 1024 4096 | |
random_number = random.random() | |
if cur_step < 20: | |
rand_num = 0.3 | |
else: | |
rand_num = 0.1 | |
if random_number > rand_num: | |
if not self.write: | |
if hidden_states.shape[1] == (self.height // 32) * (self.width // 32): | |
attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:] | |
else: | |
attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:] | |
else: | |
if hidden_states.shape[1] == (self.height // 32) * (self.width // 32): | |
attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length, | |
:mask1024.shape[0] // self.total_length * self.id_length] | |
else: | |
attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length, | |
:mask4096.shape[0] // self.total_length * self.id_length] | |
hidden_states = self.__call1__(attn, hidden_states, encoder_hidden_states, attention_mask, temb) | |
else: | |
hidden_states = self.__call2__(attn, hidden_states, None, attention_mask, temb) | |
attn_count += 1 | |
if attn_count == total_count: | |
attn_count = 0 | |
cur_step += 1 | |
mask1024, mask4096 = cal_attn_mask_xl(self.total_length, | |
self.id_length, | |
self.sa32, | |
self.sa64, | |
self.height, | |
self.width, | |
device=self.device, | |
dtype=self.dtype) | |
self.global_attn_args["mask1024"] = mask1024 | |
self.global_attn_args["mask4096"] = mask4096 | |
self.global_attn_args["attn_count"] = attn_count | |
self.global_attn_args["cur_step"] = cur_step | |
return hidden_states | |
def __call1__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
total_batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2) | |
total_batch_size, nums_token, channel = hidden_states.shape | |
img_nums = total_batch_size // 2 | |
hidden_states = hidden_states.view(-1, img_nums, nums_token, channel).reshape(-1, img_nums * nums_token, channel) | |
batch_size, sequence_length, _ = hidden_states.shape | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states # B, N, C | |
else: | |
encoder_hidden_states = encoder_hidden_states.view(-1, self.id_length + 1, nums_token, channel).reshape( | |
-1, (self.id_length + 1) * nums_token, channel) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
# print(hidden_states.shape) | |
return hidden_states | |
def __call2__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, channel = ( | |
hidden_states.shape | |
) | |
# print(hidden_states.shape) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states # B, N, C | |
else: | |
encoder_hidden_states = encoder_hidden_states.view(-1, self.id_length + 1, sequence_length, channel).reshape( | |
-1, (self.id_length + 1) * sequence_length, channel) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class StoryDiffusionSynthesizer: | |
def __init__(self, | |
num_pages: int, | |
height: int, | |
width: int, | |
device: str, | |
model_name: str = "stabilityai/stable-diffusion-xl-base-1.0", | |
model_path: str = None, | |
id_length: int = 4, | |
num_steps: int = 50): | |
self.attn_args = { | |
"attn_count": 0, | |
"cur_step": 0, | |
"total_count": 0, | |
} | |
self.sa32 = 0.5 | |
self.sa64 = 0.5 | |
self.id_length = id_length | |
self.total_length = num_pages | |
self.height = height | |
self.width = width | |
self.device = device | |
self.dtype = torch.float16 | |
self.num_steps = num_steps | |
self.styles = { | |
'(No style)': ( | |
'{prompt}', | |
''), | |
'Japanese Anime': ( | |
'anime artwork illustrating {prompt}. created by japanese anime studio. highly emotional. best quality, high resolution, (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows', | |
'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'), | |
'Digital/Oil Painting': ( | |
'{prompt} . (Extremely Detailed Oil Painting:1.2), glow effects, godrays, Hand drawn, render, 8k, octane render, cinema 4d, blender, dark, atmospheric 4k ultra detailed, cinematic sensual, Sharp focus, humorous illustration, big depth of field', | |
'anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'), | |
'Pixar/Disney Character': ( | |
'Create a Disney Pixar 3D style illustration on {prompt} . The scene is vibrant, motivational, filled with vivid colors and a sense of wonder.', | |
'lowres, bad anatomy, bad hands, text, bad eyes, bad arms, bad legs, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry, grayscale, noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo'), | |
'Photographic': ( | |
'cinematic photo {prompt} . Hyperrealistic, Hyperdetailed, detailed skin, matte skin, soft lighting, realistic, best quality, ultra realistic, 8k, golden ratio, Intricate, High Detail, film photography, soft focus', | |
'drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'), | |
'Comic book': ( | |
'comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed', | |
'photograph, deformed, glitch, noisy, realistic, stock photo, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'), | |
'Line art': ( | |
'line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics', | |
'anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'), | |
'Black and White Film Noir': ( | |
'{prompt} . (b&w, Monochromatic, Film Photography:1.3), film noir, analog style, soft lighting, subsurface scattering, realistic, heavy shadow, masterpiece, best quality, ultra realistic, 8k', | |
'anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'), | |
'Isometric Rooms': ( | |
'Tiny cute isometric {prompt} . in a cutaway box, soft smooth lighting, soft colors, 100mm lens, 3d blender render', | |
'anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'), | |
'Storybook': ( | |
"Cartoon style, cute illustration of {prompt}.", | |
'realism, photo, realistic, lowres, bad hands, bad eyes, bad arms, bad legs, error, missing fingers, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, grayscale, noisy, sloppy, messy, grainy, ultra textured' | |
) | |
} | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
model_path if model_path is not None else model_name, | |
torch_dtype=torch.float16, | |
use_safetensors=True | |
) | |
pipe = pipe.to(self.device) | |
# pipe.id_encoder.to(self.device) | |
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
pipe.scheduler.set_timesteps(num_steps) | |
unet = pipe.unet | |
attn_procs = {} | |
### Insert PairedAttention | |
for name in unet.attn_processors.keys(): | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = unet.config.block_out_channels[block_id] | |
if cross_attention_dim is None and (name.startswith("up_blocks") ) : | |
attn_procs[name] = SpatialAttnProcessor2_0( | |
id_length=self.id_length, | |
device=self.device, | |
height=self.height, | |
width=self.width, | |
sa32=self.sa32, | |
sa64=self.sa64, | |
global_attn_args=self.attn_args | |
) | |
self.attn_args["total_count"] += 1 | |
else: | |
attn_procs[name] = AttnProcessor() | |
print("successsfully load consistent self-attention") | |
print(f"number of the processor : {self.attn_args['total_count']}") | |
# unet.set_attn_processor(copy.deepcopy(attn_procs)) | |
unet.set_attn_processor(attn_procs) | |
mask1024, mask4096 = cal_attn_mask_xl( | |
self.total_length, | |
self.id_length, | |
self.sa32, | |
self.sa64, | |
self.height, | |
self.width, | |
device=self.device, | |
dtype=torch.float16, | |
) | |
self.attn_args.update({ | |
"mask1024": mask1024, | |
"mask4096": mask4096 | |
}) | |
self.pipe = pipe | |
self.negative_prompt = "naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation," \ | |
"extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating" \ | |
"limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation" | |
def set_attn_write(self, | |
value: bool): | |
unet = self.pipe.unet | |
for name, processor in unet.attn_processors.items(): | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if cross_attention_dim is None: | |
if name.startswith("up_blocks") : | |
assert isinstance(processor, SpatialAttnProcessor2_0) | |
processor.write = value | |
def apply_style(self, style_name: str, positives: list, negative: str = ""): | |
p, n = self.styles.get(style_name, self.styles["(No style)"]) | |
return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative | |
def apply_style_positive(self, style_name: str, positive: str): | |
p, n = self.styles.get(style_name, self.styles["(No style)"]) | |
return p.replace("{prompt}", positive) | |
def call(self, | |
prompts: List[str], | |
input_id_images = None, | |
start_merge_step = None, | |
style_name: str = "Pixar/Disney Character", | |
guidance_scale: float = 5.0, | |
seed: int = 2047): | |
assert len(prompts) == self.total_length, "The number of prompts should be equal to the number of pages." | |
setup_seed(seed) | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
torch.cuda.empty_cache() | |
id_prompts = prompts[:self.id_length] | |
real_prompts = prompts[self.id_length:] | |
self.set_attn_write(True) | |
self.attn_args.update({ | |
"cur_step": 0, | |
"attn_count": 0 | |
}) | |
id_prompts, negative_prompt = self.apply_style(style_name, id_prompts, self.negative_prompt) | |
id_images = self.pipe( | |
id_prompts, | |
input_id_images=input_id_images, | |
start_merge_step=start_merge_step, | |
num_inference_steps=self.num_steps, | |
guidance_scale=guidance_scale, | |
height=self.height, | |
width=self.width, | |
negative_prompt=negative_prompt, | |
generator=generator).images | |
self.set_attn_write(False) | |
real_images = [] | |
for real_prompt in real_prompts: | |
self.attn_args["cur_step"] = 0 | |
real_prompt = self.apply_style_positive(style_name, real_prompt) | |
real_images.append(self.pipe( | |
real_prompt, | |
num_inference_steps=self.num_steps, | |
guidance_scale=guidance_scale, | |
height=self.height, | |
width=self.width, | |
negative_prompt=negative_prompt, | |
generator=generator).images[0] | |
) | |
images = id_images + real_images | |
return images | |
class StoryDiffusionAgent: | |
def __init__(self, config, llm_type="qwen2") -> None: | |
self.config = config | |
if llm_type == "qwen2": | |
self.LLM = QwenAgent | |
def call(self, pages: List, device: str, save_path: str): | |
role_dict = self.extract_role_from_story(pages, **self.config["revise_cfg"]) | |
image_prompts = self.generate_image_prompt_from_story(pages, **self.config["revise_cfg"]) | |
image_prompts_with_role_desc = [] | |
for image_prompt in image_prompts: | |
for role, role_desc in role_dict.items(): | |
if role in image_prompt: | |
image_prompt = image_prompt.replace(role, role_desc) | |
image_prompts_with_role_desc.append(image_prompt) | |
generation_agent = StoryDiffusionSynthesizer( | |
num_pages=len(pages), | |
device=device, | |
**self.config["obj_cfg"] | |
) | |
images = generation_agent.call( | |
image_prompts_with_role_desc, | |
**self.config["call_cfg"] | |
) | |
for idx, image in enumerate(images): | |
image.save(save_path / f"p{idx + 1}.png") | |
return { | |
"prompts": image_prompts_with_role_desc, | |
"modality": "image", | |
"generation_results": images, | |
} | |
def extract_role_from_story( | |
self, | |
pages: List, | |
num_turns: int = 3 | |
): | |
role_extractor = self.LLM(role_extract_system, track_history=False) | |
role_reviewer = self.LLM(role_review_system, track_history=False) | |
roles = {} | |
review = "" | |
for turn in range(num_turns): | |
roles, success = role_extractor.run(json.dumps({ | |
"story_content": pages, | |
"previous_result": roles, | |
"improvement_suggestions": review, | |
}, ensure_ascii=False | |
)) | |
roles = json.loads(roles.strip("```json").strip("```")) | |
review, success = role_reviewer.run(json.dumps({ | |
"story_content": pages, | |
"role_descriptions": roles | |
}, ensure_ascii=False)) | |
if review == "Check passed.": | |
break | |
return roles | |
def generate_image_prompt_from_story( | |
self, | |
pages: List, | |
num_turns: int = 3 | |
): | |
image_prompt_rewriter = self.LLM(story_to_image_reviser_system, track_history=False) | |
image_prompt_reviewer = self.LLM(story_to_image_review_system, track_history=False) | |
image_prompts = [] | |
for page in pages: | |
review = "" | |
image_prompt = "" | |
for turn in range(num_turns): | |
image_prompt, success = image_prompt_rewriter.run(json.dumps({ | |
"all_pages": pages, | |
"current_page": page, | |
"previous_result": image_prompt, | |
"improvement_suggestions": review, | |
}, ensure_ascii=False)) | |
if image_prompt.startswith("Image description:"): | |
image_prompt = image_prompt[len("Image description:"):] | |
review, success = image_prompt_reviewer.run(json.dumps({ | |
"all_pages": pages, | |
"current_page": page, | |
"image_description": image_prompt | |
}, ensure_ascii=False)) | |
if review == "Check passed.": | |
break | |
image_prompts.append(image_prompt) | |
return image_prompts | |