|
import streamlit as st |
|
from PIL import Image |
|
from transformers import pipeline |
|
from gtts import gTTS |
|
import torch |
|
import os |
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
|
|
def generate_caption(image_file): |
|
""" |
|
使用 Hugging Face pipeline 的 image-to-text 模型生成图片描述 |
|
参数: |
|
image_file: 上传的图片文件(文件对象或文件路径) |
|
返回: |
|
caption: 生成的图片描述文本 |
|
""" |
|
image = Image.open(image_file) |
|
caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
|
caption_results = caption_generator(image) |
|
caption = caption_results[0]['generated_text'] |
|
return caption |
|
|
|
|
|
|
|
|
|
def generate_story(prompt): |
|
""" |
|
基于提示语生成故事段落,要求至少100个单词,如果生成的文本字数不够,则再次补充 |
|
参数: |
|
prompt: 文本生成的提示语 |
|
返回: |
|
story: 生成的故事文本片段 |
|
""" |
|
story_generator = pipeline("text-generation", model="gpt2") |
|
result = story_generator(prompt, max_length=300, num_return_sequences=1) |
|
story = result[0]['generated_text'] |
|
|
|
if len(story.split()) < 100: |
|
additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text'] |
|
story += " " + additional |
|
return story |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_image_generator(): |
|
""" |
|
加载稳定扩散模型,使用 Diffusers 库生成插图 |
|
使用 StableDiffusionPipeline 替代 DiffusionPipeline |
|
""" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") |
|
pipe = pipe.to(device) |
|
return pipe |
|
|
|
def generate_illustration(prompt): |
|
""" |
|
基于输入的提示语生成一张配图 |
|
参数: |
|
prompt: 用于生成图像的文本提示 |
|
返回: |
|
generated_image: 生成的 PIL Image 图像 |
|
""" |
|
pipe = load_image_generator() |
|
image_result = pipe(prompt) |
|
generated_image = image_result.images[0] |
|
return generated_image |
|
|
|
|
|
|
|
|
|
def text_to_speech(text, output_file="output.mp3"): |
|
""" |
|
将输入文本转换为语音,并保存为 mp3 文件 |
|
参数: |
|
text: 要转换的文本 |
|
output_file: 保存的音频文件名 |
|
返回: |
|
output_file: 转换后生成的音频文件路径 |
|
""" |
|
tts = gTTS(text=text, lang="en") |
|
tts.save(output_file) |
|
return output_file |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("儿童故事生成应用") |
|
st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!") |
|
|
|
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="上传的图片", use_column_width=True) |
|
|
|
|
|
with st.spinner("正在生成图片描述..."): |
|
caption = generate_caption(uploaded_file) |
|
st.write("图片描述:", caption) |
|
|
|
|
|
with st.spinner("正在生成故事..."): |
|
story = generate_story(caption) |
|
st.write("生成的故事:") |
|
st.write(story) |
|
|
|
|
|
|
|
with st.spinner("正在生成插图..."): |
|
illustration = generate_illustration(story[:200]) |
|
|
|
st.write("### 故事配图:") |
|
st.image(illustration, caption="配图", use_column_width=True) |
|
|
|
|
|
with st.spinner("正在转换成语音..."): |
|
audio_file = text_to_speech(story) |
|
st.audio(audio_file, format="audio/mp3") |
|
|
|
if __name__ == "__main__": |
|
main() |