xieqilenb commited on
Commit
d42bff7
·
verified ·
1 Parent(s): 6b5dec5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -12,25 +12,20 @@ def generate_caption(image_file):
12
  caption = caption_results[0]['generated_text']
13
  return caption
14
 
15
-
16
  def generate_story(caption):
17
-
18
- story_generator = pipeline("text-generation", model="openai-community/gpt2")
19
- prompt = f"Please based on following image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words. "
20
- result = story_generator(prompt, max_length=300, num_return_sequences=1)
 
 
 
 
21
  story = result[0]['generated_text']
22
-
23
- if len(story.split()) < 100:
24
- additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
25
- story += " " + additional
26
  return story
27
 
28
- # ----------------------------
29
- # generate_illustration
30
- # ----------------------------
31
  @st.cache_resource
32
  # def load_image_generator():
33
-
34
  # device = "cuda" if torch.cuda.is_available() else "cpu"
35
  # torch_dtype = torch.float16 if device == "cuda" else torch.float32
36
  # pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
@@ -43,7 +38,6 @@ def generate_story(caption):
43
  # generated_image = image_result.images[0]
44
  # return generated_image
45
 
46
-
47
  def text_to_speech(text, output_file="output.mp3"):
48
  tts = gTTS(text=text, lang="en")
49
  tts.save(output_file)
@@ -70,7 +64,6 @@ def main():
70
 
71
  # with st.spinner("Generating illustration..."):
72
  # illustration = generate_illustration(story[:200])
73
-
74
  # st.write("### Story Illustrations:")
75
  # st.image(illustration, caption="Story Illustrations", use_container_width=True)
76
 
 
12
  caption = caption_results[0]['generated_text']
13
  return caption
14
 
 
15
  def generate_story(caption):
16
+ story_generator = pipeline("text-generation", model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B")
17
+ messages = [
18
+ {
19
+ "role": "user",
20
+ "content": f"Please based on following image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words and max 300 words"
21
+ }
22
+ ]
23
+ result = story_generator(messages, max_length=300, num_return_sequences=1)
24
  story = result[0]['generated_text']
 
 
 
 
25
  return story
26
 
 
 
 
27
  @st.cache_resource
28
  # def load_image_generator():
 
29
  # device = "cuda" if torch.cuda.is_available() else "cpu"
30
  # torch_dtype = torch.float16 if device == "cuda" else torch.float32
31
  # pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
 
38
  # generated_image = image_result.images[0]
39
  # return generated_image
40
 
 
41
  def text_to_speech(text, output_file="output.mp3"):
42
  tts = gTTS(text=text, lang="en")
43
  tts.save(output_file)
 
64
 
65
  # with st.spinner("Generating illustration..."):
66
  # illustration = generate_illustration(story[:200])
 
67
  # st.write("### Story Illustrations:")
68
  # st.image(illustration, caption="Story Illustrations", use_container_width=True)
69