xieqilenb commited on
Commit
a31b925
·
verified ·
1 Parent(s): b232514

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -46
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
5
- from diffusers import DiffusionPipeline
6
  import torch
7
  import os
8
 
@@ -50,9 +49,19 @@ def generate_story(prompt):
50
  def load_image_generator():
51
  """
52
  加载稳定扩散模型,使用 Diffusers 库生成插图
 
53
  """
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
- pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
 
 
 
 
 
 
 
 
 
56
  pipe = pipe.to(device)
57
  return pipe
58
 
@@ -89,10 +98,9 @@ def text_to_speech(text, output_file="output.mp3"):
89
  # 5. 主函数:构建 Streamlit 交互式应用
90
  # ----------------------------
91
  def main():
92
- st.title("互动式故事生成与配图应用")
93
- st.write("上传一张图片,我们会基于该图片生成描述,并自动生成一个儿童故事。你可以选择继续扩展改故事,也可以结束互动。每个生成的故事段落都会搭配 AI 配图。")
94
 
95
- # 图片上传
96
  uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
97
 
98
  if uploaded_file is not None:
@@ -103,51 +111,26 @@ def main():
103
  # 生成图片描述
104
  with st.spinner("正在生成图片描述..."):
105
  caption = generate_caption(uploaded_file)
106
- st.write("图片描述:", caption)
107
 
108
- # 使用 session_state 保存生成的故事和插图历史
109
- if "story" not in st.session_state:
110
- # 生成初始故事段落(至少100个单词)
111
- with st.spinner("正在生成初始故事..."):
112
- initial_prompt = f"Based on the image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words."
113
- story_segment = generate_story(initial_prompt)
114
- st.session_state.story = story_segment
115
- # 生成初始配图,使用初始故事的前200个字符作为提示
116
- with st.spinner("正在生成初始配图..."):
117
- illustration = generate_illustration(st.session_state.story[:200])
118
- st.session_state.illustrations = [illustration]
119
 
120
- st.write("### 生成的故事:")
121
- st.write(st.session_state.story)
 
 
122
 
123
  st.write("### 故事配图:")
124
- for idx, illus in enumerate(st.session_state.illustrations):
125
- st.image(illus, caption=f"配图段落 {idx+1}", use_column_width=True)
126
-
127
- st.write("---")
128
- st.write("是否继续生成故事?如果不再扩展,请点击“结束互动”。")
129
-
130
- # 接收用户输入的额外情节提示(可选)
131
- user_input = st.text_input("请输入你希望添加的故事情节(可选):", value="")
132
-
133
- col1, col2 = st.columns(2)
134
- if col1.button("继续生成故事"):
135
- # 使用现有故事作为上下文,并附加用户输入的提示语生成新段落
136
- additional_prompt = st.session_state.story + " " + (user_input if user_input.strip() != "" else "")
137
- with st.spinner("正在生成新的故事段落..."):
138
- new_segment = generate_story(additional_prompt)
139
- st.session_state.story += " " + new_segment
140
- # 为新段落生成配图,取新段落前200个字符作为提示
141
- with st.spinner("正在生成新的配图..."):
142
- new_illustration = generate_illustration(new_segment[:200])
143
- st.session_state.illustrations.append(new_illustration)
144
- st.experimental_rerun()
145
-
146
- if col2.button("结束互动"):
147
- with st.spinner("正在生成故事音频..."):
148
- audio_file = text_to_speech(st.session_state.story)
149
- st.write("故事生成完毕!请点击下方按钮播放故事音频。")
150
- st.audio(audio_file, format="audio/mp3")
151
 
 
 
 
 
 
152
  if __name__ == "__main__":
153
  main()
 
2
  from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
 
5
  import torch
6
  import os
7
 
 
49
  def load_image_generator():
50
  """
51
  加载稳定扩散模型,使用 Diffusers 库生成插图
52
+ 使用 StableDiffusionPipeline 替代 DiffusionPipeline
53
  """
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
+
56
+ # 导入 StableDiffusionPipeline
57
+ from diffusers import StableDiffusionPipeline
58
+
59
+ # 对于 GPU,采用 fp16 精度以加速推理;否则使用默认精度
60
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
61
+ pipe = StableDiffusionPipeline.from_pretrained(
62
+ "stabilityai/stable-diffusion-v1-5",
63
+ torch_dtype=torch_dtype
64
+ )
65
  pipe = pipe.to(device)
66
  return pipe
67
 
 
98
  # 5. 主函数:构建 Streamlit 交互式应用
99
  # ----------------------------
100
  def main():
101
+ st.title("儿童故事生成应用")
102
+ st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!")
103
 
 
104
  uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
105
 
106
  if uploaded_file is not None:
 
111
  # 生成图片描述
112
  with st.spinner("正在生成图片描述..."):
113
  caption = generate_caption(uploaded_file)
114
+ st.write("图片描述:", caption)
115
 
116
+ # 根据图片描述生成完整故事
117
+ with st.spinner("正在生成故事..."):
118
+ story = generate_story(caption)
119
+ st.write("生成的故事:")
120
+ st.write(story)
 
 
 
 
 
 
121
 
122
+ # 生成配图
123
+ # 这里使用故事内容的前200个字符作为提示生成配图,实际中可以根据需要调整策略
124
+ with st.spinner("正在生成插图..."):
125
+ illustration = generate_illustration(story[:200])
126
 
127
  st.write("### 故事配图:")
128
+ st.image(illustration, caption="配图", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ # 文本转语音
131
+ with st.spinner("正在转换成语音..."):
132
+ audio_file = text_to_speech(story)
133
+ st.audio(audio_file, format="audio/mp3")
134
+
135
  if __name__ == "__main__":
136
  main()