comic / app.py
bh4vay's picture
Update app.py
35e1ad5 verified
import streamlit as st
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
st.write(f"Using device: {device}") # Debug message
# Load text model (TinyLlama) with optimizations
@st.cache_resource
def load_text_model():
try:
st.write("⏳ Loading text model...")
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with FP16 or 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True
).to(device)
st.write("βœ… Text model loaded successfully!")
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
except Exception as e:
st.error(f"❌ Error loading text model: {e}")
return None
story_generator = load_text_model()
# Load image model (Stable Diffusion) with optimizations
@st.cache_resource
def load_image_model():
try:
st.write("⏳ Loading image model...")
model_id = "stabilityai/sd-turbo"
model = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
model.enable_attention_slicing() # Optimize GPU memory
st.write("βœ… Image model loaded successfully!")
return model
except Exception as e:
st.error(f"❌ Error loading image model: {e}")
return None
image_generator = load_image_model()
# Function to generate a short story
def generate_story(prompt):
if not story_generator:
return "❌ Error: Story model not loaded."
formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
try:
st.write("⏳ Generating story...")
story_output = story_generator(
formatted_prompt,
max_length=100, # Reduced for speed
do_sample=True,
temperature=0.7,
top_k=30,
num_return_sequences=1,
truncation=True
)[0]['generated_text']
st.write("βœ… Story generated successfully!")
return story_output.replace(formatted_prompt, "").strip()
except Exception as e:
st.error(f"❌ Error generating story: {e}")
return "Error generating story."
# Streamlit UI
st.title("πŸ¦Έβ€β™‚οΈ AI Comic Story Generator")
st.write("Enter a prompt to generate a comic-style story and image!")
# User input
user_prompt = st.text_input("πŸ“ Enter your story prompt:")
if user_prompt:
st.subheader("πŸ“– AI-Generated Story")
generated_story = generate_story(user_prompt)
st.write(generated_story)
st.subheader("πŸ–ΌοΈ AI-Generated Image")
if not image_generator:
st.error("❌ Error: Image model not loaded.")
else:
with st.spinner("⏳ Generating image..."):
try:
image = image_generator(
user_prompt,
num_inference_steps=8, # Reduced for faster generation
height=256, width=256 # Smaller size to reduce memory usage
).images[0]
st.write("βœ… Image generated successfully!")
st.image(image, caption="Generated Comic Image", use_container_width=True)
except Exception as e:
st.error(f"❌ Error generating image: {e}")