import streamlit as st from PIL import Image from transformers import pipeline from gtts import gTTS import torch import os from diffusers import DiffusionPipeline # ---------------------------- # 1. 图像描述生成函数 # ---------------------------- def generate_caption(image_file): """ 使用 Hugging Face pipeline 的 image-to-text 模型生成图片描述 参数: image_file: 上传的图片文件(文件对象或文件路径) 返回: caption: 生成的图片描述文本 """ image = Image.open(image_file) caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") caption_results = caption_generator(image) caption = caption_results[0]['generated_text'] # 取返回结果的第一个描述 return caption # ---------------------------- # 2. 故事生成函数 # ---------------------------- def generate_story(prompt): """ 基于提示语生成故事段落,要求至少100个单词,如果生成的文本字数不够,则再次补充 参数: prompt: 文本生成的提示语 返回: story: 生成的故事文本片段 """ story_generator = pipeline("text-generation", model="gpt2") result = story_generator(prompt, max_length=300, num_return_sequences=1) story = result[0]['generated_text'] if len(story.split()) < 100: additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text'] story += " " + additional return story # ---------------------------- # 3. 图像生成(配图)相关函数 # ---------------------------- @st.cache_resource def load_image_generator(): """ 加载稳定扩散模型,使用 Diffusers 库生成插图 使用 StableDiffusionPipeline 替代 DiffusionPipeline """ device = "cuda" if torch.cuda.is_available() else "cpu" # 导入 StableDiffusionPipeline # 对于 GPU,采用 fp16 精度以加速推理;否则使用默认精度 torch_dtype = torch.float16 if device == "cuda" else torch.float32 pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") pipe = pipe.to(device) return pipe def generate_illustration(prompt): """ 基于输入的提示语生成一张配图 参数: prompt: 用于生成图像的文本提示 返回: generated_image: 生成的 PIL Image 图像 """ pipe = load_image_generator() image_result = pipe(prompt) generated_image = image_result.images[0] return generated_image # ---------------------------- # 4. 文字转语音 (TTS) 函数 # ---------------------------- def text_to_speech(text, output_file="output.mp3"): """ 将输入文本转换为语音,并保存为 mp3 文件 参数: text: 要转换的文本 output_file: 保存的音频文件名 返回: output_file: 转换后生成的音频文件路径 """ tts = gTTS(text=text, lang="en") # 如需中文,lang 可设置为 "zh-cn" tts.save(output_file) return output_file # ---------------------------- # 5. 主函数:构建 Streamlit 交互式应用 # ---------------------------- def main(): st.title("儿童故事生成应用") st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!") uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: # 显示上传的图片 image = Image.open(uploaded_file) st.image(image, caption="上传的图片", use_column_width=True) # 生成图片描述 with st.spinner("正在生成图片描述..."): caption = generate_caption(uploaded_file) st.write("图片描述:", caption) # 根据图片描述生成完整故事 with st.spinner("正在生成故事..."): story = generate_story(caption) st.write("生成的故事:") st.write(story) # 生成配图 # 这里使用故事内容的前200个字符作为提示生成配图,实际中可以根据需要调整策略 with st.spinner("正在生成插图..."): illustration = generate_illustration(story[:200]) st.write("### 故事配图:") st.image(illustration, caption="配图", use_column_width=True) # 文本转语音 with st.spinner("正在转换成语音..."): audio_file = text_to_speech(story) st.audio(audio_file, format="audio/mp3") if __name__ == "__main__": main()