tts / app.py
dshamika's picture
Update app.py
d289240 verified
raw
history blame
2.44 kB
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image, ImageDraw, ImageFont
import io
# =====================
# Load AI model
# =====================
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda") # use "cpu" if GPU unavailable
# =====================
# Helper: Detect language
# =====================
def detect_language(text):
if any("\u0D80" <= c <= "\u0DFF" for c in text): # Sinhala range
return "Sinhala"
elif any("\u0B80" <= c <= "\u0BFF" for c in text): # Tamil range
return "Tamil"
else:
return "English"
# =====================
# Generate AI post
# =====================
def generate_post(text, main_image=None, logo=None):
language = detect_language(text)
# AI background generation
prompt = f"Beautiful social media post, modern design, colors, shapes"
image = pipe(prompt, height=512, width=512).images[0]
# Overlay main image
if main_image is not None:
main_img = Image.open(main_image).convert("RGBA").resize((200,200))
image.paste(main_img, (50,50), main_img)
# Overlay logo
if logo is not None:
logo_img = Image.open(logo).convert("RGBA").resize((100,100))
image.paste(logo_img, (image.width-120,image.height-120), logo_img)
# Add text
draw = ImageDraw.Draw(image)
font_path = "fonts/FMAbhaya.ttf" if language=="Sinhala" else "fonts/Bamini.ttf" if language=="Tamil" else "fonts/Roboto-Regular.ttf"
try:
font = ImageFont.truetype(font_path, 40)
except:
font = ImageFont.load_default()
text_w, text_h = draw.textsize(text, font=font)
draw.text(((image.width-text_w)/2, image.height-80), text, fill=(255,255,255), font=font)
return image
# =====================
# Gradio UI
# =====================
with gr.Blocks() as demo:
gr.Markdown("# AI Social Media Post Generator")
with gr.Row():
text_input = gr.Textbox(label="Enter Text")
main_img_input = gr.Image(label="Main Image (Optional)", type="pil", optional=True)
logo_input = gr.Image(label="Logo (Optional)", type="pil", optional=True)
output = gr.Image(label="Generated Post")
generate_btn = gr.Button("Generate")
generate_btn.click(generate_post, inputs=[text_input, main_img_input, logo_input], outputs=output)
demo.launch()