import os |
import sys |
import subprocess |
import io |
import torch.nn as nn |
from torch.nn import functional as F |
import torch |
import torchvision.transforms.functional as TF |
import torchvision.transforms as T |
import math |
import requests |
import cv2 |
from resize_right import resize |
from guided_diffusion.guided_diffusion.script_util import model_and_diffusion_defaults |
from types import SimpleNamespace |
from PIL import Image |
import argparse |
from guided_diffusion.guided_diffusion.unet import HFUNetModel |
from tqdm.notebook import tqdm |
from datetime import datetime |
from guided_diffusion.guided_diffusion.script_util import create_model_and_diffusion |
import clip |
from transformers import BertForSequenceClassification, BertTokenizer |
import gc |
import random |
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__)) |
useCPU = False |
skip_augs = False |
perlin_init = False |
use_secondary_model = False |
diffusion_model = "custom" |
side_x = 512 |
side_y = 512 |
diffusion_sampling_mode = 'ddim' |
use_checkpoint = True |
ViTB32 = False |
ViTB16 = False |
ViTL14 = True |
ViTL14_336px = False |
RN101 = False |
RN50 = False |
RN50x4 = False |
RN50x16 = False |
RN50x64 = False |
ViTB32_laion2b_e16 = False |
ViTB32_laion400m_e31 = False |
ViTB32_laion400m_32 = False |
ViTB32quickgelu_laion400m_e31 = False |
ViTB32quickgelu_laion400m_e32 = False |
ViTB16_laion400m_e31 = False |
ViTB16_laion400m_e32 = False |
RN50_yffcc15m = False |
RN50_cc12m = False |
RN50_quickgelu_yfcc15m = False |
RN50_quickgelu_cc12m = False |
RN101_yfcc15m = False |
RN101_quickgelu_yfcc15m = False |
steps = 100 |
tv_scale = 0 |
range_scale = 150 |
sat_scale = 0 |
cutn_batches = 1 |
skip_augs = False |
intermediate_saves = 0 |
intermediates_in_subfolder = True |
perlin_mode = 'mixed' |
set_seed = 'random_seed' |
eta = 0.8 |
clamp_grad = True |
clamp_max = 0.05 |
randomize_class = True |
clip_denoised = False |
fuzzy_prompt = False |
rand_mag = 0.05 |
cut_overview = "[12]*400+[4]*600" |
cut_innercut = "[4]*400+[12]*600" |
cut_ic_pow = "[1]*1000" |
cut_icgray_p = "[0.2]*400+[0]*600" |
use_vertical_symmetry = False |
use_horizontal_symmetry = False |
transformation_percent = [0.09] |
display_rate = 3 |
n_batches = 1 |
check_model_SHA = False |
interp_spline = 'Linear' |
resume_run = False |
batch_size = 1 |
def createPath(filepath): |
os.makedirs(filepath, exist_ok=True) |
def wget(url, outputdir): |
res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8') |
print(res) |
def alpha_sigma_to_t(alpha, sigma): |
return torch.atan2(sigma, alpha) * 2 / math.pi |
def interp(t): |
return 3 * t**2 - 2 * t ** 3 |
def perlin(width, height, scale=10, device=None): |
gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device) |
xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device) |
ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device) |
wx = 1 - interp(xs) |
wy = 1 - interp(ys) |
dots = 0 |
dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys) |
dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys) |
dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys)) |
dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys)) |
return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale) |
def perlin_ms(octaves, width, height, grayscale, device=None): |
out_array = [0.5] if grayscale else [0.5, 0.5, 0.5] |
for i in range(1 if grayscale else 3): |
scale = 2 ** len(octaves) |
oct_width = width |
oct_height = height |
for oct in octaves: |
p = perlin(oct_width, oct_height, scale, device) |
out_array[i] += p * oct |
scale //= 2 |
oct_width *= 2 |
oct_height *= 2 |
return torch.cat(out_array) |
def fetch(url_or_path): |
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): |
r = requests.get(url_or_path) |
r.raise_for_status() |
fd = io.BytesIO() |
fd.write(r.content) |
fd.seek(0) |
return fd |
return open(url_or_path, 'rb') |
def read_image_workaround(path): |
"""OpenCV reads images as BGR, Pillow saves them as RGB. Work around |
this incompatibility to avoid colour inversions.""" |
im_tmp = cv2.imread(path) |
return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB) |
def parse_prompt(prompt): |
if prompt.startswith('http://') or prompt.startswith('https://'): |
vals = prompt.rsplit(':', 2) |
vals = [vals[0] + ':' + vals[1], *vals[2:]] |
else: |
vals = prompt.rsplit(':', 1) |
vals = vals + ['', '1'][len(vals):] |
return vals[0], float(vals[1]) |
def sinc(x): |
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) |
def lanczos(x, a): |
cond = torch.logical_and(-a < x, x < a) |
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) |
return out / out.sum() |
def ramp(ratio, width): |
n = math.ceil(width / ratio + 1) |
out = torch.empty([n]) |
cur = 0 |
for i in range(out.shape[0]): |
out[i] = cur |
cur += ratio |
return torch.cat([-out[1:].flip([0]), out])[1:-1] |
def resample(input, size, align_corners=True): |
n, c, h, w = input.shape |
dh, dw = size |
input = input.reshape([n * c, 1, h, w]) |
if dh < h: |
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) |
pad_h = (kernel_h.shape[0] - 1) // 2 |
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') |
input = F.conv2d(input, kernel_h[None, None, :, None]) |
if dw < w: |
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) |
pad_w = (kernel_w.shape[0] - 1) // 2 |
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') |
input = F.conv2d(input, kernel_w[None, None, None, :]) |
input = input.reshape([n, c, h, w]) |
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) |
class MakeCutouts(nn.Module): |
def __init__(self, cut_size, cutn, skip_augs=False): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.skip_augs = skip_augs |
self.augs = T.Compose([ |
T.RandomHorizontalFlip(p=0.5), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomAffine(degrees=15, translate=(0.1, 0.1)), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomPerspective(distortion_scale=0.4, p=0.7), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomGrayscale(p=0.15), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
]) |
def forward(self, input): |
input = T.Pad(input.shape[2] // 4, fill=0)(input) |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
cutouts = [] |
for ch in range(self.cutn): |
if ch > self.cutn - self.cutn // 4: |
cutout = input.clone() |
else: |
size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size / max_size), 1.)) |
offsetx = torch.randint(0, abs(sideX - size + 1), ()) |
offsety = torch.randint(0, abs(sideY - size + 1), ()) |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
if not self.skip_augs: |
cutout = self.augs(cutout) |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
del cutout |
cutouts = torch.cat(cutouts, dim=0) |
return cutouts |
class MakeCutoutsDango(nn.Module): |
def __init__(self, cut_size, args, |
Overview=4, |
InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2, |
): |
super().__init__() |
self.padargs = {} |
self.cutout_debug = False |
self.cut_size = cut_size |
self.Overview = Overview |
self.InnerCrop = InnerCrop |
self.IC_Size_Pow = IC_Size_Pow |
self.IC_Grey_P = IC_Grey_P |
self.augs = T.Compose([ |
T.RandomHorizontalFlip(p=0.5), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation=T.InterpolationMode.BILINEAR), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomGrayscale(p=0.1), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), |
]) |
def forward(self, input): |
cutouts = [] |
gray = T.Grayscale(3) |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
output_shape = [1, 3, self.cut_size, self.cut_size] |
pad_input = F.pad(input, ((sideY - max_size) // 2, (sideY - max_size) // 2, (sideX - max_size) // 2, (sideX - max_size) // 2), **self.padargs) |
cutout = resize(pad_input, out_shape=output_shape) |
if self.Overview > 0: |
if self.Overview <= 4: |
if self.Overview >= 1: |
cutouts.append(cutout) |
if self.Overview >= 2: |
cutouts.append(gray(cutout)) |
if self.Overview >= 3: |
cutouts.append(TF.hflip(cutout)) |
if self.Overview == 4: |
cutouts.append(gray(TF.hflip(cutout))) |
else: |
cutout = resize(pad_input, out_shape=output_shape) |
for _ in range(self.Overview): |
cutouts.append(cutout) |
if self.cutout_debug: |
TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("cutout_overview0.jpg", quality=99) |
if self.InnerCrop > 0: |
for i in range(self.InnerCrop): |
size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
if i <= int(self.IC_Grey_P * self.InnerCrop): |
cutout = gray(cutout) |
cutout = resize(cutout, out_shape=output_shape) |
cutouts.append(cutout) |
if self.cutout_debug: |
TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("cutout_InnerCrop.jpg", quality=99) |
cutouts = torch.cat(cutouts) |
if skip_augs is not True: |
cutouts = self.augs(cutouts) |
return cutouts |
def spherical_dist_loss(x, y): |
x = F.normalize(x, dim=-1) |
y = F.normalize(y, dim=-1) |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) |
def tv_loss(input): |
"""L2 total variation loss, as in Mahendran et al.""" |
input = F.pad(input, (0, 1, 0, 1), 'replicate') |
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] |
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] |
return (x_diff**2 + y_diff**2).mean([1, 2, 3]) |
def range_loss(input): |
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) |
def symmetry_transformation_fn(x): |
use_horizontal_symmetry = False |
if use_horizontal_symmetry: |
[n, c, h, w] = x.size() |
x = torch.concat((x[:, :, :, :w // 2], torch.flip(x[:, :, :, :w // 2], [-1])), -1) |
print("horizontal symmetry applied") |
if use_vertical_symmetry: |
[n, c, h, w] = x.size() |
x = torch.concat((x[:, :, :h // 2, :], torch.flip(x[:, :, :h // 2, :], [-2])), -2) |
print("vertical symmetry applied") |
return x |
""" |
other chaos settings |
""" |
outDirPath = f'{PROJECT_DIR}/images_out' |
createPath(outDirPath) |
model_path = f'{PROJECT_DIR}/models' |
createPath(model_path) |
DEVICE = torch.device('cuda:0' if (torch.cuda.is_available() and not useCPU) else 'cpu') |
print('Using device:', DEVICE) |
device = DEVICE |
if not useCPU: |
if torch.cuda.get_device_capability(DEVICE) == (8, 0): |
print('Disabling CUDNN for A100 gpu', file=sys.stderr) |
torch.backends.cudnn.enabled = False |
model_config = model_and_diffusion_defaults() |
model_config.update({ |
'attention_resolutions': '32, 16, 8', |
'class_cond': False, |
'diffusion_steps': 1000, |
'rescale_timesteps': True, |
'timestep_respacing': 250, |
'image_size': 512, |
'learn_sigma': True, |
'noise_schedule': 'linear', |
'num_channels': 256, |
'num_head_channels': 64, |
'num_res_blocks': 2, |
'resblock_updown': True, |
'use_checkpoint': use_checkpoint, |
'use_fp16': not useCPU, |
'use_scale_shift_norm': True, |
}) |
model_default = model_config['image_size'] |
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) |
steps_per_checkpoint = steps + 10 |
timestep_respacing = f'ddim{steps}' |
diffusion_steps = (1000 // steps) * steps if steps < 1000 else steps |
model_config.update({ |
'timestep_respacing': timestep_respacing, |
'diffusion_steps': diffusion_steps, |
}) |
start_frame = 0 |
print('Starting Run:') |
if set_seed == 'random_seed': |
random.seed() |
seed = random.randint(0, 2**32) |
else: |
seed = int(set_seed) |
args = { |
'display_rate': display_rate, |
'n_batches': n_batches, |
'batch_size': batch_size, |
'steps': steps, |
'diffusion_sampling_mode': diffusion_sampling_mode, |
'tv_scale': tv_scale, |
'range_scale': range_scale, |
'sat_scale': sat_scale, |
'cutn_batches': cutn_batches, |
'timestep_respacing': timestep_respacing, |
'diffusion_steps': diffusion_steps, |
'cut_overview': eval(cut_overview), |
'cut_innercut': eval(cut_innercut), |
'cut_ic_pow': eval(cut_ic_pow), |
'cut_icgray_p': eval(cut_icgray_p), |
'intermediate_saves': intermediate_saves, |
'intermediates_in_subfolder': intermediates_in_subfolder, |
'steps_per_checkpoint': steps_per_checkpoint, |
'set_seed': set_seed, |
'eta': eta, |
'clamp_grad': clamp_grad, |
'clamp_max': clamp_max, |
'skip_augs': skip_augs, |
'randomize_class': randomize_class, |
'clip_denoised': clip_denoised, |
'fuzzy_prompt': fuzzy_prompt, |
'rand_mag': rand_mag, |
'use_vertical_symmetry': use_vertical_symmetry, |
'use_horizontal_symmetry': use_horizontal_symmetry, |
'transformation_percent': transformation_percent, |
} |
args = SimpleNamespace(**args) |
class Diffuser: |
def __init__(self, cutom_path='IDEA-CCNL/Taiyi-Diffusion-532M-Nature'): |
self.model_setup(cutom_path) |
def model_setup(self, custom_path): |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' |
print(f'Prepping model...model name: {custom_path}') |
__, self.diffusion = create_model_and_diffusion(**model_config) |
self.model = HFUNetModel.from_pretrained(custom_path) |
self.model.requires_grad_(False).eval().to(device) |
for name, param in self.model.named_parameters(): |
if 'qkv' in name or 'norm' in name or 'proj' in name: |
param.requires_grad_() |
if model_config['use_fp16']: |
self.model.convert_to_fp16() |
print(f'Diffusion_model Loaded {diffusion_model}') |
print('Prepping model...model name: CLIP') |
self.taiyi_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese") |
self.taiyi_transformer = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").eval().to(device) |
self.clip_models = [] |
if ViTB32: |
self.clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) |
if ViTB16: |
self.clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)) |
if ViTL14: |
self.clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device)) |
if ViTL14_336px: |
self.clip_models.append(clip.load('ViT-L/14@336px', jit=False)[0].eval().requires_grad_(False).to(device)) |
print('CLIP Loaded') |
def generate(self, |
input_text_prompts=['夕阳西下'], |
init_image=None, |
skip_steps=10, |
clip_guidance_scale=7500, |
init_scale=2000, |
st_dynamic_image=None, |
seed=None, |
side_x=512, |
side_y=512, |
): |
seed = seed |
frame_num = 0 |
init_image = init_image |
init_scale = init_scale |
skip_steps = skip_steps |
loss_values = [] |
frame_prompt = input_text_prompts |
print(f'Frame {frame_num} Prompt: {frame_prompt}') |
model_stats = [] |
for clip_model in self.clip_models: |
model_stat = {"clip_model": None, "target_embeds": [], "make_cutouts": None, "weights": []} |
model_stat["clip_model"] = clip_model |
for prompt in frame_prompt: |
txt, weight = parse_prompt(prompt) |
txt = self.taiyi_transformer(self.taiyi_tokenizer(txt, return_tensors='pt')['input_ids'].to(device)).logits |
if args.fuzzy_prompt: |
for i in range(25): |
model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0, 1)) |
model_stat["weights"].append(weight) |
else: |
model_stat["target_embeds"].append(txt) |
model_stat["weights"].append(weight) |
model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"]) |
model_stat["weights"] = torch.tensor(model_stat["weights"], device=device) |
if model_stat["weights"].sum().abs() < 1e-3: |
raise RuntimeError('The weights must not sum to 0.') |
model_stat["weights"] /= model_stat["weights"].sum().abs() |
model_stats.append(model_stat) |
init = None |
if init_image is not None: |
init = init_image |
init = init.resize((side_x, side_y), Image.LANCZOS) |
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1) |
cur_t = None |
def cond_fn(x, t, y=None): |
with torch.enable_grad(): |
x_is_NaN = False |
x = x.detach().requires_grad_() |
n = x.shape[0] |
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t |
out = self.diffusion.p_mean_variance(self.model, x, my_t, clip_denoised=False, model_kwargs={'y': y}) |
fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t] |
x_in = out['pred_xstart'] * fac + x * (1 - fac) |
x_in_grad = torch.zeros_like(x_in) |
for model_stat in model_stats: |
for i in range(args.cutn_batches): |
t_int = int(t.item()) + 1 |
input_resolution = model_stat["clip_model"].visual.input_resolution |
cuts = MakeCutoutsDango(input_resolution, |
Overview=args.cut_overview[1000 - t_int], |
InnerCrop=args.cut_innercut[1000 - t_int], |
IC_Size_Pow=args.cut_ic_pow[1000 - t_int], |
IC_Grey_P=args.cut_icgray_p[1000 - t_int], |
args=args, |
) |
clip_in = normalize(cuts(x_in.add(1).div(2))) |
image_embeds = model_stat["clip_model"].encode_image(clip_in).float() |
dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat["target_embeds"].unsqueeze(0)) |
dists = dists.view([args.cut_overview[1000 - t_int] + args.cut_innercut[1000 - t_int], n, -1]) |
losses = dists.mul(model_stat["weights"]).sum(2).mean(0) |
loss_values.append(losses.sum().item()) |
x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches |
tv_losses = tv_loss(x_in) |
range_losses = range_loss(out['pred_xstart']) |
sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean() |
loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale |
if init is not None and init_scale: |
init_losses = self.lpips_model(x_in, init) |
loss = loss + init_losses.sum() * init_scale |
x_in_grad += torch.autograd.grad(loss, x_in)[0] |
if not torch.isnan(x_in_grad).any(): |
grad = -torch.autograd.grad(x_in, x, x_in_grad)[0] |
else: |
x_is_NaN = True |
grad = torch.zeros_like(x) |
if args.clamp_grad and not x_is_NaN: |
magnitude = grad.square().mean().sqrt() |
return grad * magnitude.clamp(max=args.clamp_max) / magnitude |
return grad |
if args.diffusion_sampling_mode == 'ddim': |
sample_fn = self.diffusion.ddim_sample_loop_progressive |
else: |
sample_fn = self.diffusion.plms_sample_loop_progressive |
for i in range(args.n_batches): |
current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f') |
batchBar = tqdm(range(args.n_batches), desc="Batches") |
batchBar.n = i |
batchBar.refresh() |
gc.collect() |
torch.cuda.empty_cache() |
cur_t = self.diffusion.num_timesteps - skip_steps - 1 |
if args.diffusion_sampling_mode == 'ddim': |
samples = sample_fn( |
self.model, |
(batch_size, 3, side_y, side_x), |
clip_denoised=clip_denoised, |
model_kwargs={}, |
cond_fn=cond_fn, |
progress=True, |
skip_timesteps=skip_steps, |
init_image=init, |
randomize_class=randomize_class, |
eta=eta, |
transformation_fn=symmetry_transformation_fn, |
transformation_percent=args.transformation_percent |
) |
else: |
samples = sample_fn( |
self.model, |
(batch_size, 3, side_y, side_x), |
clip_denoised=clip_denoised, |
model_kwargs={}, |
cond_fn=cond_fn, |
progress=True, |
skip_timesteps=skip_steps, |
init_image=init, |
randomize_class=randomize_class, |
order=2, |
) |
for j, sample in enumerate(samples): |
cur_t -= 1 |
intermediateStep = False |
if args.steps_per_checkpoint is not None: |
if j % steps_per_checkpoint == 0 and j > 0: |
intermediateStep = True |
elif j in args.intermediate_saves: |
intermediateStep = True |
if j % args.display_rate == 0 or cur_t == -1 or intermediateStep: |
for k, image in enumerate(sample['pred_xstart']): |
if args.n_batches > 0: |
filename = f'{current_time}-{parse_prompt(prompt)[0]}.png' |
image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) |
if j % args.display_rate == 0 or cur_t == -1: |
image.save(f'{outDirPath}/{filename}') |
if st_dynamic_image: |
st_dynamic_image.image(image, use_column_width=True) |
return image |
if __name__ == '__main__': |
parser = argparse.ArgumentParser(description="setting") |
parser.add_argument('--prompt', type=str, required=True) |
parser.add_argument('--text_scale', type=int, default=5000) |
parser.add_argument('--model_path', type=str, default="IDEA-CCNL/Taiyi-Diffusion-532M-Nature") |
parser.add_argument('--width', type=int, default=512) |
parser.add_argument('--height', type=int, default=512) |
user_args = parser.parse_args() |
dd = Diffuser(user_args.model_path) |
dd.generate([user_args.prompt], |
clip_guidance_scale=user_args.text_scale, |
side_x=user_args.width, |
side_y=user_args.height, |
) |