|
import streamlit as st |
|
import torch |
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
st.write(f"Using device: {device}") |
|
|
|
|
|
@st.cache_resource |
|
def load_text_model(): |
|
try: |
|
st.write("β³ Loading text model...") |
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
low_cpu_mem_usage=True |
|
).to(device) |
|
|
|
st.write("β
Text model loaded successfully!") |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1) |
|
|
|
except Exception as e: |
|
st.error(f"β Error loading text model: {e}") |
|
return None |
|
|
|
story_generator = load_text_model() |
|
|
|
|
|
@st.cache_resource |
|
def load_image_model(): |
|
try: |
|
st.write("β³ Loading image model...") |
|
model_id = "stabilityai/sd-turbo" |
|
model = StableDiffusionPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
model.enable_attention_slicing() |
|
st.write("β
Image model loaded successfully!") |
|
return model |
|
except Exception as e: |
|
st.error(f"β Error loading image model: {e}") |
|
return None |
|
|
|
image_generator = load_image_model() |
|
|
|
|
|
def generate_story(prompt): |
|
if not story_generator: |
|
return "β Error: Story model not loaded." |
|
|
|
formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:" |
|
|
|
try: |
|
st.write("β³ Generating story...") |
|
story_output = story_generator( |
|
formatted_prompt, |
|
max_length=100, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_k=30, |
|
num_return_sequences=1, |
|
truncation=True |
|
)[0]['generated_text'] |
|
st.write("β
Story generated successfully!") |
|
return story_output.replace(formatted_prompt, "").strip() |
|
except Exception as e: |
|
st.error(f"β Error generating story: {e}") |
|
return "Error generating story." |
|
|
|
|
|
st.title("π¦ΈββοΈ AI Comic Story Generator") |
|
st.write("Enter a prompt to generate a comic-style story and image!") |
|
|
|
|
|
user_prompt = st.text_input("π Enter your story prompt:") |
|
|
|
if user_prompt: |
|
st.subheader("π AI-Generated Story") |
|
generated_story = generate_story(user_prompt) |
|
st.write(generated_story) |
|
|
|
st.subheader("πΌοΈ AI-Generated Image") |
|
|
|
if not image_generator: |
|
st.error("β Error: Image model not loaded.") |
|
else: |
|
with st.spinner("β³ Generating image..."): |
|
try: |
|
image = image_generator( |
|
user_prompt, |
|
num_inference_steps=8, |
|
height=256, width=256 |
|
).images[0] |
|
st.write("β
Image generated successfully!") |
|
st.image(image, caption="Generated Comic Image", use_container_width=True) |
|
except Exception as e: |
|
st.error(f"β Error generating image: {e}") |
|
|