Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	[feat] add extend
Browse files- pipeline_ace_step.py +100 -16
- ui/components.py +120 -5
    	
        pipeline_ace_step.py
    CHANGED
    
    | @@ -595,23 +595,83 @@ class ACEStepPipeline: | |
| 595 | 
             
                    target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
         | 
| 596 |  | 
| 597 | 
             
                    is_repaint = False
         | 
|  | |
| 598 | 
             
                    if add_retake_noise:
         | 
|  | |
| 599 | 
             
                        retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
         | 
| 600 | 
             
                        retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
         | 
| 601 | 
             
                        repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
         | 
| 602 | 
             
                        repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
         | 
| 603 | 
            -
             | 
| 604 | 
             
                        # retake
         | 
| 605 | 
            -
                        is_repaint = repaint_end_frame - repaint_start_frame != frame_length
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 606 | 
             
                        # to make sure mean = 0, std = 1
         | 
| 607 | 
             
                        if not is_repaint:
         | 
| 608 | 
             
                            target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
         | 
| 609 | 
            -
                         | 
|  | |
| 610 | 
             
                            repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
         | 
| 611 | 
             
                            repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
         | 
| 612 | 
             
                            repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
         | 
| 613 | 
             
                            repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
         | 
| 614 | 
             
                            z0 = repaint_noise
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 615 |  | 
| 616 | 
             
                    attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
         | 
| 617 |  | 
| @@ -716,6 +776,16 @@ class ACEStepPipeline: | |
| 716 | 
             
                        return sample
         | 
| 717 |  | 
| 718 | 
             
                    for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 719 | 
             
                        # expand the latents if we are doing classifier free guidance
         | 
| 720 | 
             
                        latents = target_latents
         | 
| 721 |  | 
| @@ -818,14 +888,27 @@ class ACEStepPipeline: | |
| 818 | 
             
                                timestep=timestep,
         | 
| 819 | 
             
                            ).sample
         | 
| 820 |  | 
| 821 | 
            -
                         | 
| 822 | 
            -
             | 
| 823 | 
            -
                             | 
| 824 | 
            -
             | 
| 825 | 
            -
                             | 
| 826 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 827 |  | 
| 828 | 
            -
                    
         | 
|  | |
|  | |
|  | |
|  | |
| 829 | 
             
                    return target_latents
         | 
| 830 |  | 
| 831 | 
             
                def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
         | 
| @@ -899,6 +982,7 @@ class ACEStepPipeline: | |
| 899 | 
             
                    save_path: str = None,
         | 
| 900 | 
             
                    format: str = "flac",
         | 
| 901 | 
             
                    batch_size: int = 1,
         | 
|  | |
| 902 | 
             
                ):
         | 
| 903 |  | 
| 904 | 
             
                    start_time = time.time()
         | 
| @@ -936,7 +1020,7 @@ class ACEStepPipeline: | |
| 936 | 
             
                    lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
         | 
| 937 | 
             
                    lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
         | 
| 938 | 
             
                    if len(lyrics) > 0:
         | 
| 939 | 
            -
                        lyric_token_idx = self.tokenize_lyrics(lyrics, debug= | 
| 940 | 
             
                        lyric_mask = [1] * len(lyric_token_idx)
         | 
| 941 | 
             
                        lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
         | 
| 942 | 
             
                        lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
         | 
| @@ -949,7 +1033,7 @@ class ACEStepPipeline: | |
| 949 | 
             
                    preprocess_time_cost = end_time - start_time
         | 
| 950 | 
             
                    start_time = end_time
         | 
| 951 |  | 
| 952 | 
            -
                    add_retake_noise = task in ("retake", "repaint")
         | 
| 953 | 
             
                    # retake equal to repaint
         | 
| 954 | 
             
                    if task == "retake":
         | 
| 955 | 
             
                        repaint_start = 0
         | 
| @@ -957,7 +1041,7 @@ class ACEStepPipeline: | |
| 957 |  | 
| 958 | 
             
                    src_latents = None
         | 
| 959 | 
             
                    if src_audio_path is not None:
         | 
| 960 | 
            -
                        assert src_audio_path is not None and task in ("repaint", "edit"), "src_audio_path is required for repaint task"
         | 
| 961 | 
             
                        assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
         | 
| 962 | 
             
                        src_latents = self.infer_latents(src_audio_path)
         | 
| 963 |  | 
| @@ -989,7 +1073,7 @@ class ACEStepPipeline: | |
| 989 | 
             
                            target_lyric_token_ids=target_lyric_token_idx,
         | 
| 990 | 
             
                            target_lyric_mask=target_lyric_mask,
         | 
| 991 | 
             
                            src_latents=src_latents,
         | 
| 992 | 
            -
                            random_generators= | 
| 993 | 
             
                            infer_steps=infer_step,
         | 
| 994 | 
             
                            guidance_scale=guidance_scale,
         | 
| 995 | 
             
                            n_min=edit_n_min,
         | 
| @@ -1048,8 +1132,8 @@ class ACEStepPipeline: | |
| 1048 |  | 
| 1049 | 
             
                    input_params_json = {
         | 
| 1050 | 
             
                        "task": task,
         | 
| 1051 | 
            -
                        "prompt": prompt,
         | 
| 1052 | 
            -
                        "lyrics": lyrics,
         | 
| 1053 | 
             
                        "audio_duration": audio_duration,
         | 
| 1054 | 
             
                        "infer_step": infer_step,
         | 
| 1055 | 
             
                        "guidance_scale": guidance_scale,
         | 
|  | |
| 595 | 
             
                    target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
         | 
| 596 |  | 
| 597 | 
             
                    is_repaint = False
         | 
| 598 | 
            +
                    is_extend  = False
         | 
| 599 | 
             
                    if add_retake_noise:
         | 
| 600 | 
            +
                        n_min = int(infer_steps * (1 - retake_variance))
         | 
| 601 | 
             
                        retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
         | 
| 602 | 
             
                        retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
         | 
| 603 | 
             
                        repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
         | 
| 604 | 
             
                        repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
         | 
| 605 | 
            +
                        x0 = src_latents
         | 
| 606 | 
             
                        # retake
         | 
| 607 | 
            +
                        is_repaint = (repaint_end_frame - repaint_start_frame != frame_length) 
         | 
| 608 | 
            +
                        
         | 
| 609 | 
            +
                        is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
         | 
| 610 | 
            +
                        if is_extend:
         | 
| 611 | 
            +
                            is_repaint = True
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                        # TODO: train a mask aware repainting controlnet
         | 
| 614 | 
             
                        # to make sure mean = 0, std = 1
         | 
| 615 | 
             
                        if not is_repaint:
         | 
| 616 | 
             
                            target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
         | 
| 617 | 
            +
                        elif not is_extend:
         | 
| 618 | 
            +
                            # if repaint_end_frame 
         | 
| 619 | 
             
                            repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
         | 
| 620 | 
             
                            repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
         | 
| 621 | 
             
                            repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
         | 
| 622 | 
             
                            repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
         | 
| 623 | 
             
                            z0 = repaint_noise
         | 
| 624 | 
            +
                        elif is_extend:
         | 
| 625 | 
            +
                            to_right_pad_gt_latents = None
         | 
| 626 | 
            +
                            to_left_pad_gt_latents = None
         | 
| 627 | 
            +
                            gt_latents = src_latents
         | 
| 628 | 
            +
                            src_latents_length = gt_latents.shape[-1]
         | 
| 629 | 
            +
                            max_infer_fame_length = int(240 * 44100 / 512 / 8)
         | 
| 630 | 
            +
                            left_pad_frame_length = 0
         | 
| 631 | 
            +
                            right_pad_frame_length = 0
         | 
| 632 | 
            +
                            right_trim_length = 0
         | 
| 633 | 
            +
                            left_trim_length = 0
         | 
| 634 | 
            +
                            if repaint_start_frame < 0:
         | 
| 635 | 
            +
                                left_pad_frame_length = abs(repaint_start_frame)
         | 
| 636 | 
            +
                                frame_length = left_pad_frame_length + gt_latents.shape[-1]
         | 
| 637 | 
            +
                                extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
         | 
| 638 | 
            +
                                if frame_length > max_infer_fame_length:
         | 
| 639 | 
            +
                                    right_trim_length = frame_length - max_infer_fame_length
         | 
| 640 | 
            +
                                    extend_gt_latents = extend_gt_latents[:,:,:,:max_infer_fame_length]
         | 
| 641 | 
            +
                                    to_right_pad_gt_latents = extend_gt_latents[:,:,:,-right_trim_length:]
         | 
| 642 | 
            +
                                    frame_length = max_infer_fame_length
         | 
| 643 | 
            +
                                repaint_start_frame = 0
         | 
| 644 | 
            +
                                gt_latents = extend_gt_latents
         | 
| 645 | 
            +
                            
         | 
| 646 | 
            +
                            if repaint_end_frame > src_latents_length:
         | 
| 647 | 
            +
                                right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
         | 
| 648 | 
            +
                                frame_length = gt_latents.shape[-1] + right_pad_frame_length
         | 
| 649 | 
            +
                                extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
         | 
| 650 | 
            +
                                if frame_length > max_infer_fame_length:
         | 
| 651 | 
            +
                                    left_trim_length = frame_length - max_infer_fame_length
         | 
| 652 | 
            +
                                    extend_gt_latents = extend_gt_latents[:,:,:,-max_infer_fame_length:]
         | 
| 653 | 
            +
                                    to_left_pad_gt_latents = extend_gt_latents[:,:,:,:left_trim_length]
         | 
| 654 | 
            +
                                    frame_length = max_infer_fame_length
         | 
| 655 | 
            +
                                repaint_end_frame = frame_length
         | 
| 656 | 
            +
                                gt_latents = extend_gt_latents
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                            repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
         | 
| 659 | 
            +
                            if left_pad_frame_length > 0:
         | 
| 660 | 
            +
                                repaint_mask[:,:,:,:left_pad_frame_length] = 1.0
         | 
| 661 | 
            +
                            if right_pad_frame_length > 0:
         | 
| 662 | 
            +
                                repaint_mask[:,:,:,-right_pad_frame_length:] = 1.0
         | 
| 663 | 
            +
                            x0 = gt_latents
         | 
| 664 | 
            +
                            padd_list = []
         | 
| 665 | 
            +
                            if left_pad_frame_length > 0:
         | 
| 666 | 
            +
                                padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
         | 
| 667 | 
            +
                            padd_list.append(target_latents[:,:,:,left_trim_length:target_latents.shape[-1]-right_trim_length])
         | 
| 668 | 
            +
                            if right_pad_frame_length > 0:
         | 
| 669 | 
            +
                                padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
         | 
| 670 | 
            +
                            target_latents = torch.cat(padd_list, dim=-1)
         | 
| 671 | 
            +
                            assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                        zt_edit = x0.clone()
         | 
| 674 | 
            +
                        z0 = target_latents
         | 
| 675 |  | 
| 676 | 
             
                    attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
         | 
| 677 |  | 
|  | |
| 776 | 
             
                        return sample
         | 
| 777 |  | 
| 778 | 
             
                    for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
         | 
| 779 | 
            +
                        
         | 
| 780 | 
            +
                        if is_repaint:
         | 
| 781 | 
            +
                            if i < n_min:
         | 
| 782 | 
            +
                                continue
         | 
| 783 | 
            +
                            elif i == n_min:
         | 
| 784 | 
            +
                                t_i = t / 1000
         | 
| 785 | 
            +
                                zt_src = (1 - t_i) * x0 + (t_i) * z0
         | 
| 786 | 
            +
                                target_latents = zt_edit + zt_src - x0
         | 
| 787 | 
            +
                                logger.info(f"repaint start from {n_min} add {t_i} level of noise")
         | 
| 788 | 
            +
             | 
| 789 | 
             
                        # expand the latents if we are doing classifier free guidance
         | 
| 790 | 
             
                        latents = target_latents
         | 
| 791 |  | 
|  | |
| 888 | 
             
                                timestep=timestep,
         | 
| 889 | 
             
                            ).sample
         | 
| 890 |  | 
| 891 | 
            +
                        if is_repaint and i >= n_min:
         | 
| 892 | 
            +
                            t_i = t/1000
         | 
| 893 | 
            +
                            if i+1 < len(timesteps): 
         | 
| 894 | 
            +
                                t_im1 = (timesteps[i+1])/1000
         | 
| 895 | 
            +
                            else:
         | 
| 896 | 
            +
                                t_im1 = torch.zeros_like(t_i).to(t_i.device)
         | 
| 897 | 
            +
                            dtype = noise_pred.dtype
         | 
| 898 | 
            +
                            target_latents = target_latents.to(torch.float32)
         | 
| 899 | 
            +
                            prev_sample = target_latents + (t_im1 - t_i) * noise_pred
         | 
| 900 | 
            +
                            prev_sample = prev_sample.to(dtype)
         | 
| 901 | 
            +
                            target_latents = prev_sample
         | 
| 902 | 
            +
                            zt_src = (1 - t_im1) * x0 + (t_im1) * z0
         | 
| 903 | 
            +
                            target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
         | 
| 904 | 
            +
                        else:
         | 
| 905 | 
            +
                            target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
         | 
| 906 |  | 
| 907 | 
            +
                    if is_extend:
         | 
| 908 | 
            +
                        if to_right_pad_gt_latents is not None:
         | 
| 909 | 
            +
                            target_latents = torch.cate([target_latents, to_right_pad_gt_latents], dim=-1)
         | 
| 910 | 
            +
                        if to_left_pad_gt_latents is not None:
         | 
| 911 | 
            +
                            target_latents = torch.cate([to_right_pad_gt_latents, target_latents], dim=0)
         | 
| 912 | 
             
                    return target_latents
         | 
| 913 |  | 
| 914 | 
             
                def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
         | 
|  | |
| 982 | 
             
                    save_path: str = None,
         | 
| 983 | 
             
                    format: str = "flac",
         | 
| 984 | 
             
                    batch_size: int = 1,
         | 
| 985 | 
            +
                    debug: bool = False,
         | 
| 986 | 
             
                ):
         | 
| 987 |  | 
| 988 | 
             
                    start_time = time.time()
         | 
|  | |
| 1020 | 
             
                    lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
         | 
| 1021 | 
             
                    lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
         | 
| 1022 | 
             
                    if len(lyrics) > 0:
         | 
| 1023 | 
            +
                        lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug)
         | 
| 1024 | 
             
                        lyric_mask = [1] * len(lyric_token_idx)
         | 
| 1025 | 
             
                        lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
         | 
| 1026 | 
             
                        lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
         | 
|  | |
| 1033 | 
             
                    preprocess_time_cost = end_time - start_time
         | 
| 1034 | 
             
                    start_time = end_time
         | 
| 1035 |  | 
| 1036 | 
            +
                    add_retake_noise = task in ("retake", "repaint", "extend")
         | 
| 1037 | 
             
                    # retake equal to repaint
         | 
| 1038 | 
             
                    if task == "retake":
         | 
| 1039 | 
             
                        repaint_start = 0
         | 
|  | |
| 1041 |  | 
| 1042 | 
             
                    src_latents = None
         | 
| 1043 | 
             
                    if src_audio_path is not None:
         | 
| 1044 | 
            +
                        assert src_audio_path is not None and task in ("repaint", "edit", "extend"), "src_audio_path is required for retake/repaint/extend task"
         | 
| 1045 | 
             
                        assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
         | 
| 1046 | 
             
                        src_latents = self.infer_latents(src_audio_path)
         | 
| 1047 |  | 
|  | |
| 1073 | 
             
                            target_lyric_token_ids=target_lyric_token_idx,
         | 
| 1074 | 
             
                            target_lyric_mask=target_lyric_mask,
         | 
| 1075 | 
             
                            src_latents=src_latents,
         | 
| 1076 | 
            +
                            random_generators=retake_random_generators, # more diversity
         | 
| 1077 | 
             
                            infer_steps=infer_step,
         | 
| 1078 | 
             
                            guidance_scale=guidance_scale,
         | 
| 1079 | 
             
                            n_min=edit_n_min,
         | 
|  | |
| 1132 |  | 
| 1133 | 
             
                    input_params_json = {
         | 
| 1134 | 
             
                        "task": task,
         | 
| 1135 | 
            +
                        "prompt": prompt if task != "edit" else edit_target_prompt,
         | 
| 1136 | 
            +
                        "lyrics": lyrics if task != "edit" else edit_target_lyrics,
         | 
| 1137 | 
             
                        "audio_duration": audio_duration,
         | 
| 1138 | 
             
                        "infer_step": infer_step,
         | 
| 1139 | 
             
                        "guidance_scale": guidance_scale,
         | 
    	
        ui/components.py
    CHANGED
    
    | @@ -65,7 +65,7 @@ def create_text2music_ui( | |
| 65 | 
             
                    with gr.Column():
         | 
| 66 | 
             
                        with gr.Row(equal_height=True):
         | 
| 67 | 
             
                            # add markdown, tags and lyrics examples are from ai music generation community
         | 
| 68 | 
            -
                            audio_duration = gr.Slider(-1, 240.0, step=0.00001, value | 
| 69 | 
             
                            sample_bnt = gr.Button("Sample", variant="primary", scale=1)
         | 
| 70 |  | 
| 71 | 
             
                        prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
         | 
| @@ -252,14 +252,15 @@ def create_text2music_ui( | |
| 252 | 
             
                        with gr.Tab("edit"):
         | 
| 253 | 
             
                            edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
         | 
| 254 | 
             
                            edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
         | 
| 255 | 
            -
             | 
|  | |
| 256 | 
             
                            edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
         | 
| 257 | 
            -
                            edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0. | 
| 258 | 
             
                            edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
         | 
| 259 |  | 
| 260 | 
             
                            def edit_type_change_func(edit_type):
         | 
| 261 | 
             
                                if edit_type == "only_lyrics":
         | 
| 262 | 
            -
                                    n_min = 0. | 
| 263 | 
             
                                    n_max = 1.0
         | 
| 264 | 
             
                                elif edit_type == "remix":
         | 
| 265 | 
             
                                    n_min = 0.2
         | 
| @@ -309,6 +310,7 @@ def create_text2music_ui( | |
| 309 | 
             
                                oss_steps,
         | 
| 310 | 
             
                                guidance_scale_text,
         | 
| 311 | 
             
                                guidance_scale_lyric,
         | 
|  | |
| 312 | 
             
                            ):
         | 
| 313 | 
             
                                if edit_source == "upload":
         | 
| 314 | 
             
                                    src_audio_path = edit_source_audio_upload
         | 
| @@ -349,7 +351,8 @@ def create_text2music_ui( | |
| 349 | 
             
                                    edit_target_prompt=edit_prompt,
         | 
| 350 | 
             
                                    edit_target_lyrics=edit_lyrics,
         | 
| 351 | 
             
                                    edit_n_min=edit_n_min,
         | 
| 352 | 
            -
                                    edit_n_max=edit_n_max
         | 
|  | |
| 353 | 
             
                                )
         | 
| 354 |  | 
| 355 | 
             
                            edit_bnt.click(
         | 
| @@ -380,9 +383,121 @@ def create_text2music_ui( | |
| 380 | 
             
                                    oss_steps,
         | 
| 381 | 
             
                                    guidance_scale_text,
         | 
| 382 | 
             
                                    guidance_scale_lyric,
         | 
|  | |
| 383 | 
             
                                ],
         | 
| 384 | 
             
                                outputs=edit_outputs + [edit_input_params_json],
         | 
| 385 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 386 |  | 
| 387 | 
             
                    def sample_data():
         | 
| 388 | 
             
                        json_data = sample_data_func()
         | 
|  | |
| 65 | 
             
                    with gr.Column():
         | 
| 66 | 
             
                        with gr.Row(equal_height=True):
         | 
| 67 | 
             
                            # add markdown, tags and lyrics examples are from ai music generation community
         | 
| 68 | 
            +
                            audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
         | 
| 69 | 
             
                            sample_bnt = gr.Button("Sample", variant="primary", scale=1)
         | 
| 70 |  | 
| 71 | 
             
                        prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
         | 
|  | |
| 252 | 
             
                        with gr.Tab("edit"):
         | 
| 253 | 
             
                            edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
         | 
| 254 | 
             
                            edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
         | 
| 255 | 
            +
                            retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None)
         | 
| 256 | 
            +
                            
         | 
| 257 | 
             
                            edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
         | 
| 258 | 
            +
                            edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="edit_n_min", interactive=True)
         | 
| 259 | 
             
                            edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
         | 
| 260 |  | 
| 261 | 
             
                            def edit_type_change_func(edit_type):
         | 
| 262 | 
             
                                if edit_type == "only_lyrics":
         | 
| 263 | 
            +
                                    n_min = 0.6
         | 
| 264 | 
             
                                    n_max = 1.0
         | 
| 265 | 
             
                                elif edit_type == "remix":
         | 
| 266 | 
             
                                    n_min = 0.2
         | 
|  | |
| 310 | 
             
                                oss_steps,
         | 
| 311 | 
             
                                guidance_scale_text,
         | 
| 312 | 
             
                                guidance_scale_lyric,
         | 
| 313 | 
            +
                                retake_seeds,
         | 
| 314 | 
             
                            ):
         | 
| 315 | 
             
                                if edit_source == "upload":
         | 
| 316 | 
             
                                    src_audio_path = edit_source_audio_upload
         | 
|  | |
| 351 | 
             
                                    edit_target_prompt=edit_prompt,
         | 
| 352 | 
             
                                    edit_target_lyrics=edit_lyrics,
         | 
| 353 | 
             
                                    edit_n_min=edit_n_min,
         | 
| 354 | 
            +
                                    edit_n_max=edit_n_max,
         | 
| 355 | 
            +
                                    retake_seeds=retake_seeds,
         | 
| 356 | 
             
                                )
         | 
| 357 |  | 
| 358 | 
             
                            edit_bnt.click(
         | 
|  | |
| 383 | 
             
                                    oss_steps,
         | 
| 384 | 
             
                                    guidance_scale_text,
         | 
| 385 | 
             
                                    guidance_scale_lyric,
         | 
| 386 | 
            +
                                    retake_seeds,
         | 
| 387 | 
             
                                ],
         | 
| 388 | 
             
                                outputs=edit_outputs + [edit_input_params_json],
         | 
| 389 | 
             
                            )
         | 
| 390 | 
            +
                        with gr.Tab("extend"):
         | 
| 391 | 
            +
                            extend_seeds = gr.Textbox(label="extend seeds (default None)", placeholder="", value=None)
         | 
| 392 | 
            +
                            left_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Left Extend Length", interactive=True)
         | 
| 393 | 
            +
                            right_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Right Extend Length", interactive=True)
         | 
| 394 | 
            +
                            extend_source = gr.Radio(["text2music", "last_extend", "upload"], value="text2music", label="Extend Source", elem_id="extend_source")
         | 
| 395 | 
            +
                            
         | 
| 396 | 
            +
                            extend_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="extend_source_audio_upload")
         | 
| 397 | 
            +
                            extend_source.change(
         | 
| 398 | 
            +
                                fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"),
         | 
| 399 | 
            +
                                inputs=[extend_source],
         | 
| 400 | 
            +
                                outputs=[extend_source_audio_upload],
         | 
| 401 | 
            +
                            )
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                            extend_bnt = gr.Button("Extend", variant="primary")
         | 
| 404 | 
            +
                            extend_outputs, extend_input_params_json = create_output_ui("Extend")
         | 
| 405 | 
            +
                            
         | 
| 406 | 
            +
                            def extend_process_func(
         | 
| 407 | 
            +
                                text2music_json_data,
         | 
| 408 | 
            +
                                extend_input_params_json,
         | 
| 409 | 
            +
                                extend_seeds,
         | 
| 410 | 
            +
                                left_extend_length,
         | 
| 411 | 
            +
                                right_extend_length,
         | 
| 412 | 
            +
                                extend_source,
         | 
| 413 | 
            +
                                extend_source_audio_upload,
         | 
| 414 | 
            +
                                prompt,
         | 
| 415 | 
            +
                                lyrics,
         | 
| 416 | 
            +
                                infer_step,
         | 
| 417 | 
            +
                                guidance_scale,
         | 
| 418 | 
            +
                                scheduler_type,
         | 
| 419 | 
            +
                                cfg_type,
         | 
| 420 | 
            +
                                omega_scale,
         | 
| 421 | 
            +
                                manual_seeds,
         | 
| 422 | 
            +
                                guidance_interval,
         | 
| 423 | 
            +
                                guidance_interval_decay,
         | 
| 424 | 
            +
                                min_guidance_scale,
         | 
| 425 | 
            +
                                use_erg_tag,
         | 
| 426 | 
            +
                                use_erg_lyric,
         | 
| 427 | 
            +
                                use_erg_diffusion,
         | 
| 428 | 
            +
                                oss_steps,
         | 
| 429 | 
            +
                                guidance_scale_text,
         | 
| 430 | 
            +
                                guidance_scale_lyric,
         | 
| 431 | 
            +
                            ):
         | 
| 432 | 
            +
                                if extend_source == "upload":
         | 
| 433 | 
            +
                                    src_audio_path = extend_source_audio_upload
         | 
| 434 | 
            +
                                    json_data = text2music_json_data
         | 
| 435 | 
            +
                                elif extend_source == "text2music":
         | 
| 436 | 
            +
                                    json_data = text2music_json_data
         | 
| 437 | 
            +
                                    src_audio_path = json_data["audio_path"]
         | 
| 438 | 
            +
                                elif extend_source == "last_repaint":
         | 
| 439 | 
            +
                                    json_data = extend_input_params_json
         | 
| 440 | 
            +
                                    src_audio_path = json_data["audio_path"]
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                                repaint_start = -left_extend_length
         | 
| 443 | 
            +
                                repaint_end = json_data["audio_duration"] + right_extend_length
         | 
| 444 | 
            +
                                return text2music_process_func(
         | 
| 445 | 
            +
                                    json_data["audio_duration"],
         | 
| 446 | 
            +
                                    prompt,
         | 
| 447 | 
            +
                                    lyrics,
         | 
| 448 | 
            +
                                    infer_step,
         | 
| 449 | 
            +
                                    guidance_scale,
         | 
| 450 | 
            +
                                    scheduler_type,
         | 
| 451 | 
            +
                                    cfg_type,
         | 
| 452 | 
            +
                                    omega_scale,
         | 
| 453 | 
            +
                                    manual_seeds,
         | 
| 454 | 
            +
                                    guidance_interval,
         | 
| 455 | 
            +
                                    guidance_interval_decay,
         | 
| 456 | 
            +
                                    min_guidance_scale,
         | 
| 457 | 
            +
                                    use_erg_tag,
         | 
| 458 | 
            +
                                    use_erg_lyric,
         | 
| 459 | 
            +
                                    use_erg_diffusion,
         | 
| 460 | 
            +
                                    oss_steps,
         | 
| 461 | 
            +
                                    guidance_scale_text,
         | 
| 462 | 
            +
                                    guidance_scale_lyric,
         | 
| 463 | 
            +
                                    retake_seeds=extend_seeds,
         | 
| 464 | 
            +
                                    retake_variance=1.0,
         | 
| 465 | 
            +
                                    task="extend",
         | 
| 466 | 
            +
                                    repaint_start=repaint_start,
         | 
| 467 | 
            +
                                    repaint_end=repaint_end,
         | 
| 468 | 
            +
                                    src_audio_path=src_audio_path,
         | 
| 469 | 
            +
                                )
         | 
| 470 | 
            +
                            
         | 
| 471 | 
            +
                            extend_bnt.click(
         | 
| 472 | 
            +
                                fn=extend_process_func,
         | 
| 473 | 
            +
                                inputs=[
         | 
| 474 | 
            +
                                    input_params_json,
         | 
| 475 | 
            +
                                    extend_input_params_json,
         | 
| 476 | 
            +
                                    extend_seeds,
         | 
| 477 | 
            +
                                    left_extend_length,
         | 
| 478 | 
            +
                                    right_extend_length,
         | 
| 479 | 
            +
                                    extend_source,
         | 
| 480 | 
            +
                                    extend_source_audio_upload,
         | 
| 481 | 
            +
                                    prompt,
         | 
| 482 | 
            +
                                    lyrics,
         | 
| 483 | 
            +
                                    infer_step,
         | 
| 484 | 
            +
                                    guidance_scale,
         | 
| 485 | 
            +
                                    scheduler_type,
         | 
| 486 | 
            +
                                    cfg_type,
         | 
| 487 | 
            +
                                    omega_scale,
         | 
| 488 | 
            +
                                    manual_seeds,
         | 
| 489 | 
            +
                                    guidance_interval,
         | 
| 490 | 
            +
                                    guidance_interval_decay,
         | 
| 491 | 
            +
                                    min_guidance_scale,
         | 
| 492 | 
            +
                                    use_erg_tag,
         | 
| 493 | 
            +
                                    use_erg_lyric,
         | 
| 494 | 
            +
                                    use_erg_diffusion,
         | 
| 495 | 
            +
                                    oss_steps,
         | 
| 496 | 
            +
                                    guidance_scale_text,
         | 
| 497 | 
            +
                                    guidance_scale_lyric,
         | 
| 498 | 
            +
                                ],
         | 
| 499 | 
            +
                                outputs=extend_outputs + [extend_input_params_json],
         | 
| 500 | 
            +
                            )
         | 
| 501 |  | 
| 502 | 
             
                    def sample_data():
         | 
| 503 | 
             
                        json_data = sample_data_func()
         | 
 
			

