comicVala / app.py
bh4vay's picture
Upload app.py
91e4459 verified
import streamlit as st
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
from PIL import Image, ImageDraw, ImageFont
# Check if CUDA is available for GPU acceleration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the text generation model (TinyLlama)
@st.cache_resource
def load_text_model():
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
story_generator = load_text_model()
# Load the image generation model (Stable Diffusion Turbo)
@st.cache_resource
def load_image_model():
model_id = "runwayml/stable-diffusion-v1-5"
return StableDiffusionPipeline.from_pretrained(model_id).to(device)
image_generator = load_image_model()
# Function to generate a short story
def generate_story(prompt):
formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
story_output = story_generator(
formatted_prompt,
max_length=250, # Short story length
do_sample=True,
temperature=0.7,
top_k=50,
num_return_sequences=1,
truncation = True
)[0]['generated_text']
return story_output.replace(formatted_prompt, "").strip()
# Function to add a speech bubble to the image
def add_speech_bubble(image, text, position=(50, 50)):
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
bubble_width, bubble_height = text_width + 30, text_height + 20
bubble_x, bubble_y = position
draw.ellipse([bubble_x, bubble_y, bubble_x + bubble_width, bubble_y + bubble_height], fill="white", outline="black")
draw.text((bubble_x + 15, bubble_y + 10), text, font=font, fill="black")
return image
# Streamlit UI
st.title("πŸ¦Έβ€β™‚οΈ AI Comic Story Generator")
st.write("Enter a prompt to generate a comic-style story and image!")
# User input
user_prompt = st.text_input("πŸ“ Enter your story prompt:")
if user_prompt:
st.subheader("πŸ“– AI-Generated Story")
generated_story = generate_story(user_prompt)
st.write(generated_story)
st.subheader("πŸ–ΌοΈ AI-Generated Image")
with st.spinner("Generating image..."):
image = image_generator(user_prompt, num_inference_steps=30).images[0]
speech_text = generated_story.split(".")[0][:50]
image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)