xieqilenb commited on
Commit
6b8bebd
·
verified ·
1 Parent(s): fa2b8af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -37
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
 
 
 
 
4
 
5
  # ----------------------------
6
- # 生成图像描述函数
7
  # ----------------------------
8
  def generate_caption(image_file):
9
  """
@@ -13,67 +17,85 @@ def generate_caption(image_file):
13
  返回:
14
  caption: 生成的图片描述文本
15
  """
16
- # 打开图片(如果上传的是文件流,可以直接传给 pipeline)
17
  image = Image.open(image_file)
18
- # 利用 image-to-text pipeline 加载 Salesforce/blip-image-captioning-base 模型
19
  caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
20
- # 直接将图片传入 pipeline,返回结果是一个列表,每个元素是一个字典
21
  caption_results = caption_generator(image)
22
- caption = caption_results[0]['generated_text'] # 取第一个结果
23
  return caption
24
 
25
  # ----------------------------
26
- # 基于图片描述生成完整故事的函数
27
  # ----------------------------
28
- def generate_story(caption):
29
  """
30
- 基于图片描述生成完整故事,确保生成的故事至少包含100个单词。
31
  参数:
32
- caption: 图片描述文本
33
  返回:
34
- story: 生成的故事文本
35
  """
36
- # 使用 text-generation pipeline 加载 GPT-2 模型
37
  story_generator = pipeline("text-generation", model="gpt2")
38
- # 构建生成故事的提示语
39
- prompt = f"Based on the following image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words. "
40
-
41
- # 生成故事文本
42
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
43
  story = result[0]['generated_text']
44
 
45
- # 简单检查生成的故事单词数是否达到100,否则再生成部分文本补充
46
  if len(story.split()) < 100:
47
  additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
48
  story += " " + additional
49
  return story
50
 
51
  # ----------------------------
52
- # 文字转语音 (TTS) 函数
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # ----------------------------
54
  def text_to_speech(text, output_file="output.mp3"):
55
  """
56
- 将文本转换为语音并保存为 mp3 文件
57
  参数:
58
  text: 要转换的文本
59
  output_file: 保存的音频文件名
60
  返回:
61
- output_file: 转换后的音频文件路径
62
  """
63
- from gtts import gTTS
64
- # 这里语言参数设为英语 "en",
65
- # 如需中文可修改 lang="zh-cn",但对应文本生成模型也需生成中文
66
- tts = gTTS(text=text, lang="en")
67
  tts.save(output_file)
68
  return output_file
69
 
70
  # ----------------------------
71
- # 主函数:构建 Streamlit 界面
72
  # ----------------------------
73
  def main():
74
- st.title("儿童故事生成应用")
75
- st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!")
76
 
 
77
  uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
78
 
79
  if uploaded_file is not None:
@@ -84,18 +106,51 @@ def main():
84
  # 生成图片描述
85
  with st.spinner("正在生成图片描述..."):
86
  caption = generate_caption(uploaded_file)
87
- st.write("图片描述:", caption)
88
 
89
- # 根据图片描述生成完整故事
90
- with st.spinner("正在生成故事..."):
91
- story = generate_story(caption)
92
- st.write("生成的故事:")
93
- st.write(story)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # 文本转语音
96
- with st.spinner("正在转换成语音..."):
97
- audio_file = text_to_speech(story)
98
- st.audio(audio_file, format="audio/mp3")
99
-
100
  if __name__ == "__main__":
101
  main()
 
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
4
+ from gtts import gTTS
5
+ from diffusers import StableDiffusionPipeline
6
+ import torch
7
+ import os
8
 
9
  # ----------------------------
10
+ # 1. 图像描述生成函数
11
  # ----------------------------
12
  def generate_caption(image_file):
13
  """
 
17
  返回:
18
  caption: 生成的图片描述文本
19
  """
 
20
  image = Image.open(image_file)
 
21
  caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
22
  caption_results = caption_generator(image)
23
+ caption = caption_results[0]['generated_text'] # 取返回结果的第一个描述
24
  return caption
25
 
26
  # ----------------------------
27
+ # 2. 故事生成函数
28
  # ----------------------------
29
+ def generate_story(prompt):
30
  """
31
+ 基于提示语生成故事段落,要求至少100个单词,如果生成的文本字数不够,则再次补充
32
  参数:
33
+ prompt: 文本生成的提示语
34
  返回:
35
+ story: 生成的故事文本片段
36
  """
 
37
  story_generator = pipeline("text-generation", model="gpt2")
 
 
 
 
38
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
39
  story = result[0]['generated_text']
40
 
 
41
  if len(story.split()) < 100:
42
  additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
43
  story += " " + additional
44
  return story
45
 
46
  # ----------------------------
47
+ # 3. 图像生成(配图)相关函数
48
+ # ----------------------------
49
+ @st.cache_resource
50
+ def load_image_generator():
51
+ """
52
+ 加载稳定扩散模型,使用 Diffusers 库生成插图
53
+ """
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ pipe = StableDiffusionPipeline.from_pretrained(
56
+ "stabilityai/stable-diffusion-v1-5",
57
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
58
+ )
59
+ pipe = pipe.to(device)
60
+ return pipe
61
+
62
+ def generate_illustration(prompt):
63
+ """
64
+ 基于输入的提示语生成一张配图
65
+ 参数:
66
+ prompt: 用于生成图像的文本提示
67
+ 返回:
68
+ generated_image: 生成的 PIL Image 图像
69
+ """
70
+ pipe = load_image_generator()
71
+ image_result = pipe(prompt)
72
+ generated_image = image_result.images[0]
73
+ return generated_image
74
+
75
+ # ----------------------------
76
+ # 4. 文字转语音 (TTS) 函数
77
  # ----------------------------
78
  def text_to_speech(text, output_file="output.mp3"):
79
  """
80
+ 将输入文本转换为语音,并保存为 mp3 文件
81
  参数:
82
  text: 要转换的文本
83
  output_file: 保存的音频文件名
84
  返回:
85
+ output_file: 转换后生成的音频文件路径
86
  """
87
+ tts = gTTS(text=text, lang="en") # 如需中文,lang 可设置为 "zh-cn"
 
 
 
88
  tts.save(output_file)
89
  return output_file
90
 
91
  # ----------------------------
92
+ # 5. 主函数:构建 Streamlit 交互式应用
93
  # ----------------------------
94
  def main():
95
+ st.title("互动式故事生成与配图应用")
96
+ st.write("上传一张图片,我们会基于该图片生成描述,并自动生成一个儿童故事。你可以选择继续扩展改故事,也可以结束互动。每个生成的故事段落都会搭配 AI 配图。")
97
 
98
+ # 图片上传
99
  uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
100
 
101
  if uploaded_file is not None:
 
106
  # 生成图片描述
107
  with st.spinner("正在生成图片描述..."):
108
  caption = generate_caption(uploaded_file)
109
+ st.write("图片描述:", caption)
110
 
111
+ # 使用 session_state 保存生成的故事和插图历史
112
+ if "story" not in st.session_state:
113
+ # 生成初始故事段落(至少100个单词)
114
+ with st.spinner("正在生成初始故事..."):
115
+ initial_prompt = f"Based on the image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words."
116
+ story_segment = generate_story(initial_prompt)
117
+ st.session_state.story = story_segment
118
+ # 生成初始配图,使用初始故事的前200个字符作为提示
119
+ with st.spinner("正在生成初始配图..."):
120
+ illustration = generate_illustration(st.session_state.story[:200])
121
+ st.session_state.illustrations = [illustration]
122
+
123
+ st.write("### 生成的故事:")
124
+ st.write(st.session_state.story)
125
+
126
+ st.write("### 故事配图:")
127
+ for idx, illus in enumerate(st.session_state.illustrations):
128
+ st.image(illus, caption=f"配图段落 {idx+1}", use_column_width=True)
129
+
130
+ st.write("---")
131
+ st.write("是否继续生成故事?如果不再扩展,请点击“结束互动”。")
132
+
133
+ # 接收用户输入的额外情节提示(可选)
134
+ user_input = st.text_input("请输入你希望添加的故事情节(可选):", value="")
135
+
136
+ col1, col2 = st.columns(2)
137
+ if col1.button("继续生成故事"):
138
+ # 使用现有故事作为上下文,并附加用户输入的提示语生成新段落
139
+ additional_prompt = st.session_state.story + " " + (user_input if user_input.strip() != "" else "")
140
+ with st.spinner("正在生成新的故事段落..."):
141
+ new_segment = generate_story(additional_prompt)
142
+ st.session_state.story += " " + new_segment
143
+ # 为新段落生成配图,取新段落前200个字符作为提示
144
+ with st.spinner("正在生成新的配图..."):
145
+ new_illustration = generate_illustration(new_segment[:200])
146
+ st.session_state.illustrations.append(new_illustration)
147
+ st.experimental_rerun()
148
+
149
+ if col2.button("结束互动"):
150
+ with st.spinner("正在生成故事音频..."):
151
+ audio_file = text_to_speech(st.session_state.story)
152
+ st.write("故事生成完毕!请点击下方按钮播放故事音频。")
153
+ st.audio(audio_file, format="audio/mp3")
154
 
 
 
 
 
 
155
  if __name__ == "__main__":
156
  main()