xieqilenb commited on
Commit
06bce28
·
verified ·
1 Parent(s): 7f4a56e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -65
app.py CHANGED
@@ -1,74 +1,92 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
  from PIL import Image
 
4
  from gtts import gTTS
5
- from io import BytesIO
6
 
7
- def get_image_captioner():
8
- return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
9
 
10
- def get_story_generator():
11
- return pipeline("text-generation", model="Qwen/Qwen2-1.5B")
12
 
13
- def text_to_speech(text):
14
- tts = gTTS(text=text, lang="en")
15
- audio_bytes = BytesIO()
16
- tts.write_to_fp(audio_bytes)
17
- audio_bytes.seek(0)
18
- return audio_bytes
 
 
 
 
 
19
 
20
- st.set_page_config(page_title="Image to Audio Story", page_icon="🦜")
21
- st.title("Children's Storytelling App")
22
- st.write("Upload an image and let the magic create a story, then convert it to audio!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- uploaded_file = st.file_uploader("Select an image...", type=["jpg", "png", "jpeg"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- if uploaded_file is not None:
27
- try:
28
- image = Image.open(uploaded_file).convert("RGB")
29
- st.image(image, caption="Uploaded Image", use_container_width=True)
30
- if st.button("Generate Story"):
31
- with st.spinner("Generating content..."):
32
- captioner = get_image_captioner()
33
- caption_result = captioner(image)
34
- caption = caption_result[0]["generated_text"]
35
- st.subheader("Image Caption")
36
- st.write(caption)
37
- prompt = (
38
- "You are a creative children's story writer. Based on the following image details, "
39
- "please write an imaginative story for children aged 3-10. Do not simply rephrase the image details; "
40
- "instead, expand creatively by adding fun characters, adventures, and unexpected twists. "
41
- "The story must be at least 100 words long.\n\n"
42
- f"Image Details: {caption}\n\nStory:"
43
- )
44
- story_generator = get_story_generator()
45
- story_result = story_generator(
46
- prompt,
47
- max_length=300,
48
- min_length=100,
49
- num_return_sequences=1,
50
- do_sample=True,
51
- top_p=0.95,
52
- top_k=50
53
- )
54
- story = story_result[0]["generated_text"]
55
- while len(story.split()) < 100:
56
- story_result = story_generator(
57
- prompt,
58
- max_length=300,
59
- min_length=100,
60
- num_return_sequences=1,
61
- do_sample=True,
62
- top_p=0.95,
63
- top_k=50
64
- )
65
- story = story_result[0]["generated_text"]
66
- if "Story:" in story:
67
- story = story.split("Story:", 1)[-1].strip()
68
- st.subheader("Generated Story")
69
- st.write(story)
70
- audio_bytes = text_to_speech(story)
71
- st.subheader("Listen to the Story")
72
- st.audio(audio_bytes, format="audio/mp3")
73
- except Exception as e:
74
- st.error(f"An error occurred: {e}")
 
1
  import streamlit as st
 
2
  from PIL import Image
3
+ from transformers import pipeline
4
  from gtts import gTTS
5
+ import torch
6
 
7
+ st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
 
8
 
9
+ # 判断是否有可用的 GPU,如果有则使用 GPU(device=0),否则使用 CPU(device=-1)
10
+ device_id = 0 if torch.cuda.is_available() else -1
11
 
12
+ def generate_caption(image_file):
13
+ image = Image.open(image_file)
14
+ # 使用 GPU 进行图像描述生成,如果可用
15
+ caption_generator = pipeline(
16
+ "image-to-text",
17
+ model="Salesforce/blip-image-captioning-base",
18
+ device=device_id
19
+ )
20
+ caption_results = caption_generator(image)
21
+ caption = caption_results[0]['generated_text']
22
+ return caption
23
 
24
+ def generate_story(caption):
25
+ # 使用 GPU 进行文本生成操作
26
+ story_generator = pipeline(
27
+ "text-generation",
28
+ model="Qwen/Qwen2-1.5B",
29
+ device=device_id
30
+ )
31
+ messages = (
32
+ "Please based on following image caption: " + caption +
33
+ ", generate a complete fairy tale story for children with at least 100 words and max 300 words"
34
+ )
35
+ result = story_generator(messages, max_length=300, num_return_sequences=1)
36
+ story = result[0]['generated_text']
37
+ return story
38
+
39
+ # 以下部分为生成插图示例代码,已注释。如果需要使用 GPU,请取消注释并确保 diffusers 相关依赖已经安装
40
+ # @st.cache_resource
41
+ # def load_image_generator():
42
+ # from diffusers import DiffusionPipeline
43
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ # torch_dtype = torch.float16 if device == "cuda" else torch.float32
45
+ # pipe = DiffusionPipeline.from_pretrained(
46
+ # "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype
47
+ # )
48
+ # pipe = pipe.to(device)
49
+ # return pipe
50
+ #
51
+ # def generate_illustration(prompt):
52
+ # pipe = load_image_generator()
53
+ # image_result = pipe(prompt)
54
+ # generated_image = image_result.images[0]
55
+ # return generated_image
56
+
57
+ def text_to_speech(text, output_file="output.mp3"):
58
+ tts = gTTS(text=text, lang="en")
59
+ tts.save(output_file)
60
+ return output_file
61
 
62
+ def main():
63
+ st.markdown("<h1 style='text-align: center;'>Your Image to Audio Story 🦜</h1>", unsafe_allow_html=True)
64
+ st.write("Upload an image below and we will generate an engaging story from the picture, then convert the story into an audio playback!")
65
+
66
+ uploaded_file = st.file_uploader("Select Image", type=["png", "jpg", "jpeg"])
67
+
68
+ if uploaded_file is not None:
69
+ image = Image.open(uploaded_file)
70
+ st.image(image, caption="Uploaded image", use_container_width=True)
71
+
72
+ with st.spinner("Image caption being generated..."):
73
+ caption = generate_caption(uploaded_file)
74
+ st.write("**Image Caption:**", caption)
75
+
76
+ with st.spinner("Generating story..."):
77
+ story = generate_story(caption)
78
+ st.write("**Story:**")
79
+ st.write(story)
80
+
81
+ # 如果需要生成插图,请取消以下代码的注释
82
+ # with st.spinner("Generating illustration..."):
83
+ # illustration = generate_illustration(story[:200])
84
+ # st.write("### Story Illustrations:")
85
+ # st.image(illustration, caption="Story Illustrations", use_container_width=True)
86
+
87
+ with st.spinner("Converting to voice..."):
88
+ audio_file = text_to_speech(story)
89
+ st.audio(audio_file, format="audio/mp3")
90
 
91
+ if __name__ == "__main__":
92
+ main()