Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						8a80eb5
	
1
								Parent(s):
							
							5e340e8
								
fixed typo
Browse files
    	
        app.py
    CHANGED
    
    | @@ -118,7 +118,7 @@ def process_image(image, repaint, busyness): | |
| 118 | 
             
                print("Saving masked image file to ", masked_img_file)
         | 
| 119 | 
             
                image.save(masked_img_file)
         | 
| 120 | 
             
                num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
         | 
| 121 | 
            -
                bs =  | 
| 122 | 
             
                repaint = repaint
         | 
| 123 | 
             
                seed_scale = 1.0
         | 
| 124 | 
             
                CT_HOME = '.'
         | 
|  | |
| 118 | 
             
                print("Saving masked image file to ", masked_img_file)
         | 
| 119 | 
             
                image.save(masked_img_file)
         | 
| 120 | 
             
                num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
         | 
| 121 | 
            +
                bs = num
         | 
| 122 | 
             
                repaint = repaint
         | 
| 123 | 
             
                seed_scale = 1.0
         | 
| 124 | 
             
                CT_HOME = '.'
         | 
    	
        sample.py
    CHANGED
    
    | @@ -5,9 +5,7 @@ | |
| 5 |  | 
| 6 | 
             
            """Samples from k-diffusion models."""
         | 
| 7 |  | 
| 8 | 
            -
             | 
| 9 | 
            -
            import spaces
         | 
| 10 | 
            -
            import natten
         | 
| 11 | 
             
            import argparse
         | 
| 12 | 
             
            from pathlib import Path
         | 
| 13 |  | 
| @@ -24,11 +22,11 @@ from pom.v_diffusion import DDPM, LogSchedule, CrashSchedule | |
| 24 | 
             
            #CHORD_BORDER = 8   # chord border size in pixels
         | 
| 25 | 
             
            from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
         | 
| 26 |  | 
|  | |
| 27 |  | 
| 28 | 
             
            # ---- my mangled sampler that includes repaint 
         | 
| 29 | 
             
            import torchsde 
         | 
| 30 |  | 
| 31 | 
            -
            #@spaces.GPU
         | 
| 32 | 
             
            class BatchedBrownianTree:
         | 
| 33 | 
             
                """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
         | 
| 34 |  | 
| @@ -56,7 +54,6 @@ class BatchedBrownianTree: | |
| 56 | 
             
                    return w if self.batched else w[0]
         | 
| 57 |  | 
| 58 |  | 
| 59 | 
            -
            #@spaces.GPU
         | 
| 60 | 
             
            class BrownianTreeNoiseSampler:
         | 
| 61 | 
             
                """A noise sampler backed by a torchsde.BrownianTree.
         | 
| 62 |  | 
| @@ -94,7 +91,6 @@ def to_d(x, sigma, denoised): | |
| 94 | 
             
                return (x - denoised) / append_dims(sigma, x.ndim)
         | 
| 95 |  | 
| 96 |  | 
| 97 | 
            -
            #@spaces.GPU
         | 
| 98 | 
             
            @torch.no_grad()
         | 
| 99 | 
             
            def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
         | 
| 100 | 
             
                """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
         | 
| @@ -129,7 +125,6 @@ def get_scalings(sigma, sigma_data=0.5): | |
| 129 | 
             
                return c_skip, c_out, c_in
         | 
| 130 |  | 
| 131 |  | 
| 132 | 
            -
            #@spaces.GPU
         | 
| 133 | 
             
            @torch.no_grad()
         | 
| 134 | 
             
            def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, 
         | 
| 135 | 
             
                                disable=None, eta=1., s_noise=1., noise_sampler=None, 
         | 
| @@ -289,14 +284,12 @@ def sample(model, x, steps, eta, **extra_args): | |
| 289 |  | 
| 290 | 
             
            # Soft mask inpainting is just shrinking hard (binary) mask inpainting
         | 
| 291 | 
             
            # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
         | 
| 292 | 
            -
            #@spaces.GPU
         | 
| 293 | 
             
            def get_bmask(i, steps, mask):
         | 
| 294 | 
             
                strength = (i+1)/(steps)
         | 
| 295 | 
             
                # convert to binary mask
         | 
| 296 | 
             
                bmask = torch.where(mask<=strength,1,0)
         | 
| 297 | 
             
                return bmask
         | 
| 298 |  | 
| 299 | 
            -
            #@spaces.GPU
         | 
| 300 | 
             
            def make_cond_model_fn(model, cond_fn):
         | 
| 301 | 
             
                def cond_model_fn(x, sigma, **kwargs):
         | 
| 302 | 
             
                    with torch.enable_grad():
         | 
| @@ -312,7 +305,6 @@ def make_cond_model_fn(model, cond_fn): | |
| 312 | 
             
            # For sampling, set both init_data and mask to None
         | 
| 313 | 
             
            # For variations, set init_data 
         | 
| 314 | 
             
            # For inpainting, set both init_data & mask 
         | 
| 315 | 
            -
            #@spaces.GPU
         | 
| 316 | 
             
            def sample_k(
         | 
| 317 | 
             
                    model_fn, 
         | 
| 318 | 
             
                    noise, 
         | 
| @@ -425,7 +417,7 @@ def infer_mask_from_init_img(img, mask_with='white'): | |
| 425 | 
             
                    mask[img[2,:,:]==1] = 1  # blue
         | 
| 426 | 
             
                return mask*1.0
         | 
| 427 |  | 
| 428 | 
            -
             | 
| 429 | 
             
            def grow_mask(init_mask, grow_by=2):
         | 
| 430 | 
             
                "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
         | 
| 431 | 
             
                new_mask = init_mask.clone()
         | 
| @@ -434,7 +426,7 @@ def grow_mask(init_mask, grow_by=2): | |
| 434 | 
             
                    new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0 
         | 
| 435 | 
             
                return new_mask
         | 
| 436 |  | 
| 437 | 
            -
             | 
| 438 | 
             
            def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
         | 
| 439 | 
             
                "adds extra noise inside mask"
         | 
| 440 | 
             
                init_mask = grow_mask(init_mask, grow_by=grow_by)  # make the mask bigger
         | 
| @@ -448,15 +440,13 @@ def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0): | |
| 448 | 
             
                init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
         | 
| 449 | 
             
                return init_image
         | 
| 450 |  | 
| 451 | 
            -
             | 
| 452 | 
             
            def get_init_image_and_mask(args, device):
         | 
| 453 | 
             
                convert_tensor = transforms.ToTensor()
         | 
| 454 | 
             
                init_image = Image.open(args.init_image).convert('RGB')
         | 
| 455 | 
             
                init_image = convert_tensor(init_image)
         | 
| 456 | 
             
                #normalize image from 0..1 to -1..1
         | 
| 457 | 
             
                init_image = (2.0 * init_image) - 1.0
         | 
| 458 | 
            -
             | 
| 459 | 
            -
             | 
| 460 | 
             
                init_mask = torch.ones(init_image.shape[-2:])  # ones are where stuff will change, zeros will stay the same
         | 
| 461 |  | 
| 462 | 
             
                inpaint_task = 'infer'  # infer mask from init_image
         | 
| @@ -522,7 +512,115 @@ def get_init_image_and_mask(args, device): | |
| 522 | 
             
                init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
         | 
| 523 | 
             
                return init_image.to(device), init_mask.to(device)
         | 
| 524 |  | 
| 525 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 526 | 
             
            def main():
         | 
| 527 | 
             
                global init_image, init_mask
         | 
| 528 | 
             
                p = argparse.ArgumentParser(description=__doc__,
         | 
| @@ -574,12 +672,7 @@ def main(): | |
| 574 | 
             
                sigma_min = model_config['sigma_min']
         | 
| 575 | 
             
                sigma_max = model_config['sigma_max']
         | 
| 576 |  | 
| 577 | 
            -
                # SHH modified
         | 
| 578 | 
             
                torch.set_float32_matmul_precision('high')
         | 
| 579 | 
            -
                #class_cond = torch.tensor([0]).to(device)
         | 
| 580 | 
            -
                #num_classes = 10
         | 
| 581 | 
            -
                #class_cond = torch.remainder(torch.arange(0, args.n), num_classes).int().to(device)
         | 
| 582 | 
            -
                #extra_args = {'class_cond':class_cond}
         | 
| 583 | 
             
                extra_args = {}
         | 
| 584 | 
             
                init_image, init_mask = None, None
         | 
| 585 | 
             
                if args.init_image is not None:
         | 
| @@ -595,11 +688,6 @@ def main(): | |
| 595 | 
             
                        tqdm.write('Sampling...')
         | 
| 596 | 
             
                    sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
         | 
| 597 |  | 
| 598 | 
            -
                    #ddpm_sampler = DDPM(model)
         | 
| 599 | 
            -
                    #model_fn = model
         | 
| 600 | 
            -
                    #ddpm_sampler = K.external.VDenoiser(model_fn)
         | 
| 601 | 
            -
             | 
| 602 | 
            -
                    #@spaces.GPU
         | 
| 603 | 
             
                    def sample_fn(n, debug=True):
         | 
| 604 | 
             
                        x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
         | 
| 605 | 
             
                        print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
         | 
|  | |
| 5 |  | 
| 6 | 
             
            """Samples from k-diffusion models."""
         | 
| 7 |  | 
| 8 | 
            +
             | 
|  | |
|  | |
| 9 | 
             
            import argparse
         | 
| 10 | 
             
            from pathlib import Path
         | 
| 11 |  | 
|  | |
| 22 | 
             
            #CHORD_BORDER = 8   # chord border size in pixels
         | 
| 23 | 
             
            from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
         | 
| 24 |  | 
| 25 | 
            +
            import spaces
         | 
| 26 |  | 
| 27 | 
             
            # ---- my mangled sampler that includes repaint 
         | 
| 28 | 
             
            import torchsde 
         | 
| 29 |  | 
|  | |
| 30 | 
             
            class BatchedBrownianTree:
         | 
| 31 | 
             
                """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
         | 
| 32 |  | 
|  | |
| 54 | 
             
                    return w if self.batched else w[0]
         | 
| 55 |  | 
| 56 |  | 
|  | |
| 57 | 
             
            class BrownianTreeNoiseSampler:
         | 
| 58 | 
             
                """A noise sampler backed by a torchsde.BrownianTree.
         | 
| 59 |  | 
|  | |
| 91 | 
             
                return (x - denoised) / append_dims(sigma, x.ndim)
         | 
| 92 |  | 
| 93 |  | 
|  | |
| 94 | 
             
            @torch.no_grad()
         | 
| 95 | 
             
            def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
         | 
| 96 | 
             
                """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
         | 
|  | |
| 125 | 
             
                return c_skip, c_out, c_in
         | 
| 126 |  | 
| 127 |  | 
|  | |
| 128 | 
             
            @torch.no_grad()
         | 
| 129 | 
             
            def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, 
         | 
| 130 | 
             
                                disable=None, eta=1., s_noise=1., noise_sampler=None, 
         | 
|  | |
| 284 |  | 
| 285 | 
             
            # Soft mask inpainting is just shrinking hard (binary) mask inpainting
         | 
| 286 | 
             
            # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
         | 
|  | |
| 287 | 
             
            def get_bmask(i, steps, mask):
         | 
| 288 | 
             
                strength = (i+1)/(steps)
         | 
| 289 | 
             
                # convert to binary mask
         | 
| 290 | 
             
                bmask = torch.where(mask<=strength,1,0)
         | 
| 291 | 
             
                return bmask
         | 
| 292 |  | 
|  | |
| 293 | 
             
            def make_cond_model_fn(model, cond_fn):
         | 
| 294 | 
             
                def cond_model_fn(x, sigma, **kwargs):
         | 
| 295 | 
             
                    with torch.enable_grad():
         | 
|  | |
| 305 | 
             
            # For sampling, set both init_data and mask to None
         | 
| 306 | 
             
            # For variations, set init_data 
         | 
| 307 | 
             
            # For inpainting, set both init_data & mask 
         | 
|  | |
| 308 | 
             
            def sample_k(
         | 
| 309 | 
             
                    model_fn, 
         | 
| 310 | 
             
                    noise, 
         | 
|  | |
| 417 | 
             
                    mask[img[2,:,:]==1] = 1  # blue
         | 
| 418 | 
             
                return mask*1.0
         | 
| 419 |  | 
| 420 | 
            +
             | 
| 421 | 
             
            def grow_mask(init_mask, grow_by=2):
         | 
| 422 | 
             
                "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
         | 
| 423 | 
             
                new_mask = init_mask.clone()
         | 
|  | |
| 426 | 
             
                    new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0 
         | 
| 427 | 
             
                return new_mask
         | 
| 428 |  | 
| 429 | 
            +
             | 
| 430 | 
             
            def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
         | 
| 431 | 
             
                "adds extra noise inside mask"
         | 
| 432 | 
             
                init_mask = grow_mask(init_mask, grow_by=grow_by)  # make the mask bigger
         | 
|  | |
| 440 | 
             
                init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
         | 
| 441 | 
             
                return init_image
         | 
| 442 |  | 
| 443 | 
            +
             | 
| 444 | 
             
            def get_init_image_and_mask(args, device):
         | 
| 445 | 
             
                convert_tensor = transforms.ToTensor()
         | 
| 446 | 
             
                init_image = Image.open(args.init_image).convert('RGB')
         | 
| 447 | 
             
                init_image = convert_tensor(init_image)
         | 
| 448 | 
             
                #normalize image from 0..1 to -1..1
         | 
| 449 | 
             
                init_image = (2.0 * init_image) - 1.0
         | 
|  | |
|  | |
| 450 | 
             
                init_mask = torch.ones(init_image.shape[-2:])  # ones are where stuff will change, zeros will stay the same
         | 
| 451 |  | 
| 452 | 
             
                inpaint_task = 'infer'  # infer mask from init_image
         | 
|  | |
| 512 | 
             
                init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
         | 
| 513 | 
             
                return init_image.to(device), init_mask.to(device)
         | 
| 514 |  | 
| 515 | 
            +
             | 
| 516 | 
            +
             | 
| 517 | 
            +
             | 
| 518 | 
            +
            # wrapper compatible with ZeroGPU, callable from outside
         | 
| 519 | 
            +
            @spaces.GPU
         | 
| 520 | 
            +
            def zero_wrapper(args, device): 
         | 
| 521 | 
            +
                global init_image, init_mask
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                config = K.config.load_config(args.config if args.config else args.checkpoint)
         | 
| 524 | 
            +
                model_config = config['model']
         | 
| 525 | 
            +
                # TODO: allow non-square input sizes
         | 
| 526 | 
            +
                assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
         | 
| 527 | 
            +
                size = model_config['input_size']
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                print('zero_wrapper: Using device:', device, flush=True)
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
         | 
| 532 | 
            +
                cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device)  # add chord embedding-maker to main model
         | 
| 533 | 
            +
                if cse is not None:
         | 
| 534 | 
            +
                    inner_model.cse = cse
         | 
| 535 | 
            +
                try:
         | 
| 536 | 
            +
                    inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
         | 
| 537 | 
            +
                except:
         | 
| 538 | 
            +
                    #ckpt = torch.load(args.checkpoint).to(device)
         | 
| 539 | 
            +
                    ckpt = torch.load(args.checkpoint, map_location='cpu')
         | 
| 540 | 
            +
                    inner_model.load_state_dict(ckpt['model'])
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                print('Parameters:', K.utils.n_params(inner_model))
         | 
| 543 | 
            +
                model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                sigma_min = model_config['sigma_min']
         | 
| 546 | 
            +
                sigma_max = model_config['sigma_max']
         | 
| 547 | 
            +
                torch.set_float32_matmul_precision('high')
         | 
| 548 | 
            +
                extra_args = {}
         | 
| 549 | 
            +
                init_image, init_mask = None, None
         | 
| 550 | 
            +
                if args.init_image is not None:
         | 
| 551 | 
            +
                    init_image, init_mask = get_init_image_and_mask(args, device)
         | 
| 552 | 
            +
                    init_image = init_image.to(device)
         | 
| 553 | 
            +
                    init_mask = init_mask.to(device)
         | 
| 554 | 
            +
                @torch.no_grad()
         | 
| 555 | 
            +
                @K.utils.eval_mode(model)
         | 
| 556 | 
            +
                def run():
         | 
| 557 | 
            +
                    global init_image, init_mask
         | 
| 558 | 
            +
                    if accelerator.is_local_main_process:
         | 
| 559 | 
            +
                        tqdm.write('Sampling...')
         | 
| 560 | 
            +
                    sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    def sample_fn(n, debug=True):
         | 
| 563 | 
            +
                        x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
         | 
| 564 | 
            +
                        print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                        if args.init_image is not None:
         | 
| 567 | 
            +
                            init_data, mask = get_init_image_and_mask(args, device)
         | 
| 568 | 
            +
                            init_data = args.seed_scale*x*mask + (1-mask)*init_data  # extra nucleation?
         | 
| 569 | 
            +
                            if cse is not None: 
         | 
| 570 | 
            +
                                chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device)
         | 
| 571 | 
            +
                            else: 
         | 
| 572 | 
            +
                                chord_cond = None
         | 
| 573 | 
            +
                            #print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max())
         | 
| 574 | 
            +
                        else:
         | 
| 575 | 
            +
                            init_data, mask, chord_cond = None, None, None
         | 
| 576 | 
            +
                        # chord_cond doesn't work anyway so f it: 
         | 
| 577 | 
            +
                        chord_cond = None
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                        print("chord_cond = ", chord_cond)
         | 
| 580 | 
            +
                        if chord_cond is not None: 
         | 
| 581 | 
            +
                            extra_args['chord_cond'] = chord_cond
         | 
| 582 | 
            +
                        # these two work:
         | 
| 583 | 
            +
                        #x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
         | 
| 584 | 
            +
                        #x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                        noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) 
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                        sampler_type="my-dpmpp-2m-sde"  # "k-lms"
         | 
| 589 | 
            +
                        #sampler_type="my-sample-euler"
         | 
| 590 | 
            +
                        #sampler_type="dpmpp-2m-sde"  
         | 
| 591 | 
            +
                        #sampler_type = "dpmpp-3m-sde"
         | 
| 592 | 
            +
                        #sampler_type = "k-dpmpp-2s-ancestral"
         | 
| 593 | 
            +
                        print("dtypes:", [x.dtype if x is not None else None  for x in [noise, init_data, mask, chord_cond]])
         | 
| 594 | 
            +
                        x_0 = sample_k(inner_model, noise, sampler_type=sampler_type, 
         | 
| 595 | 
            +
                                       init_data=init_data, mask=mask, steps=args.steps, 
         | 
| 596 | 
            +
                                       sigma_min=sigma_min, sigma_max=sigma_max, rho=7., 
         | 
| 597 | 
            +
                                       device=device, model_config=model_config, repaint=args.repaint, 
         | 
| 598 | 
            +
                                       **extra_args)
         | 
| 599 | 
            +
                        #x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100,  sigma_min=0.5, sigma_max=50, rho=1., device=device,  model_config=model_config, **extra_args)
         | 
| 600 | 
            +
                        print("x_0.min, x_0.max = ", x_0.min(), x_0.max())
         | 
| 601 | 
            +
                        if x_0.isnan().any():
         | 
| 602 | 
            +
                            assert False, "x_0 has NaNs"
         | 
| 603 | 
            +
                        
         | 
| 604 | 
            +
                        # do gpu garbage collection before proceeding
         | 
| 605 | 
            +
                        torch.cuda.empty_cache()
         | 
| 606 | 
            +
                        return x_0
         | 
| 607 | 
            +
                    
         | 
| 608 | 
            +
                    x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
         | 
| 609 | 
            +
                    if accelerator.is_main_process:
         | 
| 610 | 
            +
                        for i, out in enumerate(x_0):
         | 
| 611 | 
            +
                            filename = f'{args.prefix}_{i:05}.png'
         | 
| 612 | 
            +
                            K.utils.to_pil_image(out).save(filename)
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                try:
         | 
| 615 | 
            +
                    run()
         | 
| 616 | 
            +
                except KeyboardInterrupt:
         | 
| 617 | 
            +
                    pass
         | 
| 618 | 
            +
             | 
| 619 | 
            +
             | 
| 620 | 
            +
             | 
| 621 | 
            +
             | 
| 622 | 
            +
             | 
| 623 | 
            +
             | 
| 624 | 
             
            def main():
         | 
| 625 | 
             
                global init_image, init_mask
         | 
| 626 | 
             
                p = argparse.ArgumentParser(description=__doc__,
         | 
|  | |
| 672 | 
             
                sigma_min = model_config['sigma_min']
         | 
| 673 | 
             
                sigma_max = model_config['sigma_max']
         | 
| 674 |  | 
|  | |
| 675 | 
             
                torch.set_float32_matmul_precision('high')
         | 
|  | |
|  | |
|  | |
|  | |
| 676 | 
             
                extra_args = {}
         | 
| 677 | 
             
                init_image, init_mask = None, None
         | 
| 678 | 
             
                if args.init_image is not None:
         | 
|  | |
| 688 | 
             
                        tqdm.write('Sampling...')
         | 
| 689 | 
             
                    sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
         | 
| 690 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 691 | 
             
                    def sample_fn(n, debug=True):
         | 
| 692 | 
             
                        x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
         | 
| 693 | 
             
                        print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
         | 
