asigalov61 commited on
Commit
fb0ff2b
·
verified ·
1 Parent(s): 832c4d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -223,20 +223,19 @@ def Generate_Rock_Song(input_midi,
223
 
224
  #==================================================================
225
 
226
- def generate_tokens(seq, max_num_ptcs=10):
227
 
228
  input = copy.deepcopy(seq)
229
 
230
- input = input[-input_num_memory_tokens:]
231
-
232
  pcount = 0
233
  y = 545
 
234
 
235
  gen_tokens = []
 
 
236
 
237
- while pcount < max_num_ptcs and y > 255:
238
-
239
- x = torch.tensor(input, dtype=torch.long, device=DEVICE)
240
 
241
  with ctx:
242
  out = model.generate(x,
@@ -249,12 +248,15 @@ def Generate_Rock_Song(input_midi,
249
 
250
  y = out[0].tolist()[0]
251
 
252
- if pcount < max_num_ptcs and y > 255:
253
  input.append(y)
254
  gen_tokens.append(y)
255
  if y > 544:
256
  pcount += 1
257
-
 
 
 
258
  return gen_tokens
259
 
260
  #==================================================================
 
223
 
224
  #==================================================================
225
 
226
+ def generate_tokens(seq, max_num_ptcs=4, max_tries=10):
227
 
228
  input = copy.deepcopy(seq)
229
 
 
 
230
  pcount = 0
231
  y = 545
232
+ tries = 0
233
 
234
  gen_tokens = []
235
+
236
+ while pcount < max_num_ptcs and y > 255 and tries < max_tries:
237
 
238
+ x = torch.tensor(input[-input_num_memory_tokens:], dtype=torch.long, device=DEVICE)
 
 
239
 
240
  with ctx:
241
  out = model.generate(x,
 
248
 
249
  y = out[0].tolist()[0]
250
 
251
+ if pcount < max_num_ptcs and y > 255 and y not in gen_tokens:
252
  input.append(y)
253
  gen_tokens.append(y)
254
  if y > 544:
255
  pcount += 1
256
+
257
+ else:
258
+ tries += 1
259
+
260
  return gen_tokens
261
 
262
  #==================================================================