blabla / app.py
xieqilenb's picture
Update app.py
c3409d1 verified
raw
history blame
4.67 kB
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import torch
import os
from diffusers import DiffusionPipeline
# ----------------------------
# 1. 图像描述生成函数
# ----------------------------
def generate_caption(image_file):
"""
使用 Hugging Face pipeline 的 image-to-text 模型生成图片描述
参数:
image_file: 上传的图片文件(文件对象或文件路径)
返回:
caption: 生成的图片描述文本
"""
image = Image.open(image_file)
caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
caption_results = caption_generator(image)
caption = caption_results[0]['generated_text'] # 取返回结果的第一个描述
return caption
# ----------------------------
# 2. 故事生成函数
# ----------------------------
def generate_story(prompt):
"""
基于提示语生成故事段落,要求至少100个单词,如果生成的文本字数不够,则再次补充
参数:
prompt: 文本生成的提示语
返回:
story: 生成的故事文本片段
"""
story_generator = pipeline("text-generation", model="gpt2")
result = story_generator(prompt, max_length=300, num_return_sequences=1)
story = result[0]['generated_text']
if len(story.split()) < 100:
additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
story += " " + additional
return story
# ----------------------------
# 3. 图像生成(配图)相关函数
# ----------------------------
@st.cache_resource
def load_image_generator():
"""
加载稳定扩散模型,使用 Diffusers 库生成插图
使用 StableDiffusionPipeline 替代 DiffusionPipeline
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# 导入 StableDiffusionPipeline
# 对于 GPU,采用 fp16 精度以加速推理;否则使用默认精度
torch_dtype = torch.float16 if device == "cuda" else torch.float32
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
pipe = pipe.to(device)
return pipe
def generate_illustration(prompt):
"""
基于输入的提示语生成一张配图
参数:
prompt: 用于生成图像的文本提示
返回:
generated_image: 生成的 PIL Image 图像
"""
pipe = load_image_generator()
image_result = pipe(prompt)
generated_image = image_result.images[0]
return generated_image
# ----------------------------
# 4. 文字转语音 (TTS) 函数
# ----------------------------
def text_to_speech(text, output_file="output.mp3"):
"""
将输入文本转换为语音,并保存为 mp3 文件
参数:
text: 要转换的文本
output_file: 保存的音频文件名
返回:
output_file: 转换后生成的音频文件路径
"""
tts = gTTS(text=text, lang="en") # 如需中文,lang 可设置为 "zh-cn"
tts.save(output_file)
return output_file
# ----------------------------
# 5. 主函数:构建 Streamlit 交互式应用
# ----------------------------
def main():
st.title("儿童故事生成应用")
st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!")
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
# 显示上传的图片
image = Image.open(uploaded_file)
st.image(image, caption="上传的图片", use_column_width=True)
# 生成图片描述
with st.spinner("正在生成图片描述..."):
caption = generate_caption(uploaded_file)
st.write("图片描述:", caption)
# 根据图片描述生成完整故事
with st.spinner("正在生成故事..."):
story = generate_story(caption)
st.write("生成的故事:")
st.write(story)
# 生成配图
# 这里使用故事内容的前200个字符作为提示生成配图,实际中可以根据需要调整策略
with st.spinner("正在生成插图..."):
illustration = generate_illustration(story[:200])
st.write("### 故事配图:")
st.image(illustration, caption="配图", use_column_width=True)
# 文本转语音
with st.spinner("正在转换成语音..."):
audio_file = text_to_speech(story)
st.audio(audio_file, format="audio/mp3")
if __name__ == "__main__":
main()