File size: 1,903 Bytes
df60ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os

def test_lora(lcm_speedup=False):
    pipe = DiffusionPipeline.from_pretrained("/mnt/workspace/Project/VideoGen/ckpt/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False)
    pipe.to("cuda")

    lora_path = "/mnt/workspace/Project/VideoGen/TrainT2V/diffusers/examples/text_to_image/results" 
    pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="white_100epoch_lora.safetensors", adapter_name="pattern")
    if lcm_speedup:
        pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="lcm_lora.safetensors", adapter_name="lcm")
        pipe.set_adapters(["artifacts", "lcm"], adapter_weights=[1.0, 1.0])

    prompts = [
        "A central pure white artifact with a fluid, swirling form, resembling a twisted cloth. The artifact is highly detailed with intricate textures that catch the light, creating a sense of depth and realism. Set against a soft, out-of-focus black background that transitions from light to dark, the artifact stands out prominently."
    ]

    # 设置生成参数
    if lcm_speedup:
        num_inference_steps = 8
        guidance_scale = 2
    else:
        num_inference_steps = 30
        guidance_scale = 7.5
        
    num_samples_per_prompt = 10  # 生成10张图片

    # 确保images文件夹存在
    if not os.path.exists("images_result"):
        os.makedirs("images_result")

    # 为每个 prompt 生成 num_samples_per_prompt 张图片
    for i, prompt in enumerate(prompts):
        for j in range(num_samples_per_prompt):
            image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
            image.save(f"images_result/{i}_{j}.png")  # 保存图片

if __name__ == "__main__":
    test_lora()