|
import streamlit as st
|
|
import torch
|
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
|
from diffusers import StableDiffusionPipeline
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
@st.cache_resource
|
|
def load_text_model():
|
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
|
return pipeline("text-generation", model=model, tokenizer=tokenizer)
|
|
|
|
story_generator = load_text_model()
|
|
|
|
|
|
@st.cache_resource
|
|
def load_image_model():
|
|
model_id = "runwayml/stable-diffusion-v1-5"
|
|
return StableDiffusionPipeline.from_pretrained(model_id).to(device)
|
|
|
|
image_generator = load_image_model()
|
|
|
|
|
|
def generate_story(prompt):
|
|
formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
|
|
story_output = story_generator(
|
|
formatted_prompt,
|
|
max_length=250,
|
|
do_sample=True,
|
|
temperature=0.7,
|
|
top_k=50,
|
|
num_return_sequences=1,
|
|
truncation = True
|
|
)[0]['generated_text']
|
|
|
|
return story_output.replace(formatted_prompt, "").strip()
|
|
|
|
|
|
def add_speech_bubble(image, text, position=(50, 50)):
|
|
draw = ImageDraw.Draw(image)
|
|
|
|
try:
|
|
font = ImageFont.truetype("arial.ttf", 20)
|
|
except IOError:
|
|
font = ImageFont.load_default()
|
|
|
|
text_bbox = draw.textbbox((0, 0), text, font=font)
|
|
text_width = text_bbox[2] - text_bbox[0]
|
|
text_height = text_bbox[3] - text_bbox[1]
|
|
|
|
bubble_width, bubble_height = text_width + 30, text_height + 20
|
|
bubble_x, bubble_y = position
|
|
|
|
draw.ellipse([bubble_x, bubble_y, bubble_x + bubble_width, bubble_y + bubble_height], fill="white", outline="black")
|
|
draw.text((bubble_x + 15, bubble_y + 10), text, font=font, fill="black")
|
|
|
|
return image
|
|
|
|
|
|
st.title("π¦ΈββοΈ AI Comic Story Generator")
|
|
st.write("Enter a prompt to generate a comic-style story and image!")
|
|
|
|
|
|
user_prompt = st.text_input("π Enter your story prompt:")
|
|
|
|
if user_prompt:
|
|
st.subheader("π AI-Generated Story")
|
|
generated_story = generate_story(user_prompt)
|
|
st.write(generated_story)
|
|
|
|
st.subheader("πΌοΈ AI-Generated Image")
|
|
with st.spinner("Generating image..."):
|
|
image = image_generator(user_prompt, num_inference_steps=30).images[0]
|
|
|
|
speech_text = generated_story.split(".")[0][:50]
|
|
image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
|
|
|
|
st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)
|
|
|