import streamlit as st import torch from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer from diffusers import StableDiffusionPipeline from PIL import Image, ImageDraw, ImageFont # Check if CUDA is available for GPU acceleration device = "cuda" if torch.cuda.is_available() else "cpu" # Load the text generation model (TinyLlama) @st.cache_resource def load_text_model(): model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name).to(device) return pipeline("text-generation", model=model, tokenizer=tokenizer) story_generator = load_text_model() # Load the image generation model (Stable Diffusion Turbo) @st.cache_resource def load_image_model(): model_id = "runwayml/stable-diffusion-v1-5" return StableDiffusionPipeline.from_pretrained(model_id).to(device) image_generator = load_image_model() # Function to generate a short story def generate_story(prompt): formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:" story_output = story_generator( formatted_prompt, max_length=250, # Short story length do_sample=True, temperature=0.7, top_k=50, num_return_sequences=1, truncation = True )[0]['generated_text'] return story_output.replace(formatted_prompt, "").strip() # Function to add a speech bubble to the image def add_speech_bubble(image, text, position=(50, 50)): draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("arial.ttf", 20) except IOError: font = ImageFont.load_default() text_bbox = draw.textbbox((0, 0), text, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] bubble_width, bubble_height = text_width + 30, text_height + 20 bubble_x, bubble_y = position draw.ellipse([bubble_x, bubble_y, bubble_x + bubble_width, bubble_y + bubble_height], fill="white", outline="black") draw.text((bubble_x + 15, bubble_y + 10), text, font=font, fill="black") return image # Streamlit UI st.title("πŸ¦Έβ€β™‚οΈ AI Comic Story Generator") st.write("Enter a prompt to generate a comic-style story and image!") # User input user_prompt = st.text_input("πŸ“ Enter your story prompt:") if user_prompt: st.subheader("πŸ“– AI-Generated Story") generated_story = generate_story(user_prompt) st.write(generated_story) st.subheader("πŸ–ΌοΈ AI-Generated Image") with st.spinner("Generating image..."): image = image_generator(user_prompt, num_inference_steps=30).images[0] speech_text = generated_story.split(".")[0][:50] image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50)) st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)