yuezhengrong commited on
Commit
df60ba7
·
verified ·
1 Parent(s): 7a9204e

Create generate_white.py

Browse files
Files changed (1) hide show
  1. generate_white.py +41 -0
generate_white.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+
6
+ def test_lora(lcm_speedup=False):
7
+ pipe = DiffusionPipeline.from_pretrained("/mnt/workspace/Project/VideoGen/ckpt/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False)
8
+ pipe.to("cuda")
9
+
10
+ lora_path = "/mnt/workspace/Project/VideoGen/TrainT2V/diffusers/examples/text_to_image/results"
11
+ pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="white_100epoch_lora.safetensors", adapter_name="pattern")
12
+ if lcm_speedup:
13
+ pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="lcm_lora.safetensors", adapter_name="lcm")
14
+ pipe.set_adapters(["artifacts", "lcm"], adapter_weights=[1.0, 1.0])
15
+
16
+ prompts = [
17
+ "A central pure white artifact with a fluid, swirling form, resembling a twisted cloth. The artifact is highly detailed with intricate textures that catch the light, creating a sense of depth and realism. Set against a soft, out-of-focus black background that transitions from light to dark, the artifact stands out prominently."
18
+ ]
19
+
20
+ # 设置生成参数
21
+ if lcm_speedup:
22
+ num_inference_steps = 8
23
+ guidance_scale = 2
24
+ else:
25
+ num_inference_steps = 30
26
+ guidance_scale = 7.5
27
+
28
+ num_samples_per_prompt = 10 # 生成10张图片
29
+
30
+ # 确保images文件夹存在
31
+ if not os.path.exists("images_result"):
32
+ os.makedirs("images_result")
33
+
34
+ # 为每个 prompt 生成 num_samples_per_prompt 张图片
35
+ for i, prompt in enumerate(prompts):
36
+ for j in range(num_samples_per_prompt):
37
+ image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
38
+ image.save(f"images_result/{i}_{j}.png") # 保存图片
39
+
40
+ if __name__ == "__main__":
41
+ test_lora()