Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
|
|
2 |
from PIL import Image
|
3 |
from transformers import pipeline
|
4 |
from gtts import gTTS
|
5 |
-
from diffusers import DiffusionPipeline
|
6 |
import torch
|
7 |
import os
|
8 |
|
@@ -50,9 +49,19 @@ def generate_story(prompt):
|
|
50 |
def load_image_generator():
|
51 |
"""
|
52 |
加载稳定扩散模型,使用 Diffusers 库生成插图
|
|
|
53 |
"""
|
54 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
pipe = pipe.to(device)
|
57 |
return pipe
|
58 |
|
@@ -89,10 +98,9 @@ def text_to_speech(text, output_file="output.mp3"):
|
|
89 |
# 5. 主函数:构建 Streamlit 交互式应用
|
90 |
# ----------------------------
|
91 |
def main():
|
92 |
-
st.title("
|
93 |
-
st.write("
|
94 |
|
95 |
-
# 图片上传
|
96 |
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
|
97 |
|
98 |
if uploaded_file is not None:
|
@@ -103,51 +111,26 @@ def main():
|
|
103 |
# 生成图片描述
|
104 |
with st.spinner("正在生成图片描述..."):
|
105 |
caption = generate_caption(uploaded_file)
|
106 |
-
st.write("
|
107 |
|
108 |
-
#
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
story_segment = generate_story(initial_prompt)
|
114 |
-
st.session_state.story = story_segment
|
115 |
-
# 生成初始配图,使用初始故事的前200个字符作为提示
|
116 |
-
with st.spinner("正在生成初始配图..."):
|
117 |
-
illustration = generate_illustration(st.session_state.story[:200])
|
118 |
-
st.session_state.illustrations = [illustration]
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
122 |
|
123 |
st.write("### 故事配图:")
|
124 |
-
|
125 |
-
st.image(illus, caption=f"配图段落 {idx+1}", use_column_width=True)
|
126 |
-
|
127 |
-
st.write("---")
|
128 |
-
st.write("是否继续生成故事?如果不再扩展,请点击“结束互动”。")
|
129 |
-
|
130 |
-
# 接收用户输入的额外情节提示(可选)
|
131 |
-
user_input = st.text_input("请输入你希望添加的故事情节(可选):", value="")
|
132 |
-
|
133 |
-
col1, col2 = st.columns(2)
|
134 |
-
if col1.button("继续生成故事"):
|
135 |
-
# 使用现有故事作为上下文,并附加用户输入的提示语生成新段落
|
136 |
-
additional_prompt = st.session_state.story + " " + (user_input if user_input.strip() != "" else "")
|
137 |
-
with st.spinner("正在生成新的故事段落..."):
|
138 |
-
new_segment = generate_story(additional_prompt)
|
139 |
-
st.session_state.story += " " + new_segment
|
140 |
-
# 为新段落生成配图,取新段落前200个字符作为提示
|
141 |
-
with st.spinner("正在生成新的配图..."):
|
142 |
-
new_illustration = generate_illustration(new_segment[:200])
|
143 |
-
st.session_state.illustrations.append(new_illustration)
|
144 |
-
st.experimental_rerun()
|
145 |
-
|
146 |
-
if col2.button("结束互动"):
|
147 |
-
with st.spinner("正在生成故事音频..."):
|
148 |
-
audio_file = text_to_speech(st.session_state.story)
|
149 |
-
st.write("故事生成完毕!请点击下方按钮播放故事音频。")
|
150 |
-
st.audio(audio_file, format="audio/mp3")
|
151 |
|
|
|
|
|
|
|
|
|
|
|
152 |
if __name__ == "__main__":
|
153 |
main()
|
|
|
2 |
from PIL import Image
|
3 |
from transformers import pipeline
|
4 |
from gtts import gTTS
|
|
|
5 |
import torch
|
6 |
import os
|
7 |
|
|
|
49 |
def load_image_generator():
|
50 |
"""
|
51 |
加载稳定扩散模型,使用 Diffusers 库生成插图
|
52 |
+
使用 StableDiffusionPipeline 替代 DiffusionPipeline
|
53 |
"""
|
54 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
+
|
56 |
+
# 导入 StableDiffusionPipeline
|
57 |
+
from diffusers import StableDiffusionPipeline
|
58 |
+
|
59 |
+
# 对于 GPU,采用 fp16 精度以加速推理;否则使用默认精度
|
60 |
+
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
61 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
62 |
+
"stabilityai/stable-diffusion-v1-5",
|
63 |
+
torch_dtype=torch_dtype
|
64 |
+
)
|
65 |
pipe = pipe.to(device)
|
66 |
return pipe
|
67 |
|
|
|
98 |
# 5. 主函数:构建 Streamlit 交互式应用
|
99 |
# ----------------------------
|
100 |
def main():
|
101 |
+
st.title("儿童故事生成应用")
|
102 |
+
st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!")
|
103 |
|
|
|
104 |
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
|
105 |
|
106 |
if uploaded_file is not None:
|
|
|
111 |
# 生成图片描述
|
112 |
with st.spinner("正在生成图片描述..."):
|
113 |
caption = generate_caption(uploaded_file)
|
114 |
+
st.write("图片描述:", caption)
|
115 |
|
116 |
+
# 根据图片描述生成完整故事
|
117 |
+
with st.spinner("正在生成故事..."):
|
118 |
+
story = generate_story(caption)
|
119 |
+
st.write("生成的故事:")
|
120 |
+
st.write(story)
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
# 生成配图
|
123 |
+
# 这里使用故事内容的前200个字符作为提示生成配图,实际中可以根据需要调整策略
|
124 |
+
with st.spinner("正在生成插图..."):
|
125 |
+
illustration = generate_illustration(story[:200])
|
126 |
|
127 |
st.write("### 故事配图:")
|
128 |
+
st.image(illustration, caption="配图", use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
# 文本转语音
|
131 |
+
with st.spinner("正在转换成语音..."):
|
132 |
+
audio_file = text_to_speech(story)
|
133 |
+
st.audio(audio_file, format="audio/mp3")
|
134 |
+
|
135 |
if __name__ == "__main__":
|
136 |
main()
|