File size: 3,172 Bytes
1fb6258
0887ac5
1fb6258
6b8bebd
7577927
342ef7b
7577927
1fb6258
 
 
 
7577927
0887ac5
e775448
b55983c
8769d88
 
1fb6258
8769d88
1fb6258
0887ac5
1fb6258
0887ac5
1fb6258
 
0887ac5
e775448
1fb6258
b55983c
6b8bebd
 
7577927
b55983c
7577927
 
 
 
 
 
 
 
 
 
 
6b8bebd
b55983c
0887ac5
7577927
0887ac5
 
e775448
0887ac5
342ef7b
 
0887ac5
7577927
1fb6258
0887ac5
 
f72a080
0887ac5
b55983c
1fb6258
342ef7b
0887ac5
7577927
342ef7b
 
 
a31b925
6b8bebd
7577927
 
b55983c
 
f72a080
0887ac5
b55983c
a31b925
 
 
0887ac5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS

st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")

def generate_caption(image_file):
    image = Image.open(image_file)
    caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
    caption_results = caption_generator(image)
    caption = caption_results[0]['generated_text']
    return caption


def generate_story(caption):

    story_generator = pipeline("text-generation", model="gpt2")
    prompt = f"Please based on following image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words. "
    result = story_generator(prompt, max_length=300, num_return_sequences=1)
    story = result[0]['generated_text']
    
    if len(story.split()) < 100:
        additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
        story += " " + additional
    return story

# ----------------------------
# generate_illustration
# ----------------------------
@st.cache_resource
# def load_image_generator():

#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     torch_dtype = torch.float16 if device == "cuda" else torch.float32
#     pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
#     pipe = pipe.to(device)
#     return pipe

# def generate_illustration(prompt):
#     pipe = load_image_generator()
#     image_result = pipe(prompt)
#     generated_image = image_result.images[0]
#     return generated_image


def text_to_speech(text, output_file="output.mp3"):
    tts = gTTS(text=text, lang="en")
    tts.save(output_file)
    return output_file

def main():
    st.markdown("<h1 style='text-align: center;'>Your Image to Audio Story 🦜</h1>", unsafe_allow_html=True)
    st.write("Upload an image below and we will generate an engaging story from the picture, then convert the story into an audio playback!")
    
    uploaded_file = st.file_uploader("Select Image", type=["png", "jpg", "jpeg"])
    
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded image", use_container_width=True)
        
        with st.spinner("Image caption being generated..."):
            caption = generate_caption(uploaded_file)
        st.write("**Image Caption:**", caption)
        
        with st.spinner("Generating story..."):
            story_prompt = f"Please generate a children's story based on this description: {caption}"
            story = generate_story(story_prompt)
        st.write("**Story:**")
        st.write(story)
        
        # with st.spinner("Generating illustration..."):
        #     illustration = generate_illustration(story[:200])
        
        # st.write("### Story Illustrations:")
        # st.image(illustration, caption="Story Illustrations", use_container_width=True)
        
        with st.spinner("Converting to voice...."):
            audio_file = text_to_speech(story)
        st.audio(audio_file, format="audio/mp3")

if __name__ == "__main__":
    main()