Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						ad569d5
	
1
								Parent(s):
							
							2a2c118
								
Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -2,6 +2,7 @@ import gradio as gr | |
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from diffusers import StableDiffusionXLPipeline, AutoencoderKL
         | 
| 4 | 
             
            from huggingface_hub import hf_hub_download
         | 
|  | |
| 5 | 
             
            from share_btn import community_icon_html, loading_icon_html, share_js
         | 
| 6 | 
             
            from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
         | 
| 7 | 
             
            import lora
         | 
| @@ -26,12 +27,19 @@ with open("sdxl_loras.json", "r") as file: | |
| 26 | 
             
                    }
         | 
| 27 | 
             
                    for item in data
         | 
| 28 | 
             
                ]
         | 
| 29 | 
            -
            print(sdxl_loras)
         | 
| 30 | 
            -
            saved_names = [
         | 
| 31 | 
            -
                hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
         | 
| 32 | 
            -
            ]
         | 
| 33 |  | 
| 34 | 
            -
            device = "cuda" | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 35 |  | 
| 36 | 
             
            vae = AutoencoderKL.from_pretrained(
         | 
| 37 | 
             
                "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
         | 
| @@ -40,14 +48,13 @@ pipe = StableDiffusionXLPipeline.from_pretrained( | |
| 40 | 
             
                "stabilityai/stable-diffusion-xl-base-1.0",
         | 
| 41 | 
             
                vae=vae,
         | 
| 42 | 
             
                torch_dtype=torch.float16,
         | 
| 43 | 
            -
            ) | 
| 44 | 
             
            original_pipe = copy.deepcopy(pipe)
         | 
| 45 | 
             
            pipe.to(device)
         | 
| 46 |  | 
| 47 | 
             
            last_lora = ""
         | 
| 48 | 
             
            last_merged = False
         | 
| 49 |  | 
| 50 | 
            -
             | 
| 51 | 
             
            def update_selection(selected_state: gr.SelectData):
         | 
| 52 | 
             
                lora_repo = sdxl_loras[selected_state.index]["repo"]
         | 
| 53 | 
             
                instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
         | 
| @@ -128,7 +135,7 @@ def merge_incompatible_lora(full_path_lora, lora_scale): | |
| 128 | 
             
                            del lora_model
         | 
| 129 | 
             
                            gc.collect()
         | 
| 130 |  | 
| 131 | 
            -
            def run_lora(prompt, negative, lora_scale, selected_state):
         | 
| 132 | 
             
                global last_lora, last_merged, pipe
         | 
| 133 |  | 
| 134 | 
             
                if negative == "":
         | 
| @@ -138,7 +145,8 @@ def run_lora(prompt, negative, lora_scale, selected_state): | |
| 138 | 
             
                    raise gr.Error("You must select a LoRA")
         | 
| 139 | 
             
                repo_name = sdxl_loras[selected_state.index]["repo"]
         | 
| 140 | 
             
                weight_name = sdxl_loras[selected_state.index]["weights"]
         | 
| 141 | 
            -
                full_path_lora =  | 
|  | |
| 142 | 
             
                cross_attention_kwargs = None
         | 
| 143 | 
             
                if last_lora != repo_name:
         | 
| 144 | 
             
                    if last_merged:
         | 
| @@ -148,17 +156,17 @@ def run_lora(prompt, negative, lora_scale, selected_state): | |
| 148 | 
             
                        pipe.to(device)
         | 
| 149 | 
             
                    else:
         | 
| 150 | 
             
                        pipe.unload_lora_weights()
         | 
|  | |
| 151 | 
             
                    is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
         | 
| 152 | 
             
                    if is_compatible:
         | 
| 153 | 
            -
                        pipe.load_lora_weights( | 
| 154 | 
            -
                         | 
| 155 | 
             
                    else:
         | 
| 156 | 
             
                        is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
         | 
| 157 | 
             
                        if(is_pivotal):
         | 
|  | |
|  | |
| 158 |  | 
| 159 | 
            -
                            pipe.load_lora_weights(full_path_lora)
         | 
| 160 | 
            -
                            cross_attention_kwargs = {"scale": lora_scale}
         | 
| 161 | 
            -
             | 
| 162 | 
             
                            #Add the textual inversion embeddings from pivotal tuning models
         | 
| 163 | 
             
                            text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
         | 
| 164 | 
             
                            text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
         | 
| @@ -177,7 +185,6 @@ def run_lora(prompt, negative, lora_scale, selected_state): | |
| 177 | 
             
                    height=768,
         | 
| 178 | 
             
                    num_inference_steps=20,
         | 
| 179 | 
             
                    guidance_scale=7.5,
         | 
| 180 | 
            -
                    cross_attention_kwargs=cross_attention_kwargs,
         | 
| 181 | 
             
                ).images[0]
         | 
| 182 | 
             
                last_lora = repo_name
         | 
| 183 | 
             
                gc.collect()
         | 
|  | |
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from diffusers import StableDiffusionXLPipeline, AutoencoderKL
         | 
| 4 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 5 | 
            +
            from safetensors.torch import load_file
         | 
| 6 | 
             
            from share_btn import community_icon_html, loading_icon_html, share_js
         | 
| 7 | 
             
            from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
         | 
| 8 | 
             
            import lora
         | 
|  | |
| 27 | 
             
                    }
         | 
| 28 | 
             
                    for item in data
         | 
| 29 | 
             
                ]
         | 
|  | |
|  | |
|  | |
|  | |
| 30 |  | 
| 31 | 
            +
            device = "cuda" 
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            for item in sdxl_loras:
         | 
| 34 | 
            +
                saved_name = hf_hub_download(item["repo"], item["weights"])
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
                if not saved_name.endswith('.safetensors'):
         | 
| 37 | 
            +
                    state_dict = torch.load(saved_name)
         | 
| 38 | 
            +
                else:
         | 
| 39 | 
            +
                    state_dict = load_file(saved_name)
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                item["saved_name"] = saved_name
         | 
| 42 | 
            +
                item["state_dict"] = state_dict #{k: v.to(device=device, dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}
         | 
| 43 |  | 
| 44 | 
             
            vae = AutoencoderKL.from_pretrained(
         | 
| 45 | 
             
                "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
         | 
|  | |
| 48 | 
             
                "stabilityai/stable-diffusion-xl-base-1.0",
         | 
| 49 | 
             
                vae=vae,
         | 
| 50 | 
             
                torch_dtype=torch.float16,
         | 
| 51 | 
            +
            )
         | 
| 52 | 
             
            original_pipe = copy.deepcopy(pipe)
         | 
| 53 | 
             
            pipe.to(device)
         | 
| 54 |  | 
| 55 | 
             
            last_lora = ""
         | 
| 56 | 
             
            last_merged = False
         | 
| 57 |  | 
|  | |
| 58 | 
             
            def update_selection(selected_state: gr.SelectData):
         | 
| 59 | 
             
                lora_repo = sdxl_loras[selected_state.index]["repo"]
         | 
| 60 | 
             
                instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
         | 
|  | |
| 135 | 
             
                            del lora_model
         | 
| 136 | 
             
                            gc.collect()
         | 
| 137 |  | 
| 138 | 
            +
            def run_lora(prompt, negative, lora_scale, selected_state, progress=gr.Progress(track_tqdm=True)):
         | 
| 139 | 
             
                global last_lora, last_merged, pipe
         | 
| 140 |  | 
| 141 | 
             
                if negative == "":
         | 
|  | |
| 145 | 
             
                    raise gr.Error("You must select a LoRA")
         | 
| 146 | 
             
                repo_name = sdxl_loras[selected_state.index]["repo"]
         | 
| 147 | 
             
                weight_name = sdxl_loras[selected_state.index]["weights"]
         | 
| 148 | 
            +
                full_path_lora = sdxl_loras[selected_state.index]["saved_name"]
         | 
| 149 | 
            +
                loaded_state_dict = sdxl_loras[selected_state.index]["state_dict"]
         | 
| 150 | 
             
                cross_attention_kwargs = None
         | 
| 151 | 
             
                if last_lora != repo_name:
         | 
| 152 | 
             
                    if last_merged:
         | 
|  | |
| 156 | 
             
                        pipe.to(device)
         | 
| 157 | 
             
                    else:
         | 
| 158 | 
             
                        pipe.unload_lora_weights()
         | 
| 159 | 
            +
                        pipe.unfuse_lora()
         | 
| 160 | 
             
                    is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
         | 
| 161 | 
             
                    if is_compatible:
         | 
| 162 | 
            +
                        pipe.load_lora_weights(loaded_state_dict)
         | 
| 163 | 
            +
                        pipe.fuse_lora(lora_scale)
         | 
| 164 | 
             
                    else:
         | 
| 165 | 
             
                        is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
         | 
| 166 | 
             
                        if(is_pivotal):
         | 
| 167 | 
            +
                            pipe.load_lora_weights(loaded_state_dict)
         | 
| 168 | 
            +
                            pipe.fuse_lora(lora_scale)
         | 
| 169 |  | 
|  | |
|  | |
|  | |
| 170 | 
             
                            #Add the textual inversion embeddings from pivotal tuning models
         | 
| 171 | 
             
                            text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
         | 
| 172 | 
             
                            text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
         | 
|  | |
| 185 | 
             
                    height=768,
         | 
| 186 | 
             
                    num_inference_steps=20,
         | 
| 187 | 
             
                    guidance_scale=7.5,
         | 
|  | |
| 188 | 
             
                ).images[0]
         | 
| 189 | 
             
                last_lora = repo_name
         | 
| 190 | 
             
                gc.collect()
         | 

