Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
from transformers import pipeline
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# ----------------------------
|
6 |
-
#
|
7 |
# ----------------------------
|
8 |
def generate_caption(image_file):
|
9 |
"""
|
@@ -13,67 +17,85 @@ def generate_caption(image_file):
|
|
13 |
返回:
|
14 |
caption: 生成的图片描述文本
|
15 |
"""
|
16 |
-
# 打开图片(如果上传的是文件流,可以直接传给 pipeline)
|
17 |
image = Image.open(image_file)
|
18 |
-
# 利用 image-to-text pipeline 加载 Salesforce/blip-image-captioning-base 模型
|
19 |
caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
20 |
-
# 直接将图片传入 pipeline,返回结果是一个列表,每个元素是一个字典
|
21 |
caption_results = caption_generator(image)
|
22 |
-
caption = caption_results[0]['generated_text'] #
|
23 |
return caption
|
24 |
|
25 |
# ----------------------------
|
26 |
-
#
|
27 |
# ----------------------------
|
28 |
-
def generate_story(
|
29 |
"""
|
30 |
-
|
31 |
参数:
|
32 |
-
|
33 |
返回:
|
34 |
-
story:
|
35 |
"""
|
36 |
-
# 使用 text-generation pipeline 加载 GPT-2 模型
|
37 |
story_generator = pipeline("text-generation", model="gpt2")
|
38 |
-
# 构建生成故事的提示语
|
39 |
-
prompt = f"Based on the following image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words. "
|
40 |
-
|
41 |
-
# 生成故事文本
|
42 |
result = story_generator(prompt, max_length=300, num_return_sequences=1)
|
43 |
story = result[0]['generated_text']
|
44 |
|
45 |
-
# 简单检查生成的故事单词数是否达到100,否则再生成部分文本补充
|
46 |
if len(story.split()) < 100:
|
47 |
additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
|
48 |
story += " " + additional
|
49 |
return story
|
50 |
|
51 |
# ----------------------------
|
52 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
# ----------------------------
|
54 |
def text_to_speech(text, output_file="output.mp3"):
|
55 |
"""
|
56 |
-
|
57 |
参数:
|
58 |
text: 要转换的文本
|
59 |
output_file: 保存的音频文件名
|
60 |
返回:
|
61 |
-
output_file:
|
62 |
"""
|
63 |
-
|
64 |
-
# 这里语言参数设为英语 "en",
|
65 |
-
# 如需中文可修改 lang="zh-cn",但对应文本生成模型也需生成中文
|
66 |
-
tts = gTTS(text=text, lang="en")
|
67 |
tts.save(output_file)
|
68 |
return output_file
|
69 |
|
70 |
# ----------------------------
|
71 |
-
# 主函数:构建 Streamlit
|
72 |
# ----------------------------
|
73 |
def main():
|
74 |
-
st.title("
|
75 |
-
st.write("
|
76 |
|
|
|
77 |
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
|
78 |
|
79 |
if uploaded_file is not None:
|
@@ -84,18 +106,51 @@ def main():
|
|
84 |
# 生成图片描述
|
85 |
with st.spinner("正在生成图片描述..."):
|
86 |
caption = generate_caption(uploaded_file)
|
87 |
-
st.write("
|
88 |
|
89 |
-
#
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
# 文本转语音
|
96 |
-
with st.spinner("正在转换成语音..."):
|
97 |
-
audio_file = text_to_speech(story)
|
98 |
-
st.audio(audio_file, format="audio/mp3")
|
99 |
-
|
100 |
if __name__ == "__main__":
|
101 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
from transformers import pipeline
|
4 |
+
from gtts import gTTS
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
|
9 |
# ----------------------------
|
10 |
+
# 1. 图像描述生成函数
|
11 |
# ----------------------------
|
12 |
def generate_caption(image_file):
|
13 |
"""
|
|
|
17 |
返回:
|
18 |
caption: 生成的图片描述文本
|
19 |
"""
|
|
|
20 |
image = Image.open(image_file)
|
|
|
21 |
caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
|
|
22 |
caption_results = caption_generator(image)
|
23 |
+
caption = caption_results[0]['generated_text'] # 取返回结果的第一个描述
|
24 |
return caption
|
25 |
|
26 |
# ----------------------------
|
27 |
+
# 2. 故事生成函数
|
28 |
# ----------------------------
|
29 |
+
def generate_story(prompt):
|
30 |
"""
|
31 |
+
基于提示语生成故事段落,要求至少100个单词,如果生成的文本字数不够,则再次补充
|
32 |
参数:
|
33 |
+
prompt: 文本生成的提示语
|
34 |
返回:
|
35 |
+
story: 生成的故事文本片段
|
36 |
"""
|
|
|
37 |
story_generator = pipeline("text-generation", model="gpt2")
|
|
|
|
|
|
|
|
|
38 |
result = story_generator(prompt, max_length=300, num_return_sequences=1)
|
39 |
story = result[0]['generated_text']
|
40 |
|
|
|
41 |
if len(story.split()) < 100:
|
42 |
additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
|
43 |
story += " " + additional
|
44 |
return story
|
45 |
|
46 |
# ----------------------------
|
47 |
+
# 3. 图像生成(配图)相关函数
|
48 |
+
# ----------------------------
|
49 |
+
@st.cache_resource
|
50 |
+
def load_image_generator():
|
51 |
+
"""
|
52 |
+
加载稳定扩散模型,使用 Diffusers 库生成插图
|
53 |
+
"""
|
54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
56 |
+
"stabilityai/stable-diffusion-v1-5",
|
57 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
58 |
+
)
|
59 |
+
pipe = pipe.to(device)
|
60 |
+
return pipe
|
61 |
+
|
62 |
+
def generate_illustration(prompt):
|
63 |
+
"""
|
64 |
+
基于输入的提示语生成一张配图
|
65 |
+
参数:
|
66 |
+
prompt: 用于生成图像的文本提示
|
67 |
+
返回:
|
68 |
+
generated_image: 生成的 PIL Image 图像
|
69 |
+
"""
|
70 |
+
pipe = load_image_generator()
|
71 |
+
image_result = pipe(prompt)
|
72 |
+
generated_image = image_result.images[0]
|
73 |
+
return generated_image
|
74 |
+
|
75 |
+
# ----------------------------
|
76 |
+
# 4. 文字转语音 (TTS) 函数
|
77 |
# ----------------------------
|
78 |
def text_to_speech(text, output_file="output.mp3"):
|
79 |
"""
|
80 |
+
将输入文本转换为语音,并保存为 mp3 文件
|
81 |
参数:
|
82 |
text: 要转换的文本
|
83 |
output_file: 保存的音频文件名
|
84 |
返回:
|
85 |
+
output_file: 转换后生成的音频文件路径
|
86 |
"""
|
87 |
+
tts = gTTS(text=text, lang="en") # 如需中文,lang 可设置为 "zh-cn"
|
|
|
|
|
|
|
88 |
tts.save(output_file)
|
89 |
return output_file
|
90 |
|
91 |
# ----------------------------
|
92 |
+
# 5. 主函数:构建 Streamlit 交互式应用
|
93 |
# ----------------------------
|
94 |
def main():
|
95 |
+
st.title("互动式故事生成与配图应用")
|
96 |
+
st.write("上传一张图片,我们会基于该图片生成描述,并自动生成一个儿童故事。你可以选择继续扩展改故事,也可以结束互动。每个生成的故事段落都会搭配 AI 配图。")
|
97 |
|
98 |
+
# 图片上传
|
99 |
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
|
100 |
|
101 |
if uploaded_file is not None:
|
|
|
106 |
# 生成图片描述
|
107 |
with st.spinner("正在生成图片描述..."):
|
108 |
caption = generate_caption(uploaded_file)
|
109 |
+
st.write("图片描述:", caption)
|
110 |
|
111 |
+
# 使用 session_state 保存生成的故事和插图历史
|
112 |
+
if "story" not in st.session_state:
|
113 |
+
# 生成初始故事段落(至少100个单词)
|
114 |
+
with st.spinner("正在生成初始故事..."):
|
115 |
+
initial_prompt = f"Based on the image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words."
|
116 |
+
story_segment = generate_story(initial_prompt)
|
117 |
+
st.session_state.story = story_segment
|
118 |
+
# 生成初始配图,使用初始故事的前200个字符作为提示
|
119 |
+
with st.spinner("正在生成初始配图..."):
|
120 |
+
illustration = generate_illustration(st.session_state.story[:200])
|
121 |
+
st.session_state.illustrations = [illustration]
|
122 |
+
|
123 |
+
st.write("### 生成的故事:")
|
124 |
+
st.write(st.session_state.story)
|
125 |
+
|
126 |
+
st.write("### 故事配图:")
|
127 |
+
for idx, illus in enumerate(st.session_state.illustrations):
|
128 |
+
st.image(illus, caption=f"配图段落 {idx+1}", use_column_width=True)
|
129 |
+
|
130 |
+
st.write("---")
|
131 |
+
st.write("是否继续生成故事?如果不再扩展,请点击“结束互动”。")
|
132 |
+
|
133 |
+
# 接收用户输入的额外情节提示(可选)
|
134 |
+
user_input = st.text_input("请输入你希望添加的故事情节(可选):", value="")
|
135 |
+
|
136 |
+
col1, col2 = st.columns(2)
|
137 |
+
if col1.button("继续生成故事"):
|
138 |
+
# 使用现有故事作为上下文,并附加用户输入的提示语生成新段落
|
139 |
+
additional_prompt = st.session_state.story + " " + (user_input if user_input.strip() != "" else "")
|
140 |
+
with st.spinner("正在生成新的故事段落..."):
|
141 |
+
new_segment = generate_story(additional_prompt)
|
142 |
+
st.session_state.story += " " + new_segment
|
143 |
+
# 为新段落生成配图,取新段落前200个字符作为提示
|
144 |
+
with st.spinner("正在生成新的配图..."):
|
145 |
+
new_illustration = generate_illustration(new_segment[:200])
|
146 |
+
st.session_state.illustrations.append(new_illustration)
|
147 |
+
st.experimental_rerun()
|
148 |
+
|
149 |
+
if col2.button("结束互动"):
|
150 |
+
with st.spinner("正在生成故事音频..."):
|
151 |
+
audio_file = text_to_speech(st.session_state.story)
|
152 |
+
st.write("故事生成完毕!请点击下方按钮播放故事音频。")
|
153 |
+
st.audio(audio_file, format="audio/mp3")
|
154 |
|
|
|
|
|
|
|
|
|
|
|
155 |
if __name__ == "__main__":
|
156 |
main()
|