File size: 2,914 Bytes
ed85b56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# modules/model.py
import os
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from transformers import AutoencoderKL

def get_checkpoints(folder):
    checkpoints = []
    for file in os.listdir(folder):
        if file.endswith(('.safetensors', '.ckpt', '.pt', '.pth')):
            checkpoints.append(file)
    return checkpoints

def load_model(checkpoint, vae, checkpoint_folder, vae_folder):
    # Memilih pipeline yang sesuai
    if "sdxl" in checkpoint.lower():
        pipeline_class = StableDiffusionXLPipeline
    else:
        pipeline_class = StableDiffusionPipeline

    # Load checkpoint
    if checkpoint in get_checkpoints(checkpoint_folder):
        checkpoint_path = os.path.join(checkpoint_folder, checkpoint)
        try:
            model = pipeline_class.from_single_file(checkpoint_path, torch_dtype=torch.float16)
        except Exception as e:
            model = pipeline_class.from_pretrained(checkpoint_path, torch_dtype=torch.float16)
    else:
        if checkpoint.startswith("http"):
            try:
                model = pipeline_class.from_single_file(checkpoint, torch_dtype=torch.float16)
            except Exception as e:
                model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)
        else:
            model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)

    # Load VAE
    if vae != "none":
        if vae in get_checkpoints(vae_folder):
            vae_path = os.path.join(vae_folder, vae)
            vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16)
        else:
            vae_model = AutoencoderKL.from_pretrained(vae, torch_dtype=torch.float16)
        model.vae = vae_model

    return model

def get_model_and_vae_options():
    checkpoint_folder = "../models/checkpoint/"
    vae_folder = "../models/vae/"
    model_file = "../models/models.py"

    # Membaca model dan VAE dari models/model.py
    exec(open(model_file).read())

    # Mendapatkan daftar checkpoint dan VAE dari folder
    checkpoints = get_checkpoints(checkpoint_folder)
    vae_files = get_checkpoints(vae_folder)

    # Menggabungkan daftar checkpoint, model Diffusers, dan VAE
    all_models = checkpoints + diffusers
    all_vaes = ["none"] + vae_files + vae
    
    # Mengubah format dropdown
    formatted_models = [os.path.basename(model) if not model.startswith("http") else model for model in all_models]
    formatted_vaes = [os.path.basename(vae) if not vae.startswith("http") else vae for vae in all_vaes]

    return formatted_models, formatted_vaes

# Wrapper untuk fungsi generate_image di text2img
def generate_image(text, checkpoint, vae):
    checkpoint_folder = "../models/checkpoint/"
    vae_folder = "../models/vae/"
    model = load_model(checkpoint, vae, checkpoint_folder, vae_folder)
    image = model([text])[0]
    return image