import streamlit as st import torch from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer from diffusers import StableDiffusionPipeline # Check for GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" st.write(f"Using device: {device}") # Debug message # Load text model (TinyLlama) with optimizations @st.cache_resource def load_text_model(): try: st.write("⏳ Loading text model...") model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_name) # Load model with FP16 or 8-bit quantization model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32, low_cpu_mem_usage=True ).to(device) st.write("✅ Text model loaded successfully!") return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1) except Exception as e: st.error(f"❌ Error loading text model: {e}") return None story_generator = load_text_model() # Load image model (Stable Diffusion) with optimizations @st.cache_resource def load_image_model(): try: st.write("⏳ Loading image model...") model_id = "stabilityai/sd-turbo" model = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) model.enable_attention_slicing() # Optimize GPU memory st.write("✅ Image model loaded successfully!") return model except Exception as e: st.error(f"❌ Error loading image model: {e}") return None image_generator = load_image_model() # Function to generate a short story def generate_story(prompt): if not story_generator: return "❌ Error: Story model not loaded." formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:" try: st.write("⏳ Generating story...") story_output = story_generator( formatted_prompt, max_length=100, # Reduced for speed do_sample=True, temperature=0.7, top_k=30, num_return_sequences=1, truncation=True )[0]['generated_text'] st.write("✅ Story generated successfully!") return story_output.replace(formatted_prompt, "").strip() except Exception as e: st.error(f"❌ Error generating story: {e}") return "Error generating story." # Streamlit UI st.title("🦸‍♂️ AI Comic Story Generator") st.write("Enter a prompt to generate a comic-style story and image!") # User input user_prompt = st.text_input("📝 Enter your story prompt:") if user_prompt: st.subheader("📖 AI-Generated Story") generated_story = generate_story(user_prompt) st.write(generated_story) st.subheader("🖼️ AI-Generated Image") if not image_generator: st.error("❌ Error: Image model not loaded.") else: with st.spinner("⏳ Generating image..."): try: image = image_generator( user_prompt, num_inference_steps=8, # Reduced for faster generation height=256, width=256 # Smaller size to reduce memory usage ).images[0] st.write("✅ Image generated successfully!") st.image(image, caption="Generated Comic Image", use_container_width=True) except Exception as e: st.error(f"❌ Error generating image: {e}")