xieqilenb commited on
Commit
b55983c
·
verified ·
1 Parent(s): 342ef7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -3,7 +3,6 @@ from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
5
 
6
- # 设置页面基本配置:标题、标签和图标
7
  st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
8
 
9
  def generate_caption(image_file):
@@ -13,22 +12,23 @@ def generate_caption(image_file):
13
  caption = caption_results[0]['generated_text']
14
  return caption
15
 
 
16
  def generate_story(prompt):
17
  story_generator = pipeline("text-generation", model="gpt2")
18
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
19
  story = result[0]['generated_text']
20
 
21
- # 如果生成的故事长度较短,则额外生成一部分内容
22
  if len(story.split()) < 100:
23
  additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
24
  story += " " + additional
25
  return story
26
 
27
  # ----------------------------
28
- # generate_illustration (暂时注释掉,如果需要启用请解除注释)
29
  # ----------------------------
30
  @st.cache_resource
31
  # def load_image_generator():
 
32
  # device = "cuda" if torch.cuda.is_available() else "cpu"
33
  # torch_dtype = torch.float16 if device == "cuda" else torch.float32
34
  # pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
@@ -41,24 +41,23 @@ def generate_story(prompt):
41
  # generated_image = image_result.images[0]
42
  # return generated_image
43
 
 
44
  def text_to_speech(text, output_file="output.mp3"):
45
  tts = gTTS(text=text, lang="en")
46
  tts.save(output_file)
47
  return output_file
48
 
49
  def main():
50
- # 居中显示标题,添加图标
51
  st.markdown("<h1 style='text-align: center;'>Your Image to Audio Story 🦜</h1>", unsafe_allow_html=True)
52
  st.write("Upload an image below and we will generate an engaging story from the picture, then convert the story into an audio playback!")
53
 
54
- # 图片上传
55
  uploaded_file = st.file_uploader("Select Image", type=["png", "jpg", "jpeg"])
56
 
57
  if uploaded_file is not None:
58
  image = Image.open(uploaded_file)
59
  st.image(image, caption="Uploaded image", use_column_width=True)
60
 
61
- with st.spinner("Generating image caption..."):
62
  caption = generate_caption(uploaded_file)
63
  st.write("**Image Caption:**", caption)
64
 
@@ -68,13 +67,13 @@ def main():
68
  st.write("**Story:**")
69
  st.write(story)
70
 
71
- # 以下代码为生成故事插图(需要启用相关模型支持)
72
  # with st.spinner("Generating illustration..."):
73
  # illustration = generate_illustration(story[:200])
74
- # st.write("### Story Illustrations:")
 
75
  # st.image(illustration, caption="Story Illustrations", use_column_width=True)
76
 
77
- with st.spinner("Converting text to voice..."):
78
  audio_file = text_to_speech(story)
79
  st.audio(audio_file, format="audio/mp3")
80
 
 
3
  from transformers import pipeline
4
  from gtts import gTTS
5
 
 
6
  st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
7
 
8
  def generate_caption(image_file):
 
12
  caption = caption_results[0]['generated_text']
13
  return caption
14
 
15
+
16
  def generate_story(prompt):
17
  story_generator = pipeline("text-generation", model="gpt2")
18
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
19
  story = result[0]['generated_text']
20
 
 
21
  if len(story.split()) < 100:
22
  additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
23
  story += " " + additional
24
  return story
25
 
26
  # ----------------------------
27
+ # generate_illustration
28
  # ----------------------------
29
  @st.cache_resource
30
  # def load_image_generator():
31
+
32
  # device = "cuda" if torch.cuda.is_available() else "cpu"
33
  # torch_dtype = torch.float16 if device == "cuda" else torch.float32
34
  # pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
 
41
  # generated_image = image_result.images[0]
42
  # return generated_image
43
 
44
+
45
  def text_to_speech(text, output_file="output.mp3"):
46
  tts = gTTS(text=text, lang="en")
47
  tts.save(output_file)
48
  return output_file
49
 
50
  def main():
 
51
  st.markdown("<h1 style='text-align: center;'>Your Image to Audio Story 🦜</h1>", unsafe_allow_html=True)
52
  st.write("Upload an image below and we will generate an engaging story from the picture, then convert the story into an audio playback!")
53
 
 
54
  uploaded_file = st.file_uploader("Select Image", type=["png", "jpg", "jpeg"])
55
 
56
  if uploaded_file is not None:
57
  image = Image.open(uploaded_file)
58
  st.image(image, caption="Uploaded image", use_column_width=True)
59
 
60
+ with st.spinner("Image caption being generated..."):
61
  caption = generate_caption(uploaded_file)
62
  st.write("**Image Caption:**", caption)
63
 
 
67
  st.write("**Story:**")
68
  st.write(story)
69
 
 
70
  # with st.spinner("Generating illustration..."):
71
  # illustration = generate_illustration(story[:200])
72
+
73
+ # st.write("### Story Illustrations:")
74
  # st.image(illustration, caption="Story Illustrations", use_column_width=True)
75
 
76
+ with st.spinner("Converting to voice...."):
77
  audio_file = text_to_speech(story)
78
  st.audio(audio_file, format="audio/mp3")
79