xieqilenb commited on
Commit
7577927
·
verified ·
1 Parent(s): c3409d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -84
app.py CHANGED
@@ -2,37 +2,17 @@ import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
5
- import torch
6
- import os
7
- from diffusers import DiffusionPipeline
8
- # ----------------------------
9
- # 1. 图像描述生成函数
10
- # ----------------------------
11
  def generate_caption(image_file):
12
- """
13
- 使用 Hugging Face pipeline 的 image-to-text 模型生成图片描述
14
- 参数:
15
- image_file: 上传的图片文件(文件对象或文件路径)
16
- 返回:
17
- caption: 生成的图片描述文本
18
- """
19
  image = Image.open(image_file)
20
  caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
21
  caption_results = caption_generator(image)
22
- caption = caption_results[0]['generated_text'] # 取返回结果的第一个描述
23
  return caption
24
 
25
- # ----------------------------
26
- # 2. 故事生成函数
27
- # ----------------------------
28
  def generate_story(prompt):
29
- """
30
- 基于提示语生成故事段落,要求至少100个单词,如果生成的文本字数不够,则再次补充
31
- 参数:
32
- prompt: 文本生成的提示语
33
- 返回:
34
- story: 生成的故事文本片段
35
- """
36
  story_generator = pipeline("text-generation", model="gpt2")
37
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
38
  story = result[0]['generated_text']
@@ -43,87 +23,55 @@ def generate_story(prompt):
43
  return story
44
 
45
  # ----------------------------
46
- # 3. 图像生成(配图)相关函数
47
  # ----------------------------
48
  @st.cache_resource
49
- def load_image_generator():
50
- """
51
- 加载稳定扩散模型,使用 Diffusers 库生成插图
52
- 使用 StableDiffusionPipeline 替代 DiffusionPipeline
53
- """
54
- device = "cuda" if torch.cuda.is_available() else "cpu"
55
-
56
- # 导入 StableDiffusionPipeline
57
- # 对于 GPU,采用 fp16 精度以加速推理;否则使用默认精度
58
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
59
- pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
60
- pipe = pipe.to(device)
61
- return pipe
62
 
63
- def generate_illustration(prompt):
64
- """
65
- 基于输入的提示语生成一张配图
66
- 参数:
67
- prompt: 用于生成图像的文本提示
68
- 返回:
69
- generated_image: 生成的 PIL Image 图像
70
- """
71
- pipe = load_image_generator()
72
- image_result = pipe(prompt)
73
- generated_image = image_result.images[0]
74
- return generated_image
75
 
76
- # ----------------------------
77
- # 4. 文字转语音 (TTS) 函数
78
- # ----------------------------
79
  def text_to_speech(text, output_file="output.mp3"):
80
- """
81
- 将输入文本转换为语音,并保存为 mp3 文件
82
- 参数:
83
- text: 要转换的文本
84
- output_file: 保存的音频文件名
85
- 返回:
86
- output_file: 转换后生成的音频文件路径
87
- """
88
- tts = gTTS(text=text, lang="en") # 如需中文,lang 可设置为 "zh-cn"
89
  tts.save(output_file)
90
  return output_file
91
 
92
- # ----------------------------
93
- # 5. 主函数:构建 Streamlit 交互式应用
94
- # ----------------------------
95
  def main():
96
- st.title("儿童故事生成应用")
97
- st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!")
98
 
99
- uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
100
 
101
  if uploaded_file is not None:
102
- # 显示上传的图片
103
  image = Image.open(uploaded_file)
104
- st.image(image, caption="上传的图片", use_column_width=True)
105
 
106
- # 生成图片描述
107
- with st.spinner("正在生成图片描述..."):
108
  caption = generate_caption(uploaded_file)
109
- st.write("图片描述:", caption)
110
 
111
- # 根据图片描述生成完整故事
112
- with st.spinner("正在生成故事..."):
113
  story = generate_story(caption)
114
- st.write("生成的故事:")
115
  st.write(story)
116
 
117
- # 生成配图
118
- # 这里使用故事内容的前200个字符作为提示生成配图,实际中可以根据需要调整策略
119
- with st.spinner("正在生成插图..."):
120
- illustration = generate_illustration(story[:200])
121
 
122
- st.write("### 故事配图:")
123
- st.image(illustration, caption="配图", use_column_width=True)
124
 
125
- # 文本转语音
126
- with st.spinner("正在转换成语音..."):
127
  audio_file = text_to_speech(story)
128
  st.audio(audio_file, format="audio/mp3")
129
 
 
2
  from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
5
+
6
+
 
 
 
 
7
  def generate_caption(image_file):
 
 
 
 
 
 
 
8
  image = Image.open(image_file)
9
  caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
10
  caption_results = caption_generator(image)
11
+ caption = caption_results[0]['generated_text']
12
  return caption
13
 
14
+
 
 
15
  def generate_story(prompt):
 
 
 
 
 
 
 
16
  story_generator = pipeline("text-generation", model="gpt2")
17
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
18
  story = result[0]['generated_text']
 
23
  return story
24
 
25
  # ----------------------------
26
+ # generate_illustration
27
  # ----------------------------
28
  @st.cache_resource
29
+ # def load_image_generator():
30
+
31
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ # torch_dtype = torch.float16 if device == "cuda" else torch.float32
33
+ # pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
34
+ # pipe = pipe.to(device)
35
+ # return pipe
36
+
37
+ # def generate_illustration(prompt):
38
+ # pipe = load_image_generator()
39
+ # image_result = pipe(prompt)
40
+ # generated_image = image_result.images[0]
41
+ # return generated_image
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
44
  def text_to_speech(text, output_file="output.mp3"):
45
+ tts = gTTS(text=text, lang="en")
 
 
 
 
 
 
 
 
46
  tts.save(output_file)
47
  return output_file
48
 
 
 
 
49
  def main():
50
+ st.title("Storytelling App")
51
+ st.write("Upload a image and we will generate an interesting story based on the picture and convert it into a voice playback!")
52
 
53
+ uploaded_file = st.file_uploader("Select Image", type=["png", "jpg", "jpeg"])
54
 
55
  if uploaded_file is not None:
 
56
  image = Image.open(uploaded_file)
57
+ st.image(image, caption="Uploaded image", use_column_width=True)
58
 
59
+ with st.spinner("Image caption being generated..."):
 
60
  caption = generate_caption(uploaded_file)
61
+ st.write("Image Caption:", caption)
62
 
63
+ with st.spinner("Generating story..."):
 
64
  story = generate_story(caption)
65
+ st.write("Story:")
66
  st.write(story)
67
 
68
+ # with st.spinner("Generating illustration..."):
69
+ # illustration = generate_illustration(story[:200])
 
 
70
 
71
+ # st.write("### Story Illustrations:")
72
+ # st.image(illustration, caption="Story Illustrations", use_column_width=True)
73
 
74
+ with st.spinner("Converting to voice...."):
 
75
  audio_file = text_to_speech(story)
76
  st.audio(audio_file, format="audio/mp3")
77