xieqilenb commited on
Commit
342ef7b
·
verified ·
1 Parent(s): 7577927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -3,6 +3,8 @@ from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
5
 
 
 
6
 
7
  def generate_caption(image_file):
8
  image = Image.open(image_file)
@@ -11,23 +13,22 @@ def generate_caption(image_file):
11
  caption = caption_results[0]['generated_text']
12
  return caption
13
 
14
-
15
  def generate_story(prompt):
16
  story_generator = pipeline("text-generation", model="gpt2")
17
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
18
  story = result[0]['generated_text']
19
 
 
20
  if len(story.split()) < 100:
21
  additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
22
  story += " " + additional
23
  return story
24
 
25
  # ----------------------------
26
- # generate_illustration
27
  # ----------------------------
28
  @st.cache_resource
29
  # def load_image_generator():
30
-
31
  # device = "cuda" if torch.cuda.is_available() else "cpu"
32
  # torch_dtype = torch.float16 if device == "cuda" else torch.float32
33
  # pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
@@ -40,38 +41,40 @@ def generate_story(prompt):
40
  # generated_image = image_result.images[0]
41
  # return generated_image
42
 
43
-
44
  def text_to_speech(text, output_file="output.mp3"):
45
  tts = gTTS(text=text, lang="en")
46
  tts.save(output_file)
47
  return output_file
48
 
49
  def main():
50
- st.title("Storytelling App")
51
- st.write("Upload a image and we will generate an interesting story based on the picture and convert it into a voice playback!")
 
52
 
 
53
  uploaded_file = st.file_uploader("Select Image", type=["png", "jpg", "jpeg"])
54
 
55
  if uploaded_file is not None:
56
  image = Image.open(uploaded_file)
57
  st.image(image, caption="Uploaded image", use_column_width=True)
58
 
59
- with st.spinner("Image caption being generated..."):
60
  caption = generate_caption(uploaded_file)
61
- st.write("Image Caption", caption)
62
 
63
  with st.spinner("Generating story..."):
64
- story = generate_story(caption)
65
- st.write("Story:")
 
66
  st.write(story)
67
 
 
68
  # with st.spinner("Generating illustration..."):
69
  # illustration = generate_illustration(story[:200])
70
-
71
- # st.write("### Story Illustrations:")
72
  # st.image(illustration, caption="Story Illustrations", use_column_width=True)
73
 
74
- with st.spinner("Converting to voice...."):
75
  audio_file = text_to_speech(story)
76
  st.audio(audio_file, format="audio/mp3")
77
 
 
3
  from transformers import pipeline
4
  from gtts import gTTS
5
 
6
+ # 设置页面基本配置:标题、标签和图标
7
+ st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
8
 
9
  def generate_caption(image_file):
10
  image = Image.open(image_file)
 
13
  caption = caption_results[0]['generated_text']
14
  return caption
15
 
 
16
  def generate_story(prompt):
17
  story_generator = pipeline("text-generation", model="gpt2")
18
  result = story_generator(prompt, max_length=300, num_return_sequences=1)
19
  story = result[0]['generated_text']
20
 
21
+ # 如果生成的故事长度较短,则额外生成一部分内容
22
  if len(story.split()) < 100:
23
  additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
24
  story += " " + additional
25
  return story
26
 
27
  # ----------------------------
28
+ # generate_illustration (暂时注释掉,如果需要启用请解除注释)
29
  # ----------------------------
30
  @st.cache_resource
31
  # def load_image_generator():
 
32
  # device = "cuda" if torch.cuda.is_available() else "cpu"
33
  # torch_dtype = torch.float16 if device == "cuda" else torch.float32
34
  # pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
 
41
  # generated_image = image_result.images[0]
42
  # return generated_image
43
 
 
44
  def text_to_speech(text, output_file="output.mp3"):
45
  tts = gTTS(text=text, lang="en")
46
  tts.save(output_file)
47
  return output_file
48
 
49
  def main():
50
+ # 居中显示标题,添加图标
51
+ st.markdown("<h1 style='text-align: center;'>Your Image to Audio Story 🦜</h1>", unsafe_allow_html=True)
52
+ st.write("Upload an image below and we will generate an engaging story from the picture, then convert the story into an audio playback!")
53
 
54
+ # 图片上传
55
  uploaded_file = st.file_uploader("Select Image", type=["png", "jpg", "jpeg"])
56
 
57
  if uploaded_file is not None:
58
  image = Image.open(uploaded_file)
59
  st.image(image, caption="Uploaded image", use_column_width=True)
60
 
61
+ with st.spinner("Generating image caption..."):
62
  caption = generate_caption(uploaded_file)
63
+ st.write("**Image Caption:**", caption)
64
 
65
  with st.spinner("Generating story..."):
66
+ story_prompt = f"Please generate a children's story based on this description: {caption}"
67
+ story = generate_story(story_prompt)
68
+ st.write("**Story:**")
69
  st.write(story)
70
 
71
+ # 以下代码为生成故事插图(需要启用相关模型支持)
72
  # with st.spinner("Generating illustration..."):
73
  # illustration = generate_illustration(story[:200])
74
+ # st.write("### Story Illustrations:")
 
75
  # st.image(illustration, caption="Story Illustrations", use_column_width=True)
76
 
77
+ with st.spinner("Converting text to voice..."):
78
  audio_file = text_to_speech(story)
79
  st.audio(audio_file, format="audio/mp3")
80