Update app.py
Browse files
app.py
CHANGED
@@ -5,14 +5,18 @@ import gradio as gr
|
|
5 |
|
6 |
# 模型和 LoRA 權重的 URL
|
7 |
base_model = "stabilityai/stable-diffusion-xl-base-1.0" # 基礎模型
|
8 |
-
lora_weights_url = "https://huggingface.co/hyder133/chiikawa_stype/resolve/main/tkw1.safetensors"
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# 加載基礎模型
|
11 |
-
pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=
|
12 |
-
pipe = pipe.to(
|
13 |
|
14 |
# 加載 LoRA 權重
|
15 |
-
pipe.load_lora_weights(lora_weights_url)
|
16 |
|
17 |
# 測試生成函數
|
18 |
def generate_image(prompt, width, height, steps, guidance_scale):
|
|
|
5 |
|
6 |
# 模型和 LoRA 權重的 URL
|
7 |
base_model = "stabilityai/stable-diffusion-xl-base-1.0" # 基礎模型
|
8 |
+
lora_weights_url = "https://huggingface.co/hyder133/chiikawa_stype/resolve/main/tkw1.safetensors"
|
9 |
+
|
10 |
+
# 嘗試加載到 GPU,如果不可用則回退到 CPU
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
13 |
|
14 |
# 加載基礎模型
|
15 |
+
pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)
|
16 |
+
pipe = pipe.to(device)
|
17 |
|
18 |
# 加載 LoRA 權重
|
19 |
+
pipe.load_lora_weights(lora_weights_url)
|
20 |
|
21 |
# 測試生成函數
|
22 |
def generate_image(prompt, width, height, steps, guidance_scale):
|