Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	no progress on batch
Browse files
    	
        app.py
    CHANGED
    
    | @@ -49,6 +49,7 @@ def interrupt(): | |
| 49 | 
             
                global INTERRUPTING
         | 
| 50 | 
             
                INTERRUPTING = True
         | 
| 51 |  | 
|  | |
| 52 | 
             
            def make_waveform(*args, **kwargs):
         | 
| 53 | 
             
                # Further remove some warnings.
         | 
| 54 | 
             
                be = time.time()
         | 
| @@ -66,7 +67,7 @@ def load_model(version='melody'): | |
| 66 | 
             
                    MODEL = MusicGen.get_pretrained(version)
         | 
| 67 |  | 
| 68 |  | 
| 69 | 
            -
            def _do_predictions(texts, melodies, duration, **gen_kwargs):
         | 
| 70 | 
             
                MODEL.set_generation_params(duration=duration, **gen_kwargs)
         | 
| 71 | 
             
                print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
         | 
| 72 | 
             
                be = time.time()
         | 
| @@ -89,10 +90,10 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs): | |
| 89 | 
             
                        descriptions=texts,
         | 
| 90 | 
             
                        melody_wavs=processed_melodies,
         | 
| 91 | 
             
                        melody_sample_rate=target_sr,
         | 
| 92 | 
            -
                        progress= | 
| 93 | 
             
                    )
         | 
| 94 | 
             
                else:
         | 
| 95 | 
            -
                    outputs = MODEL.generate(texts, progress= | 
| 96 |  | 
| 97 | 
             
                outputs = outputs.detach().cpu().float()
         | 
| 98 | 
             
                out_files = []
         | 
| @@ -128,7 +129,7 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe | |
| 128 | 
             
                MODEL.set_custom_progress_callback(_progress)
         | 
| 129 |  | 
| 130 | 
             
                outs = _do_predictions(
         | 
| 131 | 
            -
                    [text], [melody], duration,
         | 
| 132 | 
             
                    top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
         | 
| 133 | 
             
                return outs[0]
         | 
| 134 |  | 
| @@ -324,6 +325,8 @@ if __name__ == "__main__": | |
| 324 | 
             
                args = parser.parse_args()
         | 
| 325 |  | 
| 326 | 
             
                launch_kwargs = {}
         | 
|  | |
|  | |
| 327 | 
             
                if args.username and args.password:
         | 
| 328 | 
             
                    launch_kwargs['auth'] = (args.username, args.password)
         | 
| 329 | 
             
                if args.server_port:
         | 
|  | |
| 49 | 
             
                global INTERRUPTING
         | 
| 50 | 
             
                INTERRUPTING = True
         | 
| 51 |  | 
| 52 | 
            +
             | 
| 53 | 
             
            def make_waveform(*args, **kwargs):
         | 
| 54 | 
             
                # Further remove some warnings.
         | 
| 55 | 
             
                be = time.time()
         | 
|  | |
| 67 | 
             
                    MODEL = MusicGen.get_pretrained(version)
         | 
| 68 |  | 
| 69 |  | 
| 70 | 
            +
            def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
         | 
| 71 | 
             
                MODEL.set_generation_params(duration=duration, **gen_kwargs)
         | 
| 72 | 
             
                print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
         | 
| 73 | 
             
                be = time.time()
         | 
|  | |
| 90 | 
             
                        descriptions=texts,
         | 
| 91 | 
             
                        melody_wavs=processed_melodies,
         | 
| 92 | 
             
                        melody_sample_rate=target_sr,
         | 
| 93 | 
            +
                        progress=progress,
         | 
| 94 | 
             
                    )
         | 
| 95 | 
             
                else:
         | 
| 96 | 
            +
                    outputs = MODEL.generate(texts, progress=progress)
         | 
| 97 |  | 
| 98 | 
             
                outputs = outputs.detach().cpu().float()
         | 
| 99 | 
             
                out_files = []
         | 
|  | |
| 129 | 
             
                MODEL.set_custom_progress_callback(_progress)
         | 
| 130 |  | 
| 131 | 
             
                outs = _do_predictions(
         | 
| 132 | 
            +
                    [text], [melody], duration, progress=True,
         | 
| 133 | 
             
                    top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
         | 
| 134 | 
             
                return outs[0]
         | 
| 135 |  | 
|  | |
| 325 | 
             
                args = parser.parse_args()
         | 
| 326 |  | 
| 327 | 
             
                launch_kwargs = {}
         | 
| 328 | 
            +
                launch_kwargs['server_name'] = args.listen
         | 
| 329 | 
            +
             | 
| 330 | 
             
                if args.username and args.password:
         | 
| 331 | 
             
                    launch_kwargs['auth'] = (args.username, args.password)
         | 
| 332 | 
             
                if args.server_port:
         | 
 
			

