Pichayada commited on
Commit
316b182
·
verified ·
1 Parent(s): d6e8c46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -22
app.py CHANGED
@@ -1,30 +1,43 @@
1
- from diffusers import StableDiffusionPipeline
2
- from PIL import Image
3
  import torch
 
 
4
 
5
- def generate_image(prompt):
6
- """
7
- สร้างภาพจากข้อความที่กำหนดโดยใช้ Stable Diffusion.
 
 
 
8
 
9
- Args:
10
- prompt (str): ข้อความที่ต้องการแปลงเป็นภาพ.
11
 
12
- Returns:
13
- PIL.Image.Image: รูปภาพที่สร้างขึ้น.
14
- """
15
- # ตรวจสอบว่ามี CUDA (GPU) หรือไม่ และเลือกอุปกรณ์ที่เหมาะสม
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
17
 
18
- # โหลดโมเดล Stable Diffusion ที่เทรนไว้ล่วงหน้า
19
- model_id = "runwayml/stable-diffusion-v1-5"
20
- pipeline = StableDiffusionPipeline.from_pretrained(model_id).to(device)
21
 
22
- # สร้างภาพ
23
- image = pipeline(prompt).images[0]
 
 
24
  return image
25
 
26
- if __name__ == "__main__":
27
- prompt = "A cozy cat sleeping by a fireplace on a snowy evening."
28
- generated_image = generate_image(prompt)
29
- generated_image.save("cozy_cat.png")
30
- print("สร้างภาพเรียบร้อยแล้ว: cozy_cat.png")
 
 
 
 
 
 
 
1
+ import gradio as gr
 
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ import subprocess
5
 
6
+ # ติดตั้ง flash-attn แม้จะไม่ได้ใช้โดยตรง (ข้าม build CUDA)
7
+ subprocess.run(
8
+ 'pip install flash-attn --no-build-isolation',
9
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
10
+ shell=True
11
+ )
12
 
13
+ # ใช้ CPU
14
+ device = "cpu"
15
 
16
+ # โหลดโมเดลเบา: sd-turbo
17
+ pipe = StableDiffusionPipeline.from_pretrained(
18
+ "stabilityai/sd-turbo",
19
+ torch_dtype=torch.float32
20
+ )
21
+ pipe = pipe.to(device)
22
+ pipe.safety_checker = None # ปิด safety checker เพื่อความเร็ว
23
 
24
+ # เปิด attention_slicing (แม้บน CPU ก็ช่วยเรื่องหน่วยความจำ)
25
+ pipe.enable_attention_slicing()
 
26
 
27
+ # ฟังก์ชันสร้างภาพ
28
+ def generate_image(prompt):
29
+ result = pipe(prompt, num_inference_steps=10, guidance_scale=3.0)
30
+ image = result.images[0]
31
  return image
32
 
33
+ # Gradio UI
34
+ io = gr.Interface(
35
+ fn=generate_image,
36
+ inputs=[gr.Textbox(label="Enter your prompt")],
37
+ outputs=[gr.Image(label="Generated Image")],
38
+ theme="Yntec/HaleyCH_Theme_Orange",
39
+ description="⚠ Running on CPU using sd-turbo. Optimized for speed with low inference steps."
40
+ )
41
+
42
+ # เปิด Gradio ด้วย queue ป้องกันค้างถ้ามีหลายคำสั่ง
43
+ io.queue().launch(debug=True)