Update app.py
Browse files
app.py
CHANGED
@@ -1,32 +1,70 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
import io
|
6 |
|
7 |
+
# =====================
|
8 |
+
# Load AI model
|
9 |
+
# =====================
|
10 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
11 |
+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
12 |
+
).to("cuda") # use "cpu" if GPU unavailable
|
13 |
|
14 |
+
# =====================
|
15 |
+
# Helper: Detect language
|
16 |
+
# =====================
|
17 |
+
def detect_language(text):
|
18 |
+
if any("\u0D80" <= c <= "\u0DFF" for c in text): # Sinhala range
|
19 |
+
return "Sinhala"
|
20 |
+
elif any("\u0B80" <= c <= "\u0BFF" for c in text): # Tamil range
|
21 |
+
return "Tamil"
|
22 |
+
else:
|
23 |
+
return "English"
|
24 |
|
25 |
+
# =====================
|
26 |
+
# Generate AI post
|
27 |
+
# =====================
|
28 |
+
def generate_post(text, main_image=None, logo=None):
|
29 |
+
language = detect_language(text)
|
30 |
+
|
31 |
+
# AI background generation
|
32 |
+
prompt = f"Beautiful social media post, modern design, colors, shapes"
|
33 |
+
image = pipe(prompt, height=512, width=512).images[0]
|
34 |
|
35 |
+
# Overlay main image
|
36 |
+
if main_image is not None:
|
37 |
+
main_img = Image.open(main_image).convert("RGBA").resize((200,200))
|
38 |
+
image.paste(main_img, (50,50), main_img)
|
39 |
+
|
40 |
+
# Overlay logo
|
41 |
+
if logo is not None:
|
42 |
+
logo_img = Image.open(logo).convert("RGBA").resize((100,100))
|
43 |
+
image.paste(logo_img, (image.width-120,image.height-120), logo_img)
|
44 |
+
|
45 |
+
# Add text
|
46 |
+
draw = ImageDraw.Draw(image)
|
47 |
+
font_path = "fonts/FMAbhaya.ttf" if language=="Sinhala" else "fonts/Bamini.ttf" if language=="Tamil" else "fonts/Roboto-Regular.ttf"
|
48 |
+
try:
|
49 |
+
font = ImageFont.truetype(font_path, 40)
|
50 |
+
except:
|
51 |
+
font = ImageFont.load_default()
|
52 |
+
text_w, text_h = draw.textsize(text, font=font)
|
53 |
+
draw.text(((image.width-text_w)/2, image.height-80), text, fill=(255,255,255), font=font)
|
54 |
+
|
55 |
+
return image
|
56 |
|
57 |
+
# =====================
|
58 |
+
# Gradio UI
|
59 |
+
# =====================
|
60 |
+
with gr.Blocks() as demo:
|
61 |
+
gr.Markdown("# AI Social Media Post Generator")
|
62 |
+
with gr.Row():
|
63 |
+
text_input = gr.Textbox(label="Enter Text")
|
64 |
+
main_img_input = gr.Image(label="Main Image (Optional)", type="pil", optional=True)
|
65 |
+
logo_input = gr.Image(label="Logo (Optional)", type="pil", optional=True)
|
66 |
+
output = gr.Image(label="Generated Post")
|
67 |
+
generate_btn = gr.Button("Generate")
|
68 |
+
generate_btn.click(generate_post, inputs=[text_input, main_img_input, logo_input], outputs=output)
|
69 |
|
70 |
demo.launch()
|