Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							Β·
						
						c34205b
	
1
								Parent(s):
							
							6f70ac0
								
yes
Browse files- __pycache__/two_stream_shunt_adapter.cpython-310.pyc +0 -0
- app.py +5 -15
- requirements.txt +1 -3
    	
        __pycache__/two_stream_shunt_adapter.cpython-310.pyc
    CHANGED
    
    | Binary files a/__pycache__/two_stream_shunt_adapter.cpython-310.pyc and b/__pycache__/two_stream_shunt_adapter.cpython-310.pyc differ | 
|  | 
    	
        app.py
    CHANGED
    
    | @@ -1,16 +1,15 @@ | |
| 1 | 
            -
            import spaces
         | 
| 2 | 
            -
             | 
| 3 | 
             
            import torch
         | 
| 4 | 
             
            import gradio as gr
         | 
| 5 | 
             
            import numpy as np
         | 
| 6 | 
             
            import matplotlib.pyplot as plt
         | 
|  | |
|  | |
| 7 | 
             
            from transformers import T5Tokenizer, T5EncoderModel
         | 
| 8 | 
             
            from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
         | 
| 9 | 
             
            from safetensors.torch import load_file
         | 
| 10 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 11 | 
             
            from two_stream_shunt_adapter import TwoStreamShuntAdapter
         | 
| 12 | 
             
            from configs import T5_SHUNT_REPOS
         | 
| 13 | 
            -
            from PIL import Image
         | 
| 14 |  | 
| 15 | 
             
            # βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
         | 
| 16 | 
             
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| @@ -69,8 +68,7 @@ def plot_heat(mat, title): | |
| 69 | 
             
                plt.savefig(buf, format="png", bbox_inches='tight')
         | 
| 70 | 
             
                buf.seek(0)
         | 
| 71 | 
             
                plt.close(fig)
         | 
| 72 | 
            -
                 | 
| 73 | 
            -
                return pil_image
         | 
| 74 |  | 
| 75 | 
             
            # βββ SDXL Text Encoding βββββββββββββββββββββββββββββββββββββββ
         | 
| 76 | 
             
            def encode_sdxl_prompt(prompt, negative_prompt=""):
         | 
| @@ -136,7 +134,6 @@ def encode_sdxl_prompt(prompt, negative_prompt=""): | |
| 136 | 
             
                }
         | 
| 137 |  | 
| 138 | 
             
            # βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
         | 
| 139 | 
            -
            @spaces.GPU
         | 
| 140 | 
             
            @torch.no_grad()
         | 
| 141 | 
             
            def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, 
         | 
| 142 | 
             
                      use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
         | 
| @@ -344,15 +341,8 @@ with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: | |
| 344 | 
             
                        prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, 
         | 
| 345 | 
             
                        use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
         | 
| 346 | 
             
                    ],
         | 
| 347 | 
            -
                    outputs=[
         | 
| 348 | 
            -
                        out_img, 
         | 
| 349 | 
            -
                        delta_l, 
         | 
| 350 | 
            -
                        gate_l, 
         | 
| 351 | 
            -
                        delta_g, 
         | 
| 352 | 
            -
                        gate_g, 
         | 
| 353 | 
            -
                        stats_l, 
         | 
| 354 | 
            -
                        stats_g]
         | 
| 355 | 
             
                )
         | 
| 356 |  | 
| 357 | 
             
            if __name__ == "__main__":
         | 
| 358 | 
            -
                demo.launch( | 
|  | |
|  | |
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            import gradio as gr
         | 
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import matplotlib.pyplot as plt
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import spaces
         | 
| 7 | 
             
            from transformers import T5Tokenizer, T5EncoderModel
         | 
| 8 | 
             
            from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
         | 
| 9 | 
             
            from safetensors.torch import load_file
         | 
| 10 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 11 | 
             
            from two_stream_shunt_adapter import TwoStreamShuntAdapter
         | 
| 12 | 
             
            from configs import T5_SHUNT_REPOS
         | 
|  | |
| 13 |  | 
| 14 | 
             
            # βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
         | 
| 15 | 
             
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
|  | |
| 68 | 
             
                plt.savefig(buf, format="png", bbox_inches='tight')
         | 
| 69 | 
             
                buf.seek(0)
         | 
| 70 | 
             
                plt.close(fig)
         | 
| 71 | 
            +
                return buf
         | 
|  | |
| 72 |  | 
| 73 | 
             
            # βββ SDXL Text Encoding βββββββββββββββββββββββββββββββββββββββ
         | 
| 74 | 
             
            def encode_sdxl_prompt(prompt, negative_prompt=""):
         | 
|  | |
| 134 | 
             
                }
         | 
| 135 |  | 
| 136 | 
             
            # βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
         | 
|  | |
| 137 | 
             
            @torch.no_grad()
         | 
| 138 | 
             
            def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, 
         | 
| 139 | 
             
                      use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
         | 
|  | |
| 341 | 
             
                        prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, 
         | 
| 342 | 
             
                        use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
         | 
| 343 | 
             
                    ],
         | 
| 344 | 
            +
                    outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 345 | 
             
                )
         | 
| 346 |  | 
| 347 | 
             
            if __name__ == "__main__":
         | 
| 348 | 
            +
                demo.launch()
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,4 +1,3 @@ | |
| 1 | 
            -
            spaces
         | 
| 2 | 
             
            sentencepiece
         | 
| 3 | 
             
            accelerate
         | 
| 4 | 
             
            diffusers
         | 
| @@ -6,5 +5,4 @@ invisible_watermark | |
| 6 | 
             
            torch
         | 
| 7 | 
             
            transformers
         | 
| 8 | 
             
            xformers
         | 
| 9 | 
            -
            matplotlib
         | 
| 10 | 
            -
            pillow
         | 
|  | |
|  | |
| 1 | 
             
            sentencepiece
         | 
| 2 | 
             
            accelerate
         | 
| 3 | 
             
            diffusers
         | 
|  | |
| 5 | 
             
            torch
         | 
| 6 | 
             
            transformers
         | 
| 7 | 
             
            xformers
         | 
| 8 | 
            +
            matplotlib
         | 
|  | 
