Update app.py
Browse files
app.py
CHANGED
@@ -12,25 +12,20 @@ def generate_caption(image_file):
|
|
12 |
caption = caption_results[0]['generated_text']
|
13 |
return caption
|
14 |
|
15 |
-
|
16 |
def generate_story(caption):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
story = result[0]['generated_text']
|
22 |
-
|
23 |
-
if len(story.split()) < 100:
|
24 |
-
additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
|
25 |
-
story += " " + additional
|
26 |
return story
|
27 |
|
28 |
-
# ----------------------------
|
29 |
-
# generate_illustration
|
30 |
-
# ----------------------------
|
31 |
@st.cache_resource
|
32 |
# def load_image_generator():
|
33 |
-
|
34 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
# torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
36 |
# pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
|
@@ -43,7 +38,6 @@ def generate_story(caption):
|
|
43 |
# generated_image = image_result.images[0]
|
44 |
# return generated_image
|
45 |
|
46 |
-
|
47 |
def text_to_speech(text, output_file="output.mp3"):
|
48 |
tts = gTTS(text=text, lang="en")
|
49 |
tts.save(output_file)
|
@@ -70,7 +64,6 @@ def main():
|
|
70 |
|
71 |
# with st.spinner("Generating illustration..."):
|
72 |
# illustration = generate_illustration(story[:200])
|
73 |
-
|
74 |
# st.write("### Story Illustrations:")
|
75 |
# st.image(illustration, caption="Story Illustrations", use_container_width=True)
|
76 |
|
|
|
12 |
caption = caption_results[0]['generated_text']
|
13 |
return caption
|
14 |
|
|
|
15 |
def generate_story(caption):
|
16 |
+
story_generator = pipeline("text-generation", model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B")
|
17 |
+
messages = [
|
18 |
+
{
|
19 |
+
"role": "user",
|
20 |
+
"content": f"Please based on following image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words and max 300 words"
|
21 |
+
}
|
22 |
+
]
|
23 |
+
result = story_generator(messages, max_length=300, num_return_sequences=1)
|
24 |
story = result[0]['generated_text']
|
|
|
|
|
|
|
|
|
25 |
return story
|
26 |
|
|
|
|
|
|
|
27 |
@st.cache_resource
|
28 |
# def load_image_generator():
|
|
|
29 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
# torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
31 |
# pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
|
|
|
38 |
# generated_image = image_result.images[0]
|
39 |
# return generated_image
|
40 |
|
|
|
41 |
def text_to_speech(text, output_file="output.mp3"):
|
42 |
tts = gTTS(text=text, lang="en")
|
43 |
tts.save(output_file)
|
|
|
64 |
|
65 |
# with st.spinner("Generating illustration..."):
|
66 |
# illustration = generate_illustration(story[:200])
|
|
|
67 |
# st.write("### Story Illustrations:")
|
68 |
# st.image(illustration, caption="Story Illustrations", use_container_width=True)
|
69 |
|