Shankhdhar commited on
Commit
c8d0344
·
1 Parent(s): 1d7eda9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -1,21 +1,27 @@
 
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelWithLMHead
 
3
  from transformers import pipeline
4
  st.title("Rap Lyrics Generator")
5
  st.image('./parental.png')
6
  model_ckpt = "flax-community/gpt2-rap-lyric-generator"
7
  tokenizer = AutoTokenizer.from_pretrained(model_ckpt,from_flax=True)
8
- model = AutoModelWithLMHead.from_pretrained(model_ckpt,from_flax=True)
9
  text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
10
- artist = st.text_input("Enter the artist", "Kanye West")
11
- prefix_text = st.text_input("Enter the prefix text", "Let's party tonight")
12
- if len(artist)>0 and len(prefix_text)>0:
13
- prefix_text = "[Verse 1: "+artist+" ]" + "\n" + prefix_text
14
- elif len(artist)>0:
15
- prefix_text = "[Verse 1: "+artist+" ]"
16
- generated_text= text_generation(prefix_text, max_length=500, do_sample=True)[0]
17
- output = generated_text['generated_text']
18
- list1 = output.split("\n")
19
- for l in list1:
20
- st.write(l)
21
-
 
 
 
 
 
1
+
2
  import streamlit as st
3
+ from streamlit.elements.altair import generate_chart
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from transformers import pipeline
6
  st.title("Rap Lyrics Generator")
7
  st.image('./parental.png')
8
  model_ckpt = "flax-community/gpt2-rap-lyric-generator"
9
  tokenizer = AutoTokenizer.from_pretrained(model_ckpt,from_flax=True)
10
+ model = AutoModelForCausalLM.from_pretrained(model_ckpt,from_flax=True)
11
  text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
+ artist = st.text_input("Enter the artist", "Eminem")
13
+ song_name = st.text_input("Enter the desired song name", "Gas is going")
14
+ if st.button("Generate lyrics"):
15
+ st.title(f"{artist}: {song_name}")
16
+ prefix_text = f"<BOS>{song_name} [Verse 1:{artist}]"
17
+ generated_song = text_generation(prefix_text, max_length=500, do_sample=True)[0]
18
+ for count, line in enumerate(generated_song['generated_text'].split("\n")):
19
+ if count == 0:
20
+ st.write(line[line.find('['):])
21
+ continue
22
+ if"<EOS>" in line:
23
+ break
24
+ if "<BOS>" in line:
25
+ st.write(line[5:])
26
+ continue
27
+ st.write(line)