Commit 
							
							·
						
						bd73f2a
	
1
								Parent(s):
							
							114c79c
								
Do not assume 8 devices in JAX (#154)
Browse files- Do not assume 8 devices in JAX (e124bbdca2dab1af0cdce19d575f8043eab9341e)
Co-authored-by: Pedro Cuenca <[email protected]>
    	
        README.md
    CHANGED
    
    | @@ -154,7 +154,7 @@ prompt_ids = pipeline.prepare_inputs(prompt) | |
| 154 |  | 
| 155 | 
             
            # shard inputs and rng
         | 
| 156 | 
             
            params = replicate(params)
         | 
| 157 | 
            -
            prng_seed = jax.random.split(prng_seed,  | 
| 158 | 
             
            prompt_ids = shard(prompt_ids)
         | 
| 159 |  | 
| 160 | 
             
            images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
         | 
| @@ -187,7 +187,7 @@ prompt_ids = pipeline.prepare_inputs(prompt) | |
| 187 |  | 
| 188 | 
             
            # shard inputs and rng
         | 
| 189 | 
             
            params = replicate(params)
         | 
| 190 | 
            -
            prng_seed = jax.random.split(prng_seed,  | 
| 191 | 
             
            prompt_ids = shard(prompt_ids)
         | 
| 192 |  | 
| 193 | 
             
            images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
         | 
|  | |
| 154 |  | 
| 155 | 
             
            # shard inputs and rng
         | 
| 156 | 
             
            params = replicate(params)
         | 
| 157 | 
            +
            prng_seed = jax.random.split(prng_seed, num_samples)
         | 
| 158 | 
             
            prompt_ids = shard(prompt_ids)
         | 
| 159 |  | 
| 160 | 
             
            images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
         | 
|  | |
| 187 |  | 
| 188 | 
             
            # shard inputs and rng
         | 
| 189 | 
             
            params = replicate(params)
         | 
| 190 | 
            +
            prng_seed = jax.random.split(prng_seed, num_samples)
         | 
| 191 | 
             
            prompt_ids = shard(prompt_ids)
         | 
| 192 |  | 
| 193 | 
             
            images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
         | 

 
		