Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		wuwenxu.01
		
	commited on
		
		
					Commit 
							
							·
						
						def2fd8
	
1
								Parent(s):
							
							e8e76e7
								
feat: filter move app code from github
Browse files- .gitignore +178 -0
 - README.md +17 -1
 - app.py +104 -0
 - requirements.txt +10 -0
 - uno/dataset/uno.py +132 -0
 - uno/flux/math.py +45 -0
 - uno/flux/model.py +222 -0
 - uno/flux/modules/autoencoder.py +327 -0
 - uno/flux/modules/conditioner.py +53 -0
 - uno/flux/modules/layers.py +435 -0
 - uno/flux/pipeline.py +324 -0
 - uno/flux/sampling.py +271 -0
 - uno/flux/util.py +390 -0
 - uno/utils/convert_yaml_to_args_file.py +34 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,178 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Byte-compiled / optimized / DLL files
         
     | 
| 2 | 
         
            +
            __pycache__/
         
     | 
| 3 | 
         
            +
            *.py[cod]
         
     | 
| 4 | 
         
            +
            *$py.class
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # C extensions
         
     | 
| 7 | 
         
            +
            *.so
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Distribution / packaging
         
     | 
| 10 | 
         
            +
            .Python
         
     | 
| 11 | 
         
            +
            build/
         
     | 
| 12 | 
         
            +
            develop-eggs/
         
     | 
| 13 | 
         
            +
            dist/
         
     | 
| 14 | 
         
            +
            downloads/
         
     | 
| 15 | 
         
            +
            eggs/
         
     | 
| 16 | 
         
            +
            .eggs/
         
     | 
| 17 | 
         
            +
            lib/
         
     | 
| 18 | 
         
            +
            lib64/
         
     | 
| 19 | 
         
            +
            parts/
         
     | 
| 20 | 
         
            +
            sdist/
         
     | 
| 21 | 
         
            +
            var/
         
     | 
| 22 | 
         
            +
            wheels/
         
     | 
| 23 | 
         
            +
            share/python-wheels/
         
     | 
| 24 | 
         
            +
            *.egg-info/
         
     | 
| 25 | 
         
            +
            .installed.cfg
         
     | 
| 26 | 
         
            +
            *.egg
         
     | 
| 27 | 
         
            +
            MANIFEST
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            # PyInstaller
         
     | 
| 30 | 
         
            +
            #  Usually these files are written by a python script from a template
         
     | 
| 31 | 
         
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         
     | 
| 32 | 
         
            +
            *.manifest
         
     | 
| 33 | 
         
            +
            *.spec
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # Installer logs
         
     | 
| 36 | 
         
            +
            pip-log.txt
         
     | 
| 37 | 
         
            +
            pip-delete-this-directory.txt
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            # Unit test / coverage reports
         
     | 
| 40 | 
         
            +
            htmlcov/
         
     | 
| 41 | 
         
            +
            .tox/
         
     | 
| 42 | 
         
            +
            .nox/
         
     | 
| 43 | 
         
            +
            .coverage
         
     | 
| 44 | 
         
            +
            .coverage.*
         
     | 
| 45 | 
         
            +
            .cache
         
     | 
| 46 | 
         
            +
            nosetests.xml
         
     | 
| 47 | 
         
            +
            coverage.xml
         
     | 
| 48 | 
         
            +
            *.cover
         
     | 
| 49 | 
         
            +
            *.py,cover
         
     | 
| 50 | 
         
            +
            .hypothesis/
         
     | 
| 51 | 
         
            +
            .pytest_cache/
         
     | 
| 52 | 
         
            +
            cover/
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            # Translations
         
     | 
| 55 | 
         
            +
            *.mo
         
     | 
| 56 | 
         
            +
            *.pot
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            # Django stuff:
         
     | 
| 59 | 
         
            +
            *.log
         
     | 
| 60 | 
         
            +
            local_settings.py
         
     | 
| 61 | 
         
            +
            db.sqlite3
         
     | 
| 62 | 
         
            +
            db.sqlite3-journal
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            # Flask stuff:
         
     | 
| 65 | 
         
            +
            instance/
         
     | 
| 66 | 
         
            +
            .webassets-cache
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            # Scrapy stuff:
         
     | 
| 69 | 
         
            +
            .scrapy
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            # Sphinx documentation
         
     | 
| 72 | 
         
            +
            docs/_build/
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            # PyBuilder
         
     | 
| 75 | 
         
            +
            .pybuilder/
         
     | 
| 76 | 
         
            +
            target/
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            # Jupyter Notebook
         
     | 
| 79 | 
         
            +
            .ipynb_checkpoints
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            # IPython
         
     | 
| 82 | 
         
            +
            profile_default/
         
     | 
| 83 | 
         
            +
            ipython_config.py
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # pyenv
         
     | 
| 86 | 
         
            +
            #   For a library or package, you might want to ignore these files since the code is
         
     | 
| 87 | 
         
            +
            #   intended to run in multiple environments; otherwise, check them in:
         
     | 
| 88 | 
         
            +
            # .python-version
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            # pipenv
         
     | 
| 91 | 
         
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         
     | 
| 92 | 
         
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         
     | 
| 93 | 
         
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         
     | 
| 94 | 
         
            +
            #   install all needed dependencies.
         
     | 
| 95 | 
         
            +
            #Pipfile.lock
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            # UV
         
     | 
| 98 | 
         
            +
            #   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
         
     | 
| 99 | 
         
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         
     | 
| 100 | 
         
            +
            #   commonly ignored for libraries.
         
     | 
| 101 | 
         
            +
            #uv.lock
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            # poetry
         
     | 
| 104 | 
         
            +
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         
     | 
| 105 | 
         
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         
     | 
| 106 | 
         
            +
            #   commonly ignored for libraries.
         
     | 
| 107 | 
         
            +
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         
     | 
| 108 | 
         
            +
            #poetry.lock
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            # pdm
         
     | 
| 111 | 
         
            +
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         
     | 
| 112 | 
         
            +
            #pdm.lock
         
     | 
| 113 | 
         
            +
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         
     | 
| 114 | 
         
            +
            #   in version control.
         
     | 
| 115 | 
         
            +
            #   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
         
     | 
| 116 | 
         
            +
            .pdm.toml
         
     | 
| 117 | 
         
            +
            .pdm-python
         
     | 
| 118 | 
         
            +
            .pdm-build/
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         
     | 
| 121 | 
         
            +
            __pypackages__/
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            # Celery stuff
         
     | 
| 124 | 
         
            +
            celerybeat-schedule
         
     | 
| 125 | 
         
            +
            celerybeat.pid
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            # SageMath parsed files
         
     | 
| 128 | 
         
            +
            *.sage.py
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            # Environments
         
     | 
| 131 | 
         
            +
            .env
         
     | 
| 132 | 
         
            +
            .venv
         
     | 
| 133 | 
         
            +
            env/
         
     | 
| 134 | 
         
            +
            venv/
         
     | 
| 135 | 
         
            +
            ENV/
         
     | 
| 136 | 
         
            +
            env.bak/
         
     | 
| 137 | 
         
            +
            venv.bak/
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            # Spyder project settings
         
     | 
| 140 | 
         
            +
            .spyderproject
         
     | 
| 141 | 
         
            +
            .spyproject
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            # Rope project settings
         
     | 
| 144 | 
         
            +
            .ropeproject
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            # mkdocs documentation
         
     | 
| 147 | 
         
            +
            /site
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            # mypy
         
     | 
| 150 | 
         
            +
            .mypy_cache/
         
     | 
| 151 | 
         
            +
            .dmypy.json
         
     | 
| 152 | 
         
            +
            dmypy.json
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            # Pyre type checker
         
     | 
| 155 | 
         
            +
            .pyre/
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            # pytype static type analyzer
         
     | 
| 158 | 
         
            +
            .pytype/
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            # Cython debug symbols
         
     | 
| 161 | 
         
            +
            cython_debug/
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            # PyCharm
         
     | 
| 164 | 
         
            +
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         
     | 
| 165 | 
         
            +
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         
     | 
| 166 | 
         
            +
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         
     | 
| 167 | 
         
            +
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         
     | 
| 168 | 
         
            +
            #.idea/
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            # Ruff stuff:
         
     | 
| 171 | 
         
            +
            .ruff_cache/
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            # PyPI configuration file
         
     | 
| 174 | 
         
            +
            .pypirc
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            # User config files
         
     | 
| 177 | 
         
            +
            .vscode/
         
     | 
| 178 | 
         
            +
            output/
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,6 +1,6 @@ 
     | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
             
            title: UNO FLUX
         
     | 
| 3 | 
         
            -
            emoji:  
     | 
| 4 | 
         
             
            colorFrom: indigo
         
     | 
| 5 | 
         
             
            colorTo: yellow
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 
         @@ -9,6 +9,22 @@ app_file: app.py 
     | 
|
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
             
            license: cc-by-nc-4.0
         
     | 
| 11 | 
         
             
            short_description: Generate customized images using text and multiple images
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            ---
         
     | 
| 13 | 
         | 
| 14 | 
         
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
             
            title: UNO FLUX
         
     | 
| 3 | 
         
            +
            emoji: ⚡️
         
     | 
| 4 | 
         
             
            colorFrom: indigo
         
     | 
| 5 | 
         
             
            colorTo: yellow
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 
         | 
|
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
             
            license: cc-by-nc-4.0
         
     | 
| 11 | 
         
             
            short_description: Generate customized images using text and multiple images
         
     | 
| 12 | 
         
            +
            models:
         
     | 
| 13 | 
         
            +
              - black-forest-labs/FLUX.1-dev
         
     | 
| 14 | 
         
            +
              - bytedance-research/UNO
         
     | 
| 15 | 
         
             
            ---
         
     | 
| 16 | 
         | 
| 17 | 
         
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            ## 📄 Disclaimer
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            We open-source this project for academic research. The vast majority of images 
         
     | 
| 22 | 
         
            +
            used in this project are either generated or licensed. If you have any concerns, 
         
     | 
| 23 | 
         
            +
            please contact us, and we will promptly remove any inappropriate content. 
         
     | 
| 24 | 
         
            +
            Our code is released under the Apache 2.0 License,, while our models are under 
         
     | 
| 25 | 
         
            +
            the CC BY-NC 4.0 License. Any models related to [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
         
     | 
| 26 | 
         
            +
            base model must adhere to the original licensing terms.
         
     | 
| 27 | 
         
            +
            This research aims to advance the field of generative AI. Users are free to 
         
     | 
| 28 | 
         
            +
            create images using this tool, provided they comply with local laws and exercise 
         
     | 
| 29 | 
         
            +
            responsible usage. The developers are not liable for any misuse of the tool by users.
         
     | 
| 30 | 
         
            +
             
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,104 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import dataclasses
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import gradio as gr
         
     | 
| 18 | 
         
            +
            import torch
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from uno.flux.pipeline import UNOPipeline
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def create_demo(
         
     | 
| 24 | 
         
            +
                model_type: str,
         
     | 
| 25 | 
         
            +
                device: str = "cuda" if torch.cuda.is_available() else "cpu",
         
     | 
| 26 | 
         
            +
                offload: bool = False,
         
     | 
| 27 | 
         
            +
            ):
         
     | 
| 28 | 
         
            +
                pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                with gr.Blocks() as demo:
         
     | 
| 31 | 
         
            +
                    gr.Markdown(f"# UNO by UNO team")
         
     | 
| 32 | 
         
            +
                    with gr.Row():
         
     | 
| 33 | 
         
            +
                        with gr.Column():
         
     | 
| 34 | 
         
            +
                            prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
         
     | 
| 35 | 
         
            +
                            with gr.Row():
         
     | 
| 36 | 
         
            +
                                image_prompt1 = gr.Image(label="ref img1", visible=True, interactive=True, type="pil")
         
     | 
| 37 | 
         
            +
                                image_prompt2 = gr.Image(label="ref img2", visible=True, interactive=True, type="pil")
         
     | 
| 38 | 
         
            +
                                image_prompt3 = gr.Image(label="ref img3", visible=True, interactive=True, type="pil")
         
     | 
| 39 | 
         
            +
                                image_prompt4 = gr.Image(label="ref img4", visible=True, interactive=True, type="pil")
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                            with gr.Row():
         
     | 
| 42 | 
         
            +
                                with gr.Column():
         
     | 
| 43 | 
         
            +
                                    ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
         
     | 
| 44 | 
         
            +
                                with gr.Column():
         
     | 
| 45 | 
         
            +
                                    gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
         
     | 
| 46 | 
         
            +
                                    gr.Markdown("   1->512 / 2->320 / 3...n->256")
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                            with gr.Row():
         
     | 
| 49 | 
         
            +
                                with gr.Column():
         
     | 
| 50 | 
         
            +
                                    width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
         
     | 
| 51 | 
         
            +
                                    height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
         
     | 
| 52 | 
         
            +
                                with gr.Column():
         
     | 
| 53 | 
         
            +
                                    gr.Markdown("📌 The model trained on 512x512 resolution.\n")
         
     | 
| 54 | 
         
            +
                                    gr.Markdown(
         
     | 
| 55 | 
         
            +
                                        "The size closer to 512 is more stable,"
         
     | 
| 56 | 
         
            +
                                        " and the higher size gives a better visual effect but is less stable"
         
     | 
| 57 | 
         
            +
                                    )
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                            with gr.Accordion("Generation Options", open=False):
         
     | 
| 60 | 
         
            +
                                with gr.Row():
         
     | 
| 61 | 
         
            +
                                    num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
         
     | 
| 62 | 
         
            +
                                    guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
         
     | 
| 63 | 
         
            +
                                    seed = gr.Number(-1, label="Seed (-1 for random)")
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                            generate_btn = gr.Button("Generate")
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                        with gr.Column():
         
     | 
| 68 | 
         
            +
                            output_image = gr.Image(label="Generated Image")
         
     | 
| 69 | 
         
            +
                            download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                        inputs = [
         
     | 
| 73 | 
         
            +
                            prompt, width, height, guidance, num_steps,
         
     | 
| 74 | 
         
            +
                            seed, ref_long_side, image_prompt1, image_prompt2, image_prompt3, image_prompt4
         
     | 
| 75 | 
         
            +
                        ]
         
     | 
| 76 | 
         
            +
                        generate_btn.click(
         
     | 
| 77 | 
         
            +
                            fn=pipeline.gradio_generate,
         
     | 
| 78 | 
         
            +
                            inputs=inputs,
         
     | 
| 79 | 
         
            +
                            outputs=[output_image, download_btn],
         
     | 
| 80 | 
         
            +
                        )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                return demo
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 85 | 
         
            +
                from typing import Literal
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                from transformers import HfArgumentParser
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                @dataclasses.dataclass
         
     | 
| 90 | 
         
            +
                class AppArgs:
         
     | 
| 91 | 
         
            +
                    name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
         
     | 
| 92 | 
         
            +
                    device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 93 | 
         
            +
                    offload: bool = dataclasses.field(
         
     | 
| 94 | 
         
            +
                        default=False,
         
     | 
| 95 | 
         
            +
                        metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
         
     | 
| 96 | 
         
            +
                    )
         
     | 
| 97 | 
         
            +
                    port: int = 7860
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                parser = HfArgumentParser([AppArgs])
         
     | 
| 100 | 
         
            +
                args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
         
     | 
| 101 | 
         
            +
                args = args_tuple[0]
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                demo = create_demo(args.name, args.device, args.offload)
         
     | 
| 104 | 
         
            +
                demo.launch(server_port=args.port)
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            einops==0.8.0
         
     | 
| 2 | 
         
            +
            transformers==4.43.3
         
     | 
| 3 | 
         
            +
            huggingface-hub
         
     | 
| 4 | 
         
            +
            diffusers==0.30.1
         
     | 
| 5 | 
         
            +
            sentencepiece==0.2.0
         
     | 
| 6 | 
         
            +
            gradio==5.22.0
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            --extra-index-url https://download.pytorch.org/whl/cu124
         
     | 
| 9 | 
         
            +
            torch==2.4.0
         
     | 
| 10 | 
         
            +
            torchvision==0.19.0
         
     | 
    	
        uno/dataset/uno.py
    ADDED
    
    | 
         @@ -0,0 +1,132 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import json
         
     | 
| 17 | 
         
            +
            import os
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import numpy as np
         
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
            import torchvision.transforms.functional as TVF
         
     | 
| 22 | 
         
            +
            from torch.utils.data import DataLoader, Dataset
         
     | 
| 23 | 
         
            +
            from torchvision.transforms import Compose, Normalize, ToTensor
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def bucket_images(images: list[torch.Tensor], resolution: int = 512):
         
     | 
| 26 | 
         
            +
                bucket_override=[
         
     | 
| 27 | 
         
            +
                    # h    w
         
     | 
| 28 | 
         
            +
                    (256, 768),
         
     | 
| 29 | 
         
            +
                    (320, 768),
         
     | 
| 30 | 
         
            +
                    (320, 704),
         
     | 
| 31 | 
         
            +
                    (384, 640),
         
     | 
| 32 | 
         
            +
                    (448, 576),
         
     | 
| 33 | 
         
            +
                    (512, 512),
         
     | 
| 34 | 
         
            +
                    (576, 448),
         
     | 
| 35 | 
         
            +
                    (640, 384),
         
     | 
| 36 | 
         
            +
                    (704, 320),
         
     | 
| 37 | 
         
            +
                    (768, 320),
         
     | 
| 38 | 
         
            +
                    (768, 256)
         
     | 
| 39 | 
         
            +
                ]
         
     | 
| 40 | 
         
            +
                bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
         
     | 
| 41 | 
         
            +
                bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
         
     | 
| 44 | 
         
            +
                mean_aspect_ratio = np.mean(aspect_ratios)
         
     | 
| 45 | 
         
            +
                
         
     | 
| 46 | 
         
            +
                new_h, new_w = bucket_override[0]
         
     | 
| 47 | 
         
            +
                min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
         
     | 
| 48 | 
         
            +
                for h, w in bucket_override:
         
     | 
| 49 | 
         
            +
                    aspect_diff = np.abs(h / w - mean_aspect_ratio)
         
     | 
| 50 | 
         
            +
                    if aspect_diff < min_aspect_diff:
         
     | 
| 51 | 
         
            +
                        min_aspect_diff = aspect_diff
         
     | 
| 52 | 
         
            +
                        new_h, new_w = h, w
         
     | 
| 53 | 
         
            +
                
         
     | 
| 54 | 
         
            +
                images = [TVF.resize(image, (new_h, new_w)) for image in images]
         
     | 
| 55 | 
         
            +
                images = torch.stack(images, dim=0)
         
     | 
| 56 | 
         
            +
                return images
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            class FluxPairedDatasetV2(Dataset):
         
     | 
| 59 | 
         
            +
                def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
         
     | 
| 60 | 
         
            +
                    super().__init__()
         
     | 
| 61 | 
         
            +
                    self.json_file = json_file
         
     | 
| 62 | 
         
            +
                    self.resolution = resolution
         
     | 
| 63 | 
         
            +
                    self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
         
     | 
| 64 | 
         
            +
                    self.image_root = os.path.dirname(json_file)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    with open(self.json_file, "rt") as f:
         
     | 
| 67 | 
         
            +
                        self.data_dicts = json.load(f)
         
     | 
| 68 | 
         
            +
                    
         
     | 
| 69 | 
         
            +
                    self.transform = Compose([
         
     | 
| 70 | 
         
            +
                        ToTensor(),
         
     | 
| 71 | 
         
            +
                        Normalize([0.5], [0.5]),
         
     | 
| 72 | 
         
            +
                    ])
         
     | 
| 73 | 
         
            +
                
         
     | 
| 74 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 75 | 
         
            +
                    data_dict = self.data_dicts[idx]
         
     | 
| 76 | 
         
            +
                    image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
         
     | 
| 77 | 
         
            +
                    txt = data_dict["prompt"]
         
     | 
| 78 | 
         
            +
                    image_tgt_path = data_dict.get("image_tgt_path", None)
         
     | 
| 79 | 
         
            +
                    ref_imgs = [
         
     | 
| 80 | 
         
            +
                        Image.open(os.path.join(self.image_root, path)).convert("RGB")
         
     | 
| 81 | 
         
            +
                        for path in image_paths
         
     | 
| 82 | 
         
            +
                    ]
         
     | 
| 83 | 
         
            +
                    ref_imgs = [self.transform(img) for img in ref_imgs]
         
     | 
| 84 | 
         
            +
                    img = None
         
     | 
| 85 | 
         
            +
                    if image_tgt_path is not None:
         
     | 
| 86 | 
         
            +
                        img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
         
     | 
| 87 | 
         
            +
                        img = self.transform(img)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    return {
         
     | 
| 90 | 
         
            +
                        "img": img,
         
     | 
| 91 | 
         
            +
                        "txt": txt,
         
     | 
| 92 | 
         
            +
                        "ref_imgs": ref_imgs,
         
     | 
| 93 | 
         
            +
                    }
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def __len__(self):
         
     | 
| 96 | 
         
            +
                    return len(self.data_dicts)
         
     | 
| 97 | 
         
            +
                
         
     | 
| 98 | 
         
            +
                def collate_fn(self, batch):
         
     | 
| 99 | 
         
            +
                    img = [data["img"] for data in batch]
         
     | 
| 100 | 
         
            +
                    txt = [data["txt"] for data in batch]
         
     | 
| 101 | 
         
            +
                    ref_imgs = [data["ref_imgs"] for data in batch]
         
     | 
| 102 | 
         
            +
                    assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    n_ref = len(ref_imgs[0])
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    img = bucket_images(img, self.resolution)
         
     | 
| 107 | 
         
            +
                    ref_imgs_new = []
         
     | 
| 108 | 
         
            +
                    for i in range(n_ref):
         
     | 
| 109 | 
         
            +
                        ref_imgs_i = [refs[i] for refs in ref_imgs]
         
     | 
| 110 | 
         
            +
                        ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
         
     | 
| 111 | 
         
            +
                        ref_imgs_new.append(ref_imgs_i)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    return {
         
     | 
| 114 | 
         
            +
                        "txt": txt,
         
     | 
| 115 | 
         
            +
                        "img": img,
         
     | 
| 116 | 
         
            +
                        "ref_imgs": ref_imgs_new,
         
     | 
| 117 | 
         
            +
                    }
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 120 | 
         
            +
                import argparse
         
     | 
| 121 | 
         
            +
                from pprint import pprint
         
     | 
| 122 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 123 | 
         
            +
                # parser.add_argument("--json_file", type=str, required=True)
         
     | 
| 124 | 
         
            +
                parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
         
     | 
| 125 | 
         
            +
                args = parser.parse_args()
         
     | 
| 126 | 
         
            +
                dataset = FluxPairedDatasetV2(args.json_file, 512)
         
     | 
| 127 | 
         
            +
                dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                for i, data_dict in enumerate(dataloder):
         
     | 
| 130 | 
         
            +
                    pprint(i)
         
     | 
| 131 | 
         
            +
                    pprint(data_dict)
         
     | 
| 132 | 
         
            +
                    breakpoint()
         
     | 
    	
        uno/flux/math.py
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            from einops import rearrange
         
     | 
| 18 | 
         
            +
            from torch import Tensor
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
         
     | 
| 22 | 
         
            +
                q, k = apply_rope(q, k, pe)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
         
     | 
| 25 | 
         
            +
                x = rearrange(x, "B H L D -> B L (H D)")
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                return x
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
         
     | 
| 31 | 
         
            +
                assert dim % 2 == 0
         
     | 
| 32 | 
         
            +
                scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
         
     | 
| 33 | 
         
            +
                omega = 1.0 / (theta**scale)
         
     | 
| 34 | 
         
            +
                out = torch.einsum("...n,d->...nd", pos, omega)
         
     | 
| 35 | 
         
            +
                out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
         
     | 
| 36 | 
         
            +
                out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
         
     | 
| 37 | 
         
            +
                return out.float()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
         
     | 
| 41 | 
         
            +
                xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
         
     | 
| 42 | 
         
            +
                xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
         
     | 
| 43 | 
         
            +
                xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
         
     | 
| 44 | 
         
            +
                xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
         
     | 
| 45 | 
         
            +
                return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
         
     | 
    	
        uno/flux/model.py
    ADDED
    
    | 
         @@ -0,0 +1,222 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import torch
         
     | 
| 19 | 
         
            +
            from torch import Tensor, nn
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            @dataclass
         
     | 
| 25 | 
         
            +
            class FluxParams:
         
     | 
| 26 | 
         
            +
                in_channels: int
         
     | 
| 27 | 
         
            +
                vec_in_dim: int
         
     | 
| 28 | 
         
            +
                context_in_dim: int
         
     | 
| 29 | 
         
            +
                hidden_size: int
         
     | 
| 30 | 
         
            +
                mlp_ratio: float
         
     | 
| 31 | 
         
            +
                num_heads: int
         
     | 
| 32 | 
         
            +
                depth: int
         
     | 
| 33 | 
         
            +
                depth_single_blocks: int
         
     | 
| 34 | 
         
            +
                axes_dim: list[int]
         
     | 
| 35 | 
         
            +
                theta: int
         
     | 
| 36 | 
         
            +
                qkv_bias: bool
         
     | 
| 37 | 
         
            +
                guidance_embed: bool
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class Flux(nn.Module):
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
                Transformer model for flow matching on sequences.
         
     | 
| 43 | 
         
            +
                """
         
     | 
| 44 | 
         
            +
                _supports_gradient_checkpointing = True
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def __init__(self, params: FluxParams):
         
     | 
| 47 | 
         
            +
                    super().__init__()
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    self.params = params
         
     | 
| 50 | 
         
            +
                    self.in_channels = params.in_channels
         
     | 
| 51 | 
         
            +
                    self.out_channels = self.in_channels
         
     | 
| 52 | 
         
            +
                    if params.hidden_size % params.num_heads != 0:
         
     | 
| 53 | 
         
            +
                        raise ValueError(
         
     | 
| 54 | 
         
            +
                            f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
         
     | 
| 55 | 
         
            +
                        )
         
     | 
| 56 | 
         
            +
                    pe_dim = params.hidden_size // params.num_heads
         
     | 
| 57 | 
         
            +
                    if sum(params.axes_dim) != pe_dim:
         
     | 
| 58 | 
         
            +
                        raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
         
     | 
| 59 | 
         
            +
                    self.hidden_size = params.hidden_size
         
     | 
| 60 | 
         
            +
                    self.num_heads = params.num_heads
         
     | 
| 61 | 
         
            +
                    self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
         
     | 
| 62 | 
         
            +
                    self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
         
     | 
| 63 | 
         
            +
                    self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
         
     | 
| 64 | 
         
            +
                    self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
         
     | 
| 65 | 
         
            +
                    self.guidance_in = (
         
     | 
| 66 | 
         
            +
                        MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
         
     | 
| 67 | 
         
            +
                    )
         
     | 
| 68 | 
         
            +
                    self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    self.double_blocks = nn.ModuleList(
         
     | 
| 71 | 
         
            +
                        [
         
     | 
| 72 | 
         
            +
                            DoubleStreamBlock(
         
     | 
| 73 | 
         
            +
                                self.hidden_size,
         
     | 
| 74 | 
         
            +
                                self.num_heads,
         
     | 
| 75 | 
         
            +
                                mlp_ratio=params.mlp_ratio,
         
     | 
| 76 | 
         
            +
                                qkv_bias=params.qkv_bias,
         
     | 
| 77 | 
         
            +
                            )
         
     | 
| 78 | 
         
            +
                            for _ in range(params.depth)
         
     | 
| 79 | 
         
            +
                        ]
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    self.single_blocks = nn.ModuleList(
         
     | 
| 83 | 
         
            +
                        [
         
     | 
| 84 | 
         
            +
                            SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
         
     | 
| 85 | 
         
            +
                            for _ in range(params.depth_single_blocks)
         
     | 
| 86 | 
         
            +
                        ]
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
         
     | 
| 90 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                def _set_gradient_checkpointing(self, module, value=False):
         
     | 
| 93 | 
         
            +
                    if hasattr(module, "gradient_checkpointing"):
         
     | 
| 94 | 
         
            +
                        module.gradient_checkpointing = value
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                @property
         
     | 
| 97 | 
         
            +
                def attn_processors(self):
         
     | 
| 98 | 
         
            +
                    # set recursively
         
     | 
| 99 | 
         
            +
                    processors = {}  # type: dict[str, nn.Module]
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
         
     | 
| 102 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 103 | 
         
            +
                            processors[f"{name}.processor"] = module.processor
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 106 | 
         
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                        return processors
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 111 | 
         
            +
                        fn_recursive_add_processors(name, module, processors)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    return processors
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                def set_attn_processor(self, processor):
         
     | 
| 116 | 
         
            +
                    r"""
         
     | 
| 117 | 
         
            +
                    Sets the attention processor to use to compute attention.
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    Parameters:
         
     | 
| 120 | 
         
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         
     | 
| 121 | 
         
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         
     | 
| 122 | 
         
            +
                            for **all** `Attention` layers.
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         
     | 
| 125 | 
         
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    """
         
     | 
| 128 | 
         
            +
                    count = len(self.attn_processors.keys())
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    if isinstance(processor, dict) and len(processor) != count:
         
     | 
| 131 | 
         
            +
                        raise ValueError(
         
     | 
| 132 | 
         
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         
     | 
| 133 | 
         
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         
     | 
| 134 | 
         
            +
                        )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         
     | 
| 137 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 138 | 
         
            +
                            if not isinstance(processor, dict):
         
     | 
| 139 | 
         
            +
                                module.set_processor(processor)
         
     | 
| 140 | 
         
            +
                            else:
         
     | 
| 141 | 
         
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 144 | 
         
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 147 | 
         
            +
                        fn_recursive_attn_processor(name, module, processor)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def forward(
         
     | 
| 150 | 
         
            +
                    self,
         
     | 
| 151 | 
         
            +
                    img: Tensor,
         
     | 
| 152 | 
         
            +
                    img_ids: Tensor,
         
     | 
| 153 | 
         
            +
                    txt: Tensor,
         
     | 
| 154 | 
         
            +
                    txt_ids: Tensor,
         
     | 
| 155 | 
         
            +
                    timesteps: Tensor,
         
     | 
| 156 | 
         
            +
                    y: Tensor,
         
     | 
| 157 | 
         
            +
                    guidance: Tensor | None = None,
         
     | 
| 158 | 
         
            +
                    ref_img: Tensor | None = None, 
         
     | 
| 159 | 
         
            +
                    ref_img_ids: Tensor | None = None, 
         
     | 
| 160 | 
         
            +
                ) -> Tensor:
         
     | 
| 161 | 
         
            +
                    if img.ndim != 3 or txt.ndim != 3:
         
     | 
| 162 | 
         
            +
                        raise ValueError("Input img and txt tensors must have 3 dimensions.")
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    # running on sequences img
         
     | 
| 165 | 
         
            +
                    img = self.img_in(img)
         
     | 
| 166 | 
         
            +
                    vec = self.time_in(timestep_embedding(timesteps, 256))
         
     | 
| 167 | 
         
            +
                    if self.params.guidance_embed:
         
     | 
| 168 | 
         
            +
                        if guidance is None:
         
     | 
| 169 | 
         
            +
                            raise ValueError("Didn't get guidance strength for guidance distilled model.")
         
     | 
| 170 | 
         
            +
                        vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
         
     | 
| 171 | 
         
            +
                    vec = vec + self.vector_in(y)
         
     | 
| 172 | 
         
            +
                    txt = self.txt_in(txt)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    ids = torch.cat((txt_ids, img_ids), dim=1)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    # concat ref_img/img
         
     | 
| 177 | 
         
            +
                    img_end = img.shape[1]
         
     | 
| 178 | 
         
            +
                    if ref_img is not None:
         
     | 
| 179 | 
         
            +
                        if isinstance(ref_img, tuple) or isinstance(ref_img, list):
         
     | 
| 180 | 
         
            +
                            img_in = [img] + [self.img_in(ref) for ref in ref_img]
         
     | 
| 181 | 
         
            +
                            img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
         
     | 
| 182 | 
         
            +
                            img = torch.cat(img_in, dim=1)  
         
     | 
| 183 | 
         
            +
                            ids = torch.cat(img_ids, dim=1)
         
     | 
| 184 | 
         
            +
                        else:
         
     | 
| 185 | 
         
            +
                            img = torch.cat((img, self.img_in(ref_img)), dim=1)  
         
     | 
| 186 | 
         
            +
                            ids = torch.cat((ids, ref_img_ids), dim=1)
         
     | 
| 187 | 
         
            +
                    pe = self.pe_embedder(ids)
         
     | 
| 188 | 
         
            +
                    
         
     | 
| 189 | 
         
            +
                    for index_block, block in enumerate(self.double_blocks):
         
     | 
| 190 | 
         
            +
                        if self.training and self.gradient_checkpointing:
         
     | 
| 191 | 
         
            +
                            img, txt = torch.utils.checkpoint.checkpoint(
         
     | 
| 192 | 
         
            +
                                block,
         
     | 
| 193 | 
         
            +
                                img=img, 
         
     | 
| 194 | 
         
            +
                                txt=txt, 
         
     | 
| 195 | 
         
            +
                                vec=vec, 
         
     | 
| 196 | 
         
            +
                                pe=pe, 
         
     | 
| 197 | 
         
            +
                                use_reentrant=False,
         
     | 
| 198 | 
         
            +
                            )
         
     | 
| 199 | 
         
            +
                        else:
         
     | 
| 200 | 
         
            +
                            img, txt = block(
         
     | 
| 201 | 
         
            +
                                img=img, 
         
     | 
| 202 | 
         
            +
                                txt=txt, 
         
     | 
| 203 | 
         
            +
                                vec=vec, 
         
     | 
| 204 | 
         
            +
                                pe=pe
         
     | 
| 205 | 
         
            +
                            )
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    img = torch.cat((txt, img), 1)
         
     | 
| 208 | 
         
            +
                    for block in self.single_blocks:
         
     | 
| 209 | 
         
            +
                        if self.training and self.gradient_checkpointing:
         
     | 
| 210 | 
         
            +
                            img = torch.utils.checkpoint.checkpoint(
         
     | 
| 211 | 
         
            +
                                block,
         
     | 
| 212 | 
         
            +
                                img, vec=vec, pe=pe,
         
     | 
| 213 | 
         
            +
                                use_reentrant=False
         
     | 
| 214 | 
         
            +
                            )
         
     | 
| 215 | 
         
            +
                        else:
         
     | 
| 216 | 
         
            +
                            img = block(img, vec=vec, pe=pe)
         
     | 
| 217 | 
         
            +
                    img = img[:, txt.shape[1] :, ...]
         
     | 
| 218 | 
         
            +
                    # index img
         
     | 
| 219 | 
         
            +
                    img = img[:, :img_end, ...]
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
         
     | 
| 222 | 
         
            +
                    return img
         
     | 
    	
        uno/flux/modules/autoencoder.py
    ADDED
    
    | 
         @@ -0,0 +1,327 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import torch
         
     | 
| 19 | 
         
            +
            from einops import rearrange
         
     | 
| 20 | 
         
            +
            from torch import Tensor, nn
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            @dataclass
         
     | 
| 24 | 
         
            +
            class AutoEncoderParams:
         
     | 
| 25 | 
         
            +
                resolution: int
         
     | 
| 26 | 
         
            +
                in_channels: int
         
     | 
| 27 | 
         
            +
                ch: int
         
     | 
| 28 | 
         
            +
                out_ch: int
         
     | 
| 29 | 
         
            +
                ch_mult: list[int]
         
     | 
| 30 | 
         
            +
                num_res_blocks: int
         
     | 
| 31 | 
         
            +
                z_channels: int
         
     | 
| 32 | 
         
            +
                scale_factor: float
         
     | 
| 33 | 
         
            +
                shift_factor: float
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def swish(x: Tensor) -> Tensor:
         
     | 
| 37 | 
         
            +
                return x * torch.sigmoid(x)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class AttnBlock(nn.Module):
         
     | 
| 41 | 
         
            +
                def __init__(self, in_channels: int):
         
     | 
| 42 | 
         
            +
                    super().__init__()
         
     | 
| 43 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
         
     | 
| 48 | 
         
            +
                    self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
         
     | 
| 49 | 
         
            +
                    self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
         
     | 
| 50 | 
         
            +
                    self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def attention(self, h_: Tensor) -> Tensor:
         
     | 
| 53 | 
         
            +
                    h_ = self.norm(h_)
         
     | 
| 54 | 
         
            +
                    q = self.q(h_)
         
     | 
| 55 | 
         
            +
                    k = self.k(h_)
         
     | 
| 56 | 
         
            +
                    v = self.v(h_)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    b, c, h, w = q.shape
         
     | 
| 59 | 
         
            +
                    q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
         
     | 
| 60 | 
         
            +
                    k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
         
     | 
| 61 | 
         
            +
                    v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
         
     | 
| 62 | 
         
            +
                    h_ = nn.functional.scaled_dot_product_attention(q, k, v)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def forward(self, x: Tensor) -> Tensor:
         
     | 
| 67 | 
         
            +
                    return x + self.proj_out(self.attention(x))
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            class ResnetBlock(nn.Module):
         
     | 
| 71 | 
         
            +
                def __init__(self, in_channels: int, out_channels: int):
         
     | 
| 72 | 
         
            +
                    super().__init__()
         
     | 
| 73 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 74 | 
         
            +
                    out_channels = in_channels if out_channels is None else out_channels
         
     | 
| 75 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
         
     | 
| 78 | 
         
            +
                    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 79 | 
         
            +
                    self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
         
     | 
| 80 | 
         
            +
                    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 81 | 
         
            +
                    if self.in_channels != self.out_channels:
         
     | 
| 82 | 
         
            +
                        self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                def forward(self, x):
         
     | 
| 85 | 
         
            +
                    h = x
         
     | 
| 86 | 
         
            +
                    h = self.norm1(h)
         
     | 
| 87 | 
         
            +
                    h = swish(h)
         
     | 
| 88 | 
         
            +
                    h = self.conv1(h)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    h = self.norm2(h)
         
     | 
| 91 | 
         
            +
                    h = swish(h)
         
     | 
| 92 | 
         
            +
                    h = self.conv2(h)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    if self.in_channels != self.out_channels:
         
     | 
| 95 | 
         
            +
                        x = self.nin_shortcut(x)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    return x + h
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            class Downsample(nn.Module):
         
     | 
| 101 | 
         
            +
                def __init__(self, in_channels: int):
         
     | 
| 102 | 
         
            +
                    super().__init__()
         
     | 
| 103 | 
         
            +
                    # no asymmetric padding in torch conv, must do it ourselves
         
     | 
| 104 | 
         
            +
                    self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def forward(self, x: Tensor):
         
     | 
| 107 | 
         
            +
                    pad = (0, 1, 0, 1)
         
     | 
| 108 | 
         
            +
                    x = nn.functional.pad(x, pad, mode="constant", value=0)
         
     | 
| 109 | 
         
            +
                    x = self.conv(x)
         
     | 
| 110 | 
         
            +
                    return x
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            class Upsample(nn.Module):
         
     | 
| 114 | 
         
            +
                def __init__(self, in_channels: int):
         
     | 
| 115 | 
         
            +
                    super().__init__()
         
     | 
| 116 | 
         
            +
                    self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def forward(self, x: Tensor):
         
     | 
| 119 | 
         
            +
                    x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
         
     | 
| 120 | 
         
            +
                    x = self.conv(x)
         
     | 
| 121 | 
         
            +
                    return x
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            class Encoder(nn.Module):
         
     | 
| 125 | 
         
            +
                def __init__(
         
     | 
| 126 | 
         
            +
                    self,
         
     | 
| 127 | 
         
            +
                    resolution: int,
         
     | 
| 128 | 
         
            +
                    in_channels: int,
         
     | 
| 129 | 
         
            +
                    ch: int,
         
     | 
| 130 | 
         
            +
                    ch_mult: list[int],
         
     | 
| 131 | 
         
            +
                    num_res_blocks: int,
         
     | 
| 132 | 
         
            +
                    z_channels: int,
         
     | 
| 133 | 
         
            +
                ):
         
     | 
| 134 | 
         
            +
                    super().__init__()
         
     | 
| 135 | 
         
            +
                    self.ch = ch
         
     | 
| 136 | 
         
            +
                    self.num_resolutions = len(ch_mult)
         
     | 
| 137 | 
         
            +
                    self.num_res_blocks = num_res_blocks
         
     | 
| 138 | 
         
            +
                    self.resolution = resolution
         
     | 
| 139 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 140 | 
         
            +
                    # downsampling
         
     | 
| 141 | 
         
            +
                    self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    curr_res = resolution
         
     | 
| 144 | 
         
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         
     | 
| 145 | 
         
            +
                    self.in_ch_mult = in_ch_mult
         
     | 
| 146 | 
         
            +
                    self.down = nn.ModuleList()
         
     | 
| 147 | 
         
            +
                    block_in = self.ch
         
     | 
| 148 | 
         
            +
                    for i_level in range(self.num_resolutions):
         
     | 
| 149 | 
         
            +
                        block = nn.ModuleList()
         
     | 
| 150 | 
         
            +
                        attn = nn.ModuleList()
         
     | 
| 151 | 
         
            +
                        block_in = ch * in_ch_mult[i_level]
         
     | 
| 152 | 
         
            +
                        block_out = ch * ch_mult[i_level]
         
     | 
| 153 | 
         
            +
                        for _ in range(self.num_res_blocks):
         
     | 
| 154 | 
         
            +
                            block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
         
     | 
| 155 | 
         
            +
                            block_in = block_out
         
     | 
| 156 | 
         
            +
                        down = nn.Module()
         
     | 
| 157 | 
         
            +
                        down.block = block
         
     | 
| 158 | 
         
            +
                        down.attn = attn
         
     | 
| 159 | 
         
            +
                        if i_level != self.num_resolutions - 1:
         
     | 
| 160 | 
         
            +
                            down.downsample = Downsample(block_in)
         
     | 
| 161 | 
         
            +
                            curr_res = curr_res // 2
         
     | 
| 162 | 
         
            +
                        self.down.append(down)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    # middle
         
     | 
| 165 | 
         
            +
                    self.mid = nn.Module()
         
     | 
| 166 | 
         
            +
                    self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
         
     | 
| 167 | 
         
            +
                    self.mid.attn_1 = AttnBlock(block_in)
         
     | 
| 168 | 
         
            +
                    self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    # end
         
     | 
| 171 | 
         
            +
                    self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
         
     | 
| 172 | 
         
            +
                    self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                def forward(self, x: Tensor) -> Tensor:
         
     | 
| 175 | 
         
            +
                    # downsampling
         
     | 
| 176 | 
         
            +
                    hs = [self.conv_in(x)]
         
     | 
| 177 | 
         
            +
                    for i_level in range(self.num_resolutions):
         
     | 
| 178 | 
         
            +
                        for i_block in range(self.num_res_blocks):
         
     | 
| 179 | 
         
            +
                            h = self.down[i_level].block[i_block](hs[-1])
         
     | 
| 180 | 
         
            +
                            if len(self.down[i_level].attn) > 0:
         
     | 
| 181 | 
         
            +
                                h = self.down[i_level].attn[i_block](h)
         
     | 
| 182 | 
         
            +
                            hs.append(h)
         
     | 
| 183 | 
         
            +
                        if i_level != self.num_resolutions - 1:
         
     | 
| 184 | 
         
            +
                            hs.append(self.down[i_level].downsample(hs[-1]))
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # middle
         
     | 
| 187 | 
         
            +
                    h = hs[-1]
         
     | 
| 188 | 
         
            +
                    h = self.mid.block_1(h)
         
     | 
| 189 | 
         
            +
                    h = self.mid.attn_1(h)
         
     | 
| 190 | 
         
            +
                    h = self.mid.block_2(h)
         
     | 
| 191 | 
         
            +
                    # end
         
     | 
| 192 | 
         
            +
                    h = self.norm_out(h)
         
     | 
| 193 | 
         
            +
                    h = swish(h)
         
     | 
| 194 | 
         
            +
                    h = self.conv_out(h)
         
     | 
| 195 | 
         
            +
                    return h
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            class Decoder(nn.Module):
         
     | 
| 199 | 
         
            +
                def __init__(
         
     | 
| 200 | 
         
            +
                    self,
         
     | 
| 201 | 
         
            +
                    ch: int,
         
     | 
| 202 | 
         
            +
                    out_ch: int,
         
     | 
| 203 | 
         
            +
                    ch_mult: list[int],
         
     | 
| 204 | 
         
            +
                    num_res_blocks: int,
         
     | 
| 205 | 
         
            +
                    in_channels: int,
         
     | 
| 206 | 
         
            +
                    resolution: int,
         
     | 
| 207 | 
         
            +
                    z_channels: int,
         
     | 
| 208 | 
         
            +
                ):
         
     | 
| 209 | 
         
            +
                    super().__init__()
         
     | 
| 210 | 
         
            +
                    self.ch = ch
         
     | 
| 211 | 
         
            +
                    self.num_resolutions = len(ch_mult)
         
     | 
| 212 | 
         
            +
                    self.num_res_blocks = num_res_blocks
         
     | 
| 213 | 
         
            +
                    self.resolution = resolution
         
     | 
| 214 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 215 | 
         
            +
                    self.ffactor = 2 ** (self.num_resolutions - 1)
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    # compute in_ch_mult, block_in and curr_res at lowest res
         
     | 
| 218 | 
         
            +
                    block_in = ch * ch_mult[self.num_resolutions - 1]
         
     | 
| 219 | 
         
            +
                    curr_res = resolution // 2 ** (self.num_resolutions - 1)
         
     | 
| 220 | 
         
            +
                    self.z_shape = (1, z_channels, curr_res, curr_res)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    # z to block_in
         
     | 
| 223 | 
         
            +
                    self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    # middle
         
     | 
| 226 | 
         
            +
                    self.mid = nn.Module()
         
     | 
| 227 | 
         
            +
                    self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
         
     | 
| 228 | 
         
            +
                    self.mid.attn_1 = AttnBlock(block_in)
         
     | 
| 229 | 
         
            +
                    self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    # upsampling
         
     | 
| 232 | 
         
            +
                    self.up = nn.ModuleList()
         
     | 
| 233 | 
         
            +
                    for i_level in reversed(range(self.num_resolutions)):
         
     | 
| 234 | 
         
            +
                        block = nn.ModuleList()
         
     | 
| 235 | 
         
            +
                        attn = nn.ModuleList()
         
     | 
| 236 | 
         
            +
                        block_out = ch * ch_mult[i_level]
         
     | 
| 237 | 
         
            +
                        for _ in range(self.num_res_blocks + 1):
         
     | 
| 238 | 
         
            +
                            block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
         
     | 
| 239 | 
         
            +
                            block_in = block_out
         
     | 
| 240 | 
         
            +
                        up = nn.Module()
         
     | 
| 241 | 
         
            +
                        up.block = block
         
     | 
| 242 | 
         
            +
                        up.attn = attn
         
     | 
| 243 | 
         
            +
                        if i_level != 0:
         
     | 
| 244 | 
         
            +
                            up.upsample = Upsample(block_in)
         
     | 
| 245 | 
         
            +
                            curr_res = curr_res * 2
         
     | 
| 246 | 
         
            +
                        self.up.insert(0, up)  # prepend to get consistent order
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    # end
         
     | 
| 249 | 
         
            +
                    self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
         
     | 
| 250 | 
         
            +
                    self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                def forward(self, z: Tensor) -> Tensor:
         
     | 
| 253 | 
         
            +
                    # z to block_in
         
     | 
| 254 | 
         
            +
                    h = self.conv_in(z)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    # middle
         
     | 
| 257 | 
         
            +
                    h = self.mid.block_1(h)
         
     | 
| 258 | 
         
            +
                    h = self.mid.attn_1(h)
         
     | 
| 259 | 
         
            +
                    h = self.mid.block_2(h)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    # upsampling
         
     | 
| 262 | 
         
            +
                    for i_level in reversed(range(self.num_resolutions)):
         
     | 
| 263 | 
         
            +
                        for i_block in range(self.num_res_blocks + 1):
         
     | 
| 264 | 
         
            +
                            h = self.up[i_level].block[i_block](h)
         
     | 
| 265 | 
         
            +
                            if len(self.up[i_level].attn) > 0:
         
     | 
| 266 | 
         
            +
                                h = self.up[i_level].attn[i_block](h)
         
     | 
| 267 | 
         
            +
                        if i_level != 0:
         
     | 
| 268 | 
         
            +
                            h = self.up[i_level].upsample(h)
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    # end
         
     | 
| 271 | 
         
            +
                    h = self.norm_out(h)
         
     | 
| 272 | 
         
            +
                    h = swish(h)
         
     | 
| 273 | 
         
            +
                    h = self.conv_out(h)
         
     | 
| 274 | 
         
            +
                    return h
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
            class DiagonalGaussian(nn.Module):
         
     | 
| 278 | 
         
            +
                def __init__(self, sample: bool = True, chunk_dim: int = 1):
         
     | 
| 279 | 
         
            +
                    super().__init__()
         
     | 
| 280 | 
         
            +
                    self.sample = sample
         
     | 
| 281 | 
         
            +
                    self.chunk_dim = chunk_dim
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def forward(self, z: Tensor) -> Tensor:
         
     | 
| 284 | 
         
            +
                    mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
         
     | 
| 285 | 
         
            +
                    if self.sample:
         
     | 
| 286 | 
         
            +
                        std = torch.exp(0.5 * logvar)
         
     | 
| 287 | 
         
            +
                        return mean + std * torch.randn_like(mean)
         
     | 
| 288 | 
         
            +
                    else:
         
     | 
| 289 | 
         
            +
                        return mean
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
            class AutoEncoder(nn.Module):
         
     | 
| 293 | 
         
            +
                def __init__(self, params: AutoEncoderParams):
         
     | 
| 294 | 
         
            +
                    super().__init__()
         
     | 
| 295 | 
         
            +
                    self.encoder = Encoder(
         
     | 
| 296 | 
         
            +
                        resolution=params.resolution,
         
     | 
| 297 | 
         
            +
                        in_channels=params.in_channels,
         
     | 
| 298 | 
         
            +
                        ch=params.ch,
         
     | 
| 299 | 
         
            +
                        ch_mult=params.ch_mult,
         
     | 
| 300 | 
         
            +
                        num_res_blocks=params.num_res_blocks,
         
     | 
| 301 | 
         
            +
                        z_channels=params.z_channels,
         
     | 
| 302 | 
         
            +
                    )
         
     | 
| 303 | 
         
            +
                    self.decoder = Decoder(
         
     | 
| 304 | 
         
            +
                        resolution=params.resolution,
         
     | 
| 305 | 
         
            +
                        in_channels=params.in_channels,
         
     | 
| 306 | 
         
            +
                        ch=params.ch,
         
     | 
| 307 | 
         
            +
                        out_ch=params.out_ch,
         
     | 
| 308 | 
         
            +
                        ch_mult=params.ch_mult,
         
     | 
| 309 | 
         
            +
                        num_res_blocks=params.num_res_blocks,
         
     | 
| 310 | 
         
            +
                        z_channels=params.z_channels,
         
     | 
| 311 | 
         
            +
                    )
         
     | 
| 312 | 
         
            +
                    self.reg = DiagonalGaussian()
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                    self.scale_factor = params.scale_factor
         
     | 
| 315 | 
         
            +
                    self.shift_factor = params.shift_factor
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                def encode(self, x: Tensor) -> Tensor:
         
     | 
| 318 | 
         
            +
                    z = self.reg(self.encoder(x))
         
     | 
| 319 | 
         
            +
                    z = self.scale_factor * (z - self.shift_factor)
         
     | 
| 320 | 
         
            +
                    return z
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                def decode(self, z: Tensor) -> Tensor:
         
     | 
| 323 | 
         
            +
                    z = z / self.scale_factor + self.shift_factor
         
     | 
| 324 | 
         
            +
                    return self.decoder(z)
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                def forward(self, x: Tensor) -> Tensor:
         
     | 
| 327 | 
         
            +
                    return self.decode(self.encode(x))
         
     | 
    	
        uno/flux/modules/conditioner.py
    ADDED
    
    | 
         @@ -0,0 +1,53 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from torch import Tensor, nn
         
     | 
| 17 | 
         
            +
            from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
         
     | 
| 18 | 
         
            +
                                      T5Tokenizer)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            class HFEmbedder(nn.Module):
         
     | 
| 22 | 
         
            +
                def __init__(self, version: str, max_length: int, **hf_kwargs):
         
     | 
| 23 | 
         
            +
                    super().__init__()
         
     | 
| 24 | 
         
            +
                    self.is_clip = version.startswith("openai")
         
     | 
| 25 | 
         
            +
                    self.max_length = max_length
         
     | 
| 26 | 
         
            +
                    self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    if self.is_clip:
         
     | 
| 29 | 
         
            +
                        self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
         
     | 
| 30 | 
         
            +
                        self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
         
     | 
| 31 | 
         
            +
                    else:
         
     | 
| 32 | 
         
            +
                        self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
         
     | 
| 33 | 
         
            +
                        self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    self.hf_module = self.hf_module.eval().requires_grad_(False)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def forward(self, text: list[str]) -> Tensor:
         
     | 
| 38 | 
         
            +
                    batch_encoding = self.tokenizer(
         
     | 
| 39 | 
         
            +
                        text,
         
     | 
| 40 | 
         
            +
                        truncation=True,
         
     | 
| 41 | 
         
            +
                        max_length=self.max_length,
         
     | 
| 42 | 
         
            +
                        return_length=False,
         
     | 
| 43 | 
         
            +
                        return_overflowing_tokens=False,
         
     | 
| 44 | 
         
            +
                        padding="max_length",
         
     | 
| 45 | 
         
            +
                        return_tensors="pt",
         
     | 
| 46 | 
         
            +
                    )
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    outputs = self.hf_module(
         
     | 
| 49 | 
         
            +
                        input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
         
     | 
| 50 | 
         
            +
                        attention_mask=None,
         
     | 
| 51 | 
         
            +
                        output_hidden_states=False,
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
                    return outputs[self.output_key]
         
     | 
    	
        uno/flux/modules/layers.py
    ADDED
    
    | 
         @@ -0,0 +1,435 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import math
         
     | 
| 17 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            from einops import rearrange
         
     | 
| 21 | 
         
            +
            from torch import Tensor, nn
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            from ..math import attention, rope
         
     | 
| 24 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            class EmbedND(nn.Module):
         
     | 
| 27 | 
         
            +
                def __init__(self, dim: int, theta: int, axes_dim: list[int]):
         
     | 
| 28 | 
         
            +
                    super().__init__()
         
     | 
| 29 | 
         
            +
                    self.dim = dim
         
     | 
| 30 | 
         
            +
                    self.theta = theta
         
     | 
| 31 | 
         
            +
                    self.axes_dim = axes_dim
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def forward(self, ids: Tensor) -> Tensor:
         
     | 
| 34 | 
         
            +
                    n_axes = ids.shape[-1]
         
     | 
| 35 | 
         
            +
                    emb = torch.cat(
         
     | 
| 36 | 
         
            +
                        [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
         
     | 
| 37 | 
         
            +
                        dim=-3,
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    return emb.unsqueeze(1)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                Create sinusoidal timestep embeddings.
         
     | 
| 46 | 
         
            +
                :param t: a 1-D Tensor of N indices, one per batch element.
         
     | 
| 47 | 
         
            +
                                  These may be fractional.
         
     | 
| 48 | 
         
            +
                :param dim: the dimension of the output.
         
     | 
| 49 | 
         
            +
                :param max_period: controls the minimum frequency of the embeddings.
         
     | 
| 50 | 
         
            +
                :return: an (N, D) Tensor of positional embeddings.
         
     | 
| 51 | 
         
            +
                """
         
     | 
| 52 | 
         
            +
                t = time_factor * t
         
     | 
| 53 | 
         
            +
                half = dim // 2
         
     | 
| 54 | 
         
            +
                freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
         
     | 
| 55 | 
         
            +
                    t.device
         
     | 
| 56 | 
         
            +
                )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                args = t[:, None].float() * freqs[None]
         
     | 
| 59 | 
         
            +
                embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         
     | 
| 60 | 
         
            +
                if dim % 2:
         
     | 
| 61 | 
         
            +
                    embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
         
     | 
| 62 | 
         
            +
                if torch.is_floating_point(t):
         
     | 
| 63 | 
         
            +
                    embedding = embedding.to(t)
         
     | 
| 64 | 
         
            +
                return embedding
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            class MLPEmbedder(nn.Module):
         
     | 
| 68 | 
         
            +
                def __init__(self, in_dim: int, hidden_dim: int):
         
     | 
| 69 | 
         
            +
                    super().__init__()
         
     | 
| 70 | 
         
            +
                    self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
         
     | 
| 71 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 72 | 
         
            +
                    self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def forward(self, x: Tensor) -> Tensor:
         
     | 
| 75 | 
         
            +
                    return self.out_layer(self.silu(self.in_layer(x)))
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            class RMSNorm(torch.nn.Module):
         
     | 
| 79 | 
         
            +
                def __init__(self, dim: int):
         
     | 
| 80 | 
         
            +
                    super().__init__()
         
     | 
| 81 | 
         
            +
                    self.scale = nn.Parameter(torch.ones(dim))
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def forward(self, x: Tensor):
         
     | 
| 84 | 
         
            +
                    x_dtype = x.dtype
         
     | 
| 85 | 
         
            +
                    x = x.float()
         
     | 
| 86 | 
         
            +
                    rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
         
     | 
| 87 | 
         
            +
                    return (x * rrms).to(dtype=x_dtype) * self.scale
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            class QKNorm(torch.nn.Module):
         
     | 
| 91 | 
         
            +
                def __init__(self, dim: int):
         
     | 
| 92 | 
         
            +
                    super().__init__()
         
     | 
| 93 | 
         
            +
                    self.query_norm = RMSNorm(dim)
         
     | 
| 94 | 
         
            +
                    self.key_norm = RMSNorm(dim)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
         
     | 
| 97 | 
         
            +
                    q = self.query_norm(q)
         
     | 
| 98 | 
         
            +
                    k = self.key_norm(k)
         
     | 
| 99 | 
         
            +
                    return q.to(v), k.to(v)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            class LoRALinearLayer(nn.Module):
         
     | 
| 102 | 
         
            +
                def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
         
     | 
| 103 | 
         
            +
                    super().__init__()
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
         
     | 
| 106 | 
         
            +
                    self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
         
     | 
| 107 | 
         
            +
                    # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
         
     | 
| 108 | 
         
            +
                    # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
         
     | 
| 109 | 
         
            +
                    self.network_alpha = network_alpha
         
     | 
| 110 | 
         
            +
                    self.rank = rank
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    nn.init.normal_(self.down.weight, std=1 / rank)
         
     | 
| 113 | 
         
            +
                    nn.init.zeros_(self.up.weight)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 116 | 
         
            +
                    orig_dtype = hidden_states.dtype
         
     | 
| 117 | 
         
            +
                    dtype = self.down.weight.dtype
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    down_hidden_states = self.down(hidden_states.to(dtype))
         
     | 
| 120 | 
         
            +
                    up_hidden_states = self.up(down_hidden_states)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    if self.network_alpha is not None:
         
     | 
| 123 | 
         
            +
                        up_hidden_states *= self.network_alpha / self.rank
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    return up_hidden_states.to(orig_dtype)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            class FLuxSelfAttnProcessor:
         
     | 
| 128 | 
         
            +
                def __call__(self, attn, x, pe, **attention_kwargs):
         
     | 
| 129 | 
         
            +
                    qkv = attn.qkv(x)
         
     | 
| 130 | 
         
            +
                    q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
         
     | 
| 131 | 
         
            +
                    q, k = attn.norm(q, k, v)
         
     | 
| 132 | 
         
            +
                    x = attention(q, k, v, pe=pe)
         
     | 
| 133 | 
         
            +
                    x = attn.proj(x)
         
     | 
| 134 | 
         
            +
                    return x
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            class LoraFluxAttnProcessor(nn.Module):
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
         
     | 
| 139 | 
         
            +
                    super().__init__()
         
     | 
| 140 | 
         
            +
                    self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
         
     | 
| 141 | 
         
            +
                    self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
         
     | 
| 142 | 
         
            +
                    self.lora_weight = lora_weight
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def __call__(self, attn, x, pe, **attention_kwargs):
         
     | 
| 146 | 
         
            +
                    qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
         
     | 
| 147 | 
         
            +
                    q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
         
     | 
| 148 | 
         
            +
                    q, k = attn.norm(q, k, v)
         
     | 
| 149 | 
         
            +
                    x = attention(q, k, v, pe=pe)
         
     | 
| 150 | 
         
            +
                    x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
         
     | 
| 151 | 
         
            +
                    return x
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            class SelfAttention(nn.Module):
         
     | 
| 154 | 
         
            +
                def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
         
     | 
| 155 | 
         
            +
                    super().__init__()
         
     | 
| 156 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 157 | 
         
            +
                    head_dim = dim // num_heads
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         
     | 
| 160 | 
         
            +
                    self.norm = QKNorm(head_dim)
         
     | 
| 161 | 
         
            +
                    self.proj = nn.Linear(dim, dim)
         
     | 
| 162 | 
         
            +
                def forward():
         
     | 
| 163 | 
         
            +
                    pass
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            @dataclass
         
     | 
| 167 | 
         
            +
            class ModulationOut:
         
     | 
| 168 | 
         
            +
                shift: Tensor
         
     | 
| 169 | 
         
            +
                scale: Tensor
         
     | 
| 170 | 
         
            +
                gate: Tensor
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            class Modulation(nn.Module):
         
     | 
| 174 | 
         
            +
                def __init__(self, dim: int, double: bool):
         
     | 
| 175 | 
         
            +
                    super().__init__()
         
     | 
| 176 | 
         
            +
                    self.is_double = double
         
     | 
| 177 | 
         
            +
                    self.multiplier = 6 if double else 3
         
     | 
| 178 | 
         
            +
                    self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
         
     | 
| 181 | 
         
            +
                    out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    return (
         
     | 
| 184 | 
         
            +
                        ModulationOut(*out[:3]),
         
     | 
| 185 | 
         
            +
                        ModulationOut(*out[3:]) if self.is_double else None,
         
     | 
| 186 | 
         
            +
                    )
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
            class DoubleStreamBlockLoraProcessor(nn.Module):
         
     | 
| 189 | 
         
            +
                def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
         
     | 
| 190 | 
         
            +
                    super().__init__()
         
     | 
| 191 | 
         
            +
                    self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
         
     | 
| 192 | 
         
            +
                    self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
         
     | 
| 193 | 
         
            +
                    self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
         
     | 
| 194 | 
         
            +
                    self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
         
     | 
| 195 | 
         
            +
                    self.lora_weight = lora_weight
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
         
     | 
| 198 | 
         
            +
                    img_mod1, img_mod2 = attn.img_mod(vec)
         
     | 
| 199 | 
         
            +
                    txt_mod1, txt_mod2 = attn.txt_mod(vec)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    # prepare image for attention
         
     | 
| 202 | 
         
            +
                    img_modulated = attn.img_norm1(img)
         
     | 
| 203 | 
         
            +
                    img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
         
     | 
| 204 | 
         
            +
                    img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
         
     | 
| 205 | 
         
            +
                    img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 206 | 
         
            +
                    img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    # prepare txt for attention
         
     | 
| 209 | 
         
            +
                    txt_modulated = attn.txt_norm1(txt)
         
     | 
| 210 | 
         
            +
                    txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
         
     | 
| 211 | 
         
            +
                    txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
         
     | 
| 212 | 
         
            +
                    txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 213 | 
         
            +
                    txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    # run actual attention
         
     | 
| 216 | 
         
            +
                    q = torch.cat((txt_q, img_q), dim=2)
         
     | 
| 217 | 
         
            +
                    k = torch.cat((txt_k, img_k), dim=2)
         
     | 
| 218 | 
         
            +
                    v = torch.cat((txt_v, img_v), dim=2)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    attn1 = attention(q, k, v, pe=pe)
         
     | 
| 221 | 
         
            +
                    txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    # calculate the img bloks
         
     | 
| 224 | 
         
            +
                    img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
         
     | 
| 225 | 
         
            +
                    img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    # calculate the txt bloks
         
     | 
| 228 | 
         
            +
                    txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
         
     | 
| 229 | 
         
            +
                    txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
         
     | 
| 230 | 
         
            +
                    return img, txt
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
            class DoubleStreamBlockProcessor:
         
     | 
| 233 | 
         
            +
                def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
         
     | 
| 234 | 
         
            +
                    img_mod1, img_mod2 = attn.img_mod(vec)
         
     | 
| 235 | 
         
            +
                    txt_mod1, txt_mod2 = attn.txt_mod(vec)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    # prepare image for attention
         
     | 
| 238 | 
         
            +
                    img_modulated = attn.img_norm1(img)
         
     | 
| 239 | 
         
            +
                    img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
         
     | 
| 240 | 
         
            +
                    img_qkv = attn.img_attn.qkv(img_modulated)
         
     | 
| 241 | 
         
            +
                    img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
         
     | 
| 242 | 
         
            +
                    img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    # prepare txt for attention
         
     | 
| 245 | 
         
            +
                    txt_modulated = attn.txt_norm1(txt)
         
     | 
| 246 | 
         
            +
                    txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
         
     | 
| 247 | 
         
            +
                    txt_qkv = attn.txt_attn.qkv(txt_modulated)
         
     | 
| 248 | 
         
            +
                    txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
         
     | 
| 249 | 
         
            +
                    txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    # run actual attention
         
     | 
| 252 | 
         
            +
                    q = torch.cat((txt_q, img_q), dim=2)
         
     | 
| 253 | 
         
            +
                    k = torch.cat((txt_k, img_k), dim=2)
         
     | 
| 254 | 
         
            +
                    v = torch.cat((txt_v, img_v), dim=2)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    attn1 = attention(q, k, v, pe=pe)
         
     | 
| 257 | 
         
            +
                    txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    # calculate the img bloks
         
     | 
| 260 | 
         
            +
                    img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
         
     | 
| 261 | 
         
            +
                    img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    # calculate the txt bloks
         
     | 
| 264 | 
         
            +
                    txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
         
     | 
| 265 | 
         
            +
                    txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
         
     | 
| 266 | 
         
            +
                    return img, txt
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
            class DoubleStreamBlock(nn.Module):
         
     | 
| 269 | 
         
            +
                def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
         
     | 
| 270 | 
         
            +
                    super().__init__()
         
     | 
| 271 | 
         
            +
                    mlp_hidden_dim = int(hidden_size * mlp_ratio)
         
     | 
| 272 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 273 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 274 | 
         
            +
                    self.head_dim = hidden_size // num_heads
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    self.img_mod = Modulation(hidden_size, double=True)
         
     | 
| 277 | 
         
            +
                    self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 278 | 
         
            +
                    self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 281 | 
         
            +
                    self.img_mlp = nn.Sequential(
         
     | 
| 282 | 
         
            +
                        nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
         
     | 
| 283 | 
         
            +
                        nn.GELU(approximate="tanh"),
         
     | 
| 284 | 
         
            +
                        nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
         
     | 
| 285 | 
         
            +
                    )
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    self.txt_mod = Modulation(hidden_size, double=True)
         
     | 
| 288 | 
         
            +
                    self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 289 | 
         
            +
                    self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 292 | 
         
            +
                    self.txt_mlp = nn.Sequential(
         
     | 
| 293 | 
         
            +
                        nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
         
     | 
| 294 | 
         
            +
                        nn.GELU(approximate="tanh"),
         
     | 
| 295 | 
         
            +
                        nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
         
     | 
| 296 | 
         
            +
                    )
         
     | 
| 297 | 
         
            +
                    processor = DoubleStreamBlockProcessor()
         
     | 
| 298 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                def set_processor(self, processor) -> None:
         
     | 
| 301 | 
         
            +
                    self.processor = processor
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                def get_processor(self):
         
     | 
| 304 | 
         
            +
                    return self.processor
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                def forward(
         
     | 
| 307 | 
         
            +
                    self,
         
     | 
| 308 | 
         
            +
                    img: Tensor,
         
     | 
| 309 | 
         
            +
                    txt: Tensor,
         
     | 
| 310 | 
         
            +
                    vec: Tensor,
         
     | 
| 311 | 
         
            +
                    pe: Tensor,
         
     | 
| 312 | 
         
            +
                    image_proj: Tensor = None,
         
     | 
| 313 | 
         
            +
                    ip_scale: float =1.0,
         
     | 
| 314 | 
         
            +
                ) -> tuple[Tensor, Tensor]:
         
     | 
| 315 | 
         
            +
                    if image_proj is None:
         
     | 
| 316 | 
         
            +
                        return self.processor(self, img, txt, vec, pe)
         
     | 
| 317 | 
         
            +
                    else:
         
     | 
| 318 | 
         
            +
                        return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
            class SingleStreamBlockLoraProcessor(nn.Module):
         
     | 
| 322 | 
         
            +
                def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
         
     | 
| 323 | 
         
            +
                    super().__init__()
         
     | 
| 324 | 
         
            +
                    self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
         
     | 
| 325 | 
         
            +
                    self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
         
     | 
| 326 | 
         
            +
                    self.lora_weight = lora_weight
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                    mod, _ = attn.modulation(vec)
         
     | 
| 331 | 
         
            +
                    x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
         
     | 
| 332 | 
         
            +
                    qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
         
     | 
| 333 | 
         
            +
                    qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 336 | 
         
            +
                    q, k = attn.norm(q, k, v)
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                    # compute attention
         
     | 
| 339 | 
         
            +
                    attn_1 = attention(q, k, v, pe=pe)
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    # compute activation in mlp stream, cat again and run second linear layer
         
     | 
| 342 | 
         
            +
                    output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
         
     | 
| 343 | 
         
            +
                    output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
         
     | 
| 344 | 
         
            +
                    output = x + mod.gate * output
         
     | 
| 345 | 
         
            +
                    return output
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
            class SingleStreamBlockProcessor:
         
     | 
| 349 | 
         
            +
                def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                    mod, _ = attn.modulation(vec)
         
     | 
| 352 | 
         
            +
                    x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
         
     | 
| 353 | 
         
            +
                    qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 356 | 
         
            +
                    q, k = attn.norm(q, k, v)
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                    # compute attention
         
     | 
| 359 | 
         
            +
                    attn_1 = attention(q, k, v, pe=pe)
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    # compute activation in mlp stream, cat again and run second linear layer
         
     | 
| 362 | 
         
            +
                    output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
         
     | 
| 363 | 
         
            +
                    output = x + mod.gate * output
         
     | 
| 364 | 
         
            +
                    return output
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
            class SingleStreamBlock(nn.Module):
         
     | 
| 367 | 
         
            +
                """
         
     | 
| 368 | 
         
            +
                A DiT block with parallel linear layers as described in
         
     | 
| 369 | 
         
            +
                https://arxiv.org/abs/2302.05442 and adapted modulation interface.
         
     | 
| 370 | 
         
            +
                """
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                def __init__(
         
     | 
| 373 | 
         
            +
                    self,
         
     | 
| 374 | 
         
            +
                    hidden_size: int,
         
     | 
| 375 | 
         
            +
                    num_heads: int,
         
     | 
| 376 | 
         
            +
                    mlp_ratio: float = 4.0,
         
     | 
| 377 | 
         
            +
                    qk_scale: float | None = None,
         
     | 
| 378 | 
         
            +
                ):
         
     | 
| 379 | 
         
            +
                    super().__init__()
         
     | 
| 380 | 
         
            +
                    self.hidden_dim = hidden_size
         
     | 
| 381 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 382 | 
         
            +
                    self.head_dim = hidden_size // num_heads
         
     | 
| 383 | 
         
            +
                    self.scale = qk_scale or self.head_dim**-0.5
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                    self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
         
     | 
| 386 | 
         
            +
                    # qkv and mlp_in
         
     | 
| 387 | 
         
            +
                    self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
         
     | 
| 388 | 
         
            +
                    # proj and mlp_out
         
     | 
| 389 | 
         
            +
                    self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                    self.norm = QKNorm(self.head_dim)
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 394 | 
         
            +
                    self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    self.mlp_act = nn.GELU(approximate="tanh")
         
     | 
| 397 | 
         
            +
                    self.modulation = Modulation(hidden_size, double=False)
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    processor = SingleStreamBlockProcessor()
         
     | 
| 400 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                def set_processor(self, processor) -> None:
         
     | 
| 404 | 
         
            +
                    self.processor = processor
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
                def get_processor(self):
         
     | 
| 407 | 
         
            +
                    return self.processor
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                def forward(
         
     | 
| 410 | 
         
            +
                    self,
         
     | 
| 411 | 
         
            +
                    x: Tensor,
         
     | 
| 412 | 
         
            +
                    vec: Tensor,
         
     | 
| 413 | 
         
            +
                    pe: Tensor,
         
     | 
| 414 | 
         
            +
                    image_proj: Tensor | None = None,
         
     | 
| 415 | 
         
            +
                    ip_scale: float = 1.0,
         
     | 
| 416 | 
         
            +
                ) -> Tensor:
         
     | 
| 417 | 
         
            +
                    if image_proj is None:
         
     | 
| 418 | 
         
            +
                        return self.processor(self, x, vec, pe)
         
     | 
| 419 | 
         
            +
                    else:
         
     | 
| 420 | 
         
            +
                        return self.processor(self, x, vec, pe, image_proj, ip_scale)
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
            class LastLayer(nn.Module):
         
     | 
| 425 | 
         
            +
                def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
         
     | 
| 426 | 
         
            +
                    super().__init__()
         
     | 
| 427 | 
         
            +
                    self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 428 | 
         
            +
                    self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
         
     | 
| 429 | 
         
            +
                    self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                def forward(self, x: Tensor, vec: Tensor) -> Tensor:
         
     | 
| 432 | 
         
            +
                    shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
         
     | 
| 433 | 
         
            +
                    x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
         
     | 
| 434 | 
         
            +
                    x = self.linear(x)
         
     | 
| 435 | 
         
            +
                    return x
         
     | 
    	
        uno/flux/pipeline.py
    ADDED
    
    | 
         @@ -0,0 +1,324 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import os
         
     | 
| 17 | 
         
            +
            from typing import Literal
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            from einops import rearrange
         
     | 
| 21 | 
         
            +
            from PIL import ExifTags, Image
         
     | 
| 22 | 
         
            +
            import torchvision.transforms.functional as TVF
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from uno.flux.modules.layers import (
         
     | 
| 25 | 
         
            +
                DoubleStreamBlockLoraProcessor,
         
     | 
| 26 | 
         
            +
                DoubleStreamBlockProcessor,
         
     | 
| 27 | 
         
            +
                SingleStreamBlockLoraProcessor,
         
     | 
| 28 | 
         
            +
                SingleStreamBlockProcessor,
         
     | 
| 29 | 
         
            +
            )
         
     | 
| 30 | 
         
            +
            from uno.flux.sampling import denoise, get_noise, get_schedule, prepare, prepare_multi_ip, unpack
         
     | 
| 31 | 
         
            +
            from uno.flux.util import (
         
     | 
| 32 | 
         
            +
                get_lora_rank,
         
     | 
| 33 | 
         
            +
                load_ae,
         
     | 
| 34 | 
         
            +
                load_checkpoint,
         
     | 
| 35 | 
         
            +
                load_clip,
         
     | 
| 36 | 
         
            +
                load_flow_model,
         
     | 
| 37 | 
         
            +
                load_flow_model_only_lora,
         
     | 
| 38 | 
         
            +
                load_flow_model_quintized,
         
     | 
| 39 | 
         
            +
                load_t5,
         
     | 
| 40 | 
         
            +
            )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def find_nearest_scale(image_h, image_w, predefined_scales):
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                根据图片的高度和宽度,找到最近的预定义尺度。
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                :param image_h: 图片的高度
         
     | 
| 48 | 
         
            +
                :param image_w: 图片的宽度
         
     | 
| 49 | 
         
            +
                :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
         
     | 
| 50 | 
         
            +
                :return: 最近的预定义尺度 (h, w)
         
     | 
| 51 | 
         
            +
                """
         
     | 
| 52 | 
         
            +
                # 计算输入图片的长宽比
         
     | 
| 53 | 
         
            +
                image_ratio = image_h / image_w
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                # 初始化变量以存储最小差异和最近的尺度
         
     | 
| 56 | 
         
            +
                min_diff = float('inf')
         
     | 
| 57 | 
         
            +
                nearest_scale = None
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
         
     | 
| 60 | 
         
            +
                for scale_h, scale_w in predefined_scales:
         
     | 
| 61 | 
         
            +
                    predefined_ratio = scale_h / scale_w
         
     | 
| 62 | 
         
            +
                    diff = abs(predefined_ratio - image_ratio)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    if diff < min_diff:
         
     | 
| 65 | 
         
            +
                        min_diff = diff
         
     | 
| 66 | 
         
            +
                        nearest_scale = (scale_h, scale_w)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                return nearest_scale
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
         
     | 
| 71 | 
         
            +
                # 获取原始图像的宽度和高度
         
     | 
| 72 | 
         
            +
                image_w, image_h = raw_image.size
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                # 计算长边和短边
         
     | 
| 75 | 
         
            +
                if image_w >= image_h:
         
     | 
| 76 | 
         
            +
                    new_w = long_size
         
     | 
| 77 | 
         
            +
                    new_h = int((long_size / image_w) * image_h)
         
     | 
| 78 | 
         
            +
                else:
         
     | 
| 79 | 
         
            +
                    new_h = long_size
         
     | 
| 80 | 
         
            +
                    new_w = int((long_size / image_h) * image_w)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                # 按新的宽高进行等比例缩放
         
     | 
| 83 | 
         
            +
                raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
         
     | 
| 84 | 
         
            +
                target_w = new_w // 16 * 16
         
     | 
| 85 | 
         
            +
                target_h = new_h // 16 * 16
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                # 计算裁剪的起始坐标以实现中心裁剪
         
     | 
| 88 | 
         
            +
                left = (new_w - target_w) // 2
         
     | 
| 89 | 
         
            +
                top = (new_h - target_h) // 2
         
     | 
| 90 | 
         
            +
                right = left + target_w
         
     | 
| 91 | 
         
            +
                bottom = top + target_h
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                # 进行中心裁剪
         
     | 
| 94 | 
         
            +
                raw_image = raw_image.crop((left, top, right, bottom))
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                # 转换为 RGB 模式
         
     | 
| 97 | 
         
            +
                raw_image = raw_image.convert("RGB")
         
     | 
| 98 | 
         
            +
                return raw_image
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            class UNOPipeline:
         
     | 
| 101 | 
         
            +
                def __init__(
         
     | 
| 102 | 
         
            +
                    self,
         
     | 
| 103 | 
         
            +
                    model_type: str,
         
     | 
| 104 | 
         
            +
                    device: torch.device,
         
     | 
| 105 | 
         
            +
                    offload: bool = False,
         
     | 
| 106 | 
         
            +
                    only_lora: bool = False,
         
     | 
| 107 | 
         
            +
                    lora_rank: int = 16
         
     | 
| 108 | 
         
            +
                ):
         
     | 
| 109 | 
         
            +
                    self.device = device
         
     | 
| 110 | 
         
            +
                    self.offload = offload
         
     | 
| 111 | 
         
            +
                    self.model_type = model_type
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    self.clip = load_clip(self.device)
         
     | 
| 114 | 
         
            +
                    self.t5 = load_t5(self.device, max_length=512)
         
     | 
| 115 | 
         
            +
                    self.ae = load_ae(model_type, device="cpu" if offload else self.device)
         
     | 
| 116 | 
         
            +
                    if "fp8" in model_type:
         
     | 
| 117 | 
         
            +
                        self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
         
     | 
| 118 | 
         
            +
                    elif only_lora:
         
     | 
| 119 | 
         
            +
                        self.model = load_flow_model_only_lora(
         
     | 
| 120 | 
         
            +
                            model_type, device="cpu" if offload else self.device, lora_rank=lora_rank
         
     | 
| 121 | 
         
            +
                        )
         
     | 
| 122 | 
         
            +
                    else:
         
     | 
| 123 | 
         
            +
                        self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def load_ckpt(self, ckpt_path):
         
     | 
| 127 | 
         
            +
                    if ckpt_path is not None:
         
     | 
| 128 | 
         
            +
                        from safetensors.torch import load_file as load_sft
         
     | 
| 129 | 
         
            +
                        print("Loading checkpoint to replace old keys")
         
     | 
| 130 | 
         
            +
                        # load_sft doesn't support torch.device
         
     | 
| 131 | 
         
            +
                        if ckpt_path.endswith('safetensors'):
         
     | 
| 132 | 
         
            +
                            sd = load_sft(ckpt_path, device='cpu')
         
     | 
| 133 | 
         
            +
                            missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 134 | 
         
            +
                        else:
         
     | 
| 135 | 
         
            +
                            dit_state = torch.load(ckpt_path, map_location='cpu')
         
     | 
| 136 | 
         
            +
                            sd = {}
         
     | 
| 137 | 
         
            +
                            for k in dit_state.keys():
         
     | 
| 138 | 
         
            +
                                sd[k.replace('module.','')] = dit_state[k]
         
     | 
| 139 | 
         
            +
                            missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 140 | 
         
            +
                            self.model.to(str(self.device))
         
     | 
| 141 | 
         
            +
                        print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                def set_lora(self, local_path: str = None, repo_id: str = None,
         
     | 
| 144 | 
         
            +
                             name: str = None, lora_weight: int = 0.7):
         
     | 
| 145 | 
         
            +
                    checkpoint = load_checkpoint(local_path, repo_id, name)
         
     | 
| 146 | 
         
            +
                    self.update_model_with_lora(checkpoint, lora_weight)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
         
     | 
| 149 | 
         
            +
                    checkpoint = load_checkpoint(
         
     | 
| 150 | 
         
            +
                        None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
         
     | 
| 151 | 
         
            +
                    )
         
     | 
| 152 | 
         
            +
                    self.update_model_with_lora(checkpoint, lora_weight)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def update_model_with_lora(self, checkpoint, lora_weight):
         
     | 
| 155 | 
         
            +
                    rank = get_lora_rank(checkpoint)
         
     | 
| 156 | 
         
            +
                    lora_attn_procs = {}
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    for name, _ in self.model.attn_processors.items():
         
     | 
| 159 | 
         
            +
                        lora_state_dict = {}
         
     | 
| 160 | 
         
            +
                        for k in checkpoint.keys():
         
     | 
| 161 | 
         
            +
                            if name in k:
         
     | 
| 162 | 
         
            +
                                lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                        if len(lora_state_dict):
         
     | 
| 165 | 
         
            +
                            if name.startswith("single_blocks"):
         
     | 
| 166 | 
         
            +
                                lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
         
     | 
| 167 | 
         
            +
                            else:
         
     | 
| 168 | 
         
            +
                                lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
         
     | 
| 169 | 
         
            +
                            lora_attn_procs[name].load_state_dict(lora_state_dict)
         
     | 
| 170 | 
         
            +
                            lora_attn_procs[name].to(self.device)
         
     | 
| 171 | 
         
            +
                        else:
         
     | 
| 172 | 
         
            +
                            if name.startswith("single_blocks"):
         
     | 
| 173 | 
         
            +
                                lora_attn_procs[name] = SingleStreamBlockProcessor()
         
     | 
| 174 | 
         
            +
                            else:
         
     | 
| 175 | 
         
            +
                                lora_attn_procs[name] = DoubleStreamBlockProcessor()
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    self.model.set_attn_processor(lora_attn_procs)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                def __call__(
         
     | 
| 181 | 
         
            +
                    self,
         
     | 
| 182 | 
         
            +
                    prompt: str,
         
     | 
| 183 | 
         
            +
                    width: int = 512,
         
     | 
| 184 | 
         
            +
                    height: int = 512,
         
     | 
| 185 | 
         
            +
                    guidance: float = 4,
         
     | 
| 186 | 
         
            +
                    num_steps: int = 50,
         
     | 
| 187 | 
         
            +
                    seed: int = 123456789,
         
     | 
| 188 | 
         
            +
                    true_gs: float = 3,
         
     | 
| 189 | 
         
            +
                    neg_prompt: str = '',
         
     | 
| 190 | 
         
            +
                    neg_image_prompt: Image = None,
         
     | 
| 191 | 
         
            +
                    timestep_to_start_cfg: int = 0,
         
     | 
| 192 | 
         
            +
                    **kwargs
         
     | 
| 193 | 
         
            +
                ):
         
     | 
| 194 | 
         
            +
                    width = 16 * (width // 16)
         
     | 
| 195 | 
         
            +
                    height = 16 * (height // 16)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    return self.forward(
         
     | 
| 198 | 
         
            +
                        prompt,
         
     | 
| 199 | 
         
            +
                        width,
         
     | 
| 200 | 
         
            +
                        height,
         
     | 
| 201 | 
         
            +
                        guidance,
         
     | 
| 202 | 
         
            +
                        num_steps,
         
     | 
| 203 | 
         
            +
                        seed,
         
     | 
| 204 | 
         
            +
                        timestep_to_start_cfg=timestep_to_start_cfg,
         
     | 
| 205 | 
         
            +
                        true_gs=true_gs,
         
     | 
| 206 | 
         
            +
                        neg_prompt=neg_prompt,
         
     | 
| 207 | 
         
            +
                        **kwargs
         
     | 
| 208 | 
         
            +
                    )
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                @torch.inference_mode()
         
     | 
| 211 | 
         
            +
                def gradio_generate(
         
     | 
| 212 | 
         
            +
                    self,
         
     | 
| 213 | 
         
            +
                    prompt: str,
         
     | 
| 214 | 
         
            +
                    width: int,
         
     | 
| 215 | 
         
            +
                    height: int,
         
     | 
| 216 | 
         
            +
                    guidance: float,
         
     | 
| 217 | 
         
            +
                    num_steps: int,
         
     | 
| 218 | 
         
            +
                    seed: int,
         
     | 
| 219 | 
         
            +
                    ref_long_side: int,
         
     | 
| 220 | 
         
            +
                    image_prompt1: Image.Image,
         
     | 
| 221 | 
         
            +
                    image_prompt2: Image.Image,
         
     | 
| 222 | 
         
            +
                    image_prompt3: Image.Image,
         
     | 
| 223 | 
         
            +
                    image_prompt4: Image.Image,
         
     | 
| 224 | 
         
            +
                ):
         
     | 
| 225 | 
         
            +
                    ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
         
     | 
| 226 | 
         
            +
                    ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
         
     | 
| 227 | 
         
            +
                    ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                    seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    img = self(prompt=prompt, width=width, height=height, guidance=guidance,
         
     | 
| 232 | 
         
            +
                               num_steps=num_steps, seed=seed, ref_imgs=ref_imgs)
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    filename = f"output/gradio/{seed}_{prompt[:20]}.png"
         
     | 
| 235 | 
         
            +
                    os.makedirs(os.path.dirname(filename), exist_ok=True)
         
     | 
| 236 | 
         
            +
                    exif_data = Image.Exif()
         
     | 
| 237 | 
         
            +
                    exif_data[ExifTags.Base.Make] = "UNO"
         
     | 
| 238 | 
         
            +
                    exif_data[ExifTags.Base.Model] = self.model_type
         
     | 
| 239 | 
         
            +
                    info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
         
     | 
| 240 | 
         
            +
                    exif_data[ExifTags.Base.ImageDescription] = info
         
     | 
| 241 | 
         
            +
                    img.save(filename, format="png", exif=exif_data)
         
     | 
| 242 | 
         
            +
                    return img, filename
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                @torch.inference_mode
         
     | 
| 245 | 
         
            +
                def forward(
         
     | 
| 246 | 
         
            +
                    self,
         
     | 
| 247 | 
         
            +
                    prompt: str,
         
     | 
| 248 | 
         
            +
                    width: int,
         
     | 
| 249 | 
         
            +
                    height: int,
         
     | 
| 250 | 
         
            +
                    guidance: float,
         
     | 
| 251 | 
         
            +
                    num_steps: int,
         
     | 
| 252 | 
         
            +
                    seed: int,
         
     | 
| 253 | 
         
            +
                    timestep_to_start_cfg: int = 1e5,  # TODO 没用,删除
         
     | 
| 254 | 
         
            +
                    true_gs: float = 3.5,
         
     | 
| 255 | 
         
            +
                    neg_prompt: str = "",
         
     | 
| 256 | 
         
            +
                    ref_imgs: list[Image.Image] | None = None,
         
     | 
| 257 | 
         
            +
                    pe: Literal['d', 'h', 'w', 'o'] = 'd',
         
     | 
| 258 | 
         
            +
                ):
         
     | 
| 259 | 
         
            +
                    x = get_noise(
         
     | 
| 260 | 
         
            +
                        1, height, width, device=self.device,
         
     | 
| 261 | 
         
            +
                        dtype=torch.bfloat16, seed=seed
         
     | 
| 262 | 
         
            +
                    )
         
     | 
| 263 | 
         
            +
                    timesteps = get_schedule(
         
     | 
| 264 | 
         
            +
                        num_steps,
         
     | 
| 265 | 
         
            +
                        (width // 8) * (height // 8) // (16 * 16),
         
     | 
| 266 | 
         
            +
                        shift=True,
         
     | 
| 267 | 
         
            +
                    )
         
     | 
| 268 | 
         
            +
                    if self.offload:
         
     | 
| 269 | 
         
            +
                        self.ae.encoder = self.ae.encoder.to(self.device)
         
     | 
| 270 | 
         
            +
                    x_1_refs = [
         
     | 
| 271 | 
         
            +
                        self.ae.encode(
         
     | 
| 272 | 
         
            +
                            (TVF.to_tensor(ref_img) * 2.0 - 1.0) 
         
     | 
| 273 | 
         
            +
                            .unsqueeze(0).to(self.device, torch.float32)
         
     | 
| 274 | 
         
            +
                        ).to(torch.bfloat16)
         
     | 
| 275 | 
         
            +
                        for ref_img in ref_imgs
         
     | 
| 276 | 
         
            +
                    ]
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    if self.offload:
         
     | 
| 279 | 
         
            +
                        self.ae.encoder = self.offload_model_to_cpu(self.ae.encoder)
         
     | 
| 280 | 
         
            +
                        self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
         
     | 
| 281 | 
         
            +
                    inp_cond = prepare_multi_ip(
         
     | 
| 282 | 
         
            +
                        t5=self.t5, clip=self.clip,
         
     | 
| 283 | 
         
            +
                        img=x,
         
     | 
| 284 | 
         
            +
                        prompt=prompt, ref_imgs=x_1_refs, pe=pe
         
     | 
| 285 | 
         
            +
                    )
         
     | 
| 286 | 
         
            +
                    neg_inp_cond = prepare_multi_ip(
         
     | 
| 287 | 
         
            +
                        t5=self.t5, clip=self.clip,
         
     | 
| 288 | 
         
            +
                        img=x,
         
     | 
| 289 | 
         
            +
                        prompt=neg_prompt, ref_imgs=x_1_refs, pe=pe
         
     | 
| 290 | 
         
            +
                    )
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    if self.offload:
         
     | 
| 293 | 
         
            +
                        self.offload_model_to_cpu(self.t5, self.clip)
         
     | 
| 294 | 
         
            +
                        self.model = self.model.to(self.device)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    x = denoise(
         
     | 
| 297 | 
         
            +
                        self.model,
         
     | 
| 298 | 
         
            +
                        **inp_cond,
         
     | 
| 299 | 
         
            +
                        timesteps=timesteps,
         
     | 
| 300 | 
         
            +
                        guidance=guidance,
         
     | 
| 301 | 
         
            +
                        timestep_to_start_cfg=timestep_to_start_cfg,
         
     | 
| 302 | 
         
            +
                        neg_txt=neg_inp_cond['txt'],
         
     | 
| 303 | 
         
            +
                        neg_txt_ids=neg_inp_cond['txt_ids'],
         
     | 
| 304 | 
         
            +
                        neg_vec=neg_inp_cond['vec'],
         
     | 
| 305 | 
         
            +
                        true_gs=true_gs,
         
     | 
| 306 | 
         
            +
                    )
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    if self.offload:
         
     | 
| 309 | 
         
            +
                        self.offload_model_to_cpu(self.model)
         
     | 
| 310 | 
         
            +
                        self.ae.decoder.to(x.device)
         
     | 
| 311 | 
         
            +
                    x = unpack(x.float(), height, width)
         
     | 
| 312 | 
         
            +
                    x = self.ae.decode(x)
         
     | 
| 313 | 
         
            +
                    self.offload_model_to_cpu(self.ae.decoder)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    x1 = x.clamp(-1, 1)
         
     | 
| 316 | 
         
            +
                    x1 = rearrange(x1[-1], "c h w -> h w c")
         
     | 
| 317 | 
         
            +
                    output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
         
     | 
| 318 | 
         
            +
                    return output_img
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                def offload_model_to_cpu(self, *models):
         
     | 
| 321 | 
         
            +
                    if not self.offload: return
         
     | 
| 322 | 
         
            +
                    for model in models:
         
     | 
| 323 | 
         
            +
                        model.cpu()
         
     | 
| 324 | 
         
            +
                        torch.cuda.empty_cache()
         
     | 
    	
        uno/flux/sampling.py
    ADDED
    
    | 
         @@ -0,0 +1,271 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import math
         
     | 
| 17 | 
         
            +
            from typing import Literal
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 21 | 
         
            +
            from torch import Tensor
         
     | 
| 22 | 
         
            +
            from tqdm import tqdm
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from .model import Flux
         
     | 
| 25 | 
         
            +
            from .modules.conditioner import HFEmbedder
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def get_noise(
         
     | 
| 29 | 
         
            +
                num_samples: int,
         
     | 
| 30 | 
         
            +
                height: int,
         
     | 
| 31 | 
         
            +
                width: int,
         
     | 
| 32 | 
         
            +
                device: torch.device,
         
     | 
| 33 | 
         
            +
                dtype: torch.dtype,
         
     | 
| 34 | 
         
            +
                seed: int,
         
     | 
| 35 | 
         
            +
            ):
         
     | 
| 36 | 
         
            +
                return torch.randn(
         
     | 
| 37 | 
         
            +
                    num_samples,
         
     | 
| 38 | 
         
            +
                    16,
         
     | 
| 39 | 
         
            +
                    # allow for packing
         
     | 
| 40 | 
         
            +
                    2 * math.ceil(height / 16),
         
     | 
| 41 | 
         
            +
                    2 * math.ceil(width / 16),
         
     | 
| 42 | 
         
            +
                    device=device,
         
     | 
| 43 | 
         
            +
                    dtype=dtype,
         
     | 
| 44 | 
         
            +
                    generator=torch.Generator(device=device).manual_seed(seed),
         
     | 
| 45 | 
         
            +
                )
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def prepare(
         
     | 
| 49 | 
         
            +
                t5: HFEmbedder,
         
     | 
| 50 | 
         
            +
                clip: HFEmbedder,
         
     | 
| 51 | 
         
            +
                img: Tensor,
         
     | 
| 52 | 
         
            +
                prompt: str | list[str],
         
     | 
| 53 | 
         
            +
                ref_img: None | Tensor=None,
         
     | 
| 54 | 
         
            +
                pe: Literal['d', 'h', 'w', 'o'] ='d'
         
     | 
| 55 | 
         
            +
            ) -> dict[str, Tensor]:
         
     | 
| 56 | 
         
            +
                assert pe in ['d', 'h', 'w', 'o']
         
     | 
| 57 | 
         
            +
                bs, c, h, w = img.shape
         
     | 
| 58 | 
         
            +
                if bs == 1 and not isinstance(prompt, str):
         
     | 
| 59 | 
         
            +
                    bs = len(prompt)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         
     | 
| 62 | 
         
            +
                if img.shape[0] == 1 and bs > 1:
         
     | 
| 63 | 
         
            +
                    img = repeat(img, "1 ... -> bs ...", bs=bs)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                img_ids = torch.zeros(h // 2, w // 2, 3)
         
     | 
| 66 | 
         
            +
                img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
         
     | 
| 67 | 
         
            +
                img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
         
     | 
| 68 | 
         
            +
                img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                if ref_img is not None:
         
     | 
| 71 | 
         
            +
                    _, _, ref_h, ref_w = ref_img.shape
         
     | 
| 72 | 
         
            +
                    ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         
     | 
| 73 | 
         
            +
                    if ref_img.shape[0] == 1 and bs > 1:
         
     | 
| 74 | 
         
            +
                        ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
         
     | 
| 75 | 
         
            +
                    ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
         
     | 
| 76 | 
         
            +
                    # img id分别在宽高偏移各自最大值
         
     | 
| 77 | 
         
            +
                    h_offset = h // 2 if pe in {'d', 'h'} else 0
         
     | 
| 78 | 
         
            +
                    w_offset = w // 2 if pe in {'d', 'w'} else 0
         
     | 
| 79 | 
         
            +
                    ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
         
     | 
| 80 | 
         
            +
                    ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
         
     | 
| 81 | 
         
            +
                    ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                if isinstance(prompt, str):
         
     | 
| 84 | 
         
            +
                    prompt = [prompt]
         
     | 
| 85 | 
         
            +
                txt = t5(prompt)
         
     | 
| 86 | 
         
            +
                if txt.shape[0] == 1 and bs > 1:
         
     | 
| 87 | 
         
            +
                    txt = repeat(txt, "1 ... -> bs ...", bs=bs)
         
     | 
| 88 | 
         
            +
                txt_ids = torch.zeros(bs, txt.shape[1], 3)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                vec = clip(prompt)
         
     | 
| 91 | 
         
            +
                if vec.shape[0] == 1 and bs > 1:
         
     | 
| 92 | 
         
            +
                    vec = repeat(vec, "1 ... -> bs ...", bs=bs)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                if ref_img is not None:
         
     | 
| 95 | 
         
            +
                    return {
         
     | 
| 96 | 
         
            +
                        "img": img,
         
     | 
| 97 | 
         
            +
                        "img_ids": img_ids.to(img.device),
         
     | 
| 98 | 
         
            +
                        "ref_img": ref_img,
         
     | 
| 99 | 
         
            +
                        "ref_img_ids": ref_img_ids.to(img.device),
         
     | 
| 100 | 
         
            +
                        "txt": txt.to(img.device),
         
     | 
| 101 | 
         
            +
                        "txt_ids": txt_ids.to(img.device),
         
     | 
| 102 | 
         
            +
                        "vec": vec.to(img.device),
         
     | 
| 103 | 
         
            +
                    }
         
     | 
| 104 | 
         
            +
                else:
         
     | 
| 105 | 
         
            +
                    return {
         
     | 
| 106 | 
         
            +
                        "img": img,
         
     | 
| 107 | 
         
            +
                        "img_ids": img_ids.to(img.device),
         
     | 
| 108 | 
         
            +
                        "txt": txt.to(img.device),
         
     | 
| 109 | 
         
            +
                        "txt_ids": txt_ids.to(img.device),
         
     | 
| 110 | 
         
            +
                        "vec": vec.to(img.device),
         
     | 
| 111 | 
         
            +
                    }
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def prepare_multi_ip(
         
     | 
| 114 | 
         
            +
                t5: HFEmbedder,
         
     | 
| 115 | 
         
            +
                clip: HFEmbedder,
         
     | 
| 116 | 
         
            +
                img: Tensor,
         
     | 
| 117 | 
         
            +
                prompt: str | list[str],
         
     | 
| 118 | 
         
            +
                ref_imgs: list[Tensor] | None = None,
         
     | 
| 119 | 
         
            +
                pe: Literal['d', 'h', 'w', 'o'] = 'd'
         
     | 
| 120 | 
         
            +
            ) -> dict[str, Tensor]:
         
     | 
| 121 | 
         
            +
                assert pe in ['d', 'h', 'w', 'o']
         
     | 
| 122 | 
         
            +
                bs, c, h, w = img.shape
         
     | 
| 123 | 
         
            +
                if bs == 1 and not isinstance(prompt, str):
         
     | 
| 124 | 
         
            +
                    bs = len(prompt)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         
     | 
| 127 | 
         
            +
                if img.shape[0] == 1 and bs > 1:
         
     | 
| 128 | 
         
            +
                    img = repeat(img, "1 ... -> bs ...", bs=bs)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                img_ids = torch.zeros(h // 2, w // 2, 3)
         
     | 
| 131 | 
         
            +
                img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
         
     | 
| 132 | 
         
            +
                img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
         
     | 
| 133 | 
         
            +
                img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                ref_img_ids = []
         
     | 
| 136 | 
         
            +
                ref_imgs_list = []
         
     | 
| 137 | 
         
            +
                pe_shift_w, pe_shift_h = w // 2, h // 2
         
     | 
| 138 | 
         
            +
                for ref_img in ref_imgs:
         
     | 
| 139 | 
         
            +
                    _, _, ref_h1, ref_w1 = ref_img.shape
         
     | 
| 140 | 
         
            +
                    ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         
     | 
| 141 | 
         
            +
                    if ref_img.shape[0] == 1 and bs > 1:
         
     | 
| 142 | 
         
            +
                        ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
         
     | 
| 143 | 
         
            +
                    ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
         
     | 
| 144 | 
         
            +
                    # img id分别���宽高偏移各自最大值
         
     | 
| 145 | 
         
            +
                    h_offset = pe_shift_h if pe in {'d', 'h'} else 0
         
     | 
| 146 | 
         
            +
                    w_offset = pe_shift_w if pe in {'d', 'w'} else 0
         
     | 
| 147 | 
         
            +
                    ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
         
     | 
| 148 | 
         
            +
                    ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
         
     | 
| 149 | 
         
            +
                    ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
         
     | 
| 150 | 
         
            +
                    ref_img_ids.append(ref_img_ids1)
         
     | 
| 151 | 
         
            +
                    ref_imgs_list.append(ref_img)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    # 更新pe shift
         
     | 
| 154 | 
         
            +
                    pe_shift_h += ref_h1 // 2
         
     | 
| 155 | 
         
            +
                    pe_shift_w += ref_w1 // 2
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                if isinstance(prompt, str):
         
     | 
| 158 | 
         
            +
                    prompt = [prompt]
         
     | 
| 159 | 
         
            +
                txt = t5(prompt)
         
     | 
| 160 | 
         
            +
                if txt.shape[0] == 1 and bs > 1:
         
     | 
| 161 | 
         
            +
                    txt = repeat(txt, "1 ... -> bs ...", bs=bs)
         
     | 
| 162 | 
         
            +
                txt_ids = torch.zeros(bs, txt.shape[1], 3)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                vec = clip(prompt)
         
     | 
| 165 | 
         
            +
                if vec.shape[0] == 1 and bs > 1:
         
     | 
| 166 | 
         
            +
                    vec = repeat(vec, "1 ... -> bs ...", bs=bs)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                return {
         
     | 
| 169 | 
         
            +
                    "img": img,
         
     | 
| 170 | 
         
            +
                    "img_ids": img_ids.to(img.device),
         
     | 
| 171 | 
         
            +
                    "ref_img": tuple(ref_imgs_list),
         
     | 
| 172 | 
         
            +
                    "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
         
     | 
| 173 | 
         
            +
                    "txt": txt.to(img.device),
         
     | 
| 174 | 
         
            +
                    "txt_ids": txt_ids.to(img.device),
         
     | 
| 175 | 
         
            +
                    "vec": vec.to(img.device),
         
     | 
| 176 | 
         
            +
                }
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            def time_shift(mu: float, sigma: float, t: Tensor):
         
     | 
| 180 | 
         
            +
                return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            def get_lin_function(
         
     | 
| 184 | 
         
            +
                x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
         
     | 
| 185 | 
         
            +
            ):
         
     | 
| 186 | 
         
            +
                m = (y2 - y1) / (x2 - x1)
         
     | 
| 187 | 
         
            +
                b = y1 - m * x1
         
     | 
| 188 | 
         
            +
                return lambda x: m * x + b
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            def get_schedule(
         
     | 
| 192 | 
         
            +
                num_steps: int,
         
     | 
| 193 | 
         
            +
                image_seq_len: int,
         
     | 
| 194 | 
         
            +
                base_shift: float = 0.5,
         
     | 
| 195 | 
         
            +
                max_shift: float = 1.15,
         
     | 
| 196 | 
         
            +
                shift: bool = True,
         
     | 
| 197 | 
         
            +
            ) -> list[float]:
         
     | 
| 198 | 
         
            +
                # extra step for zero
         
     | 
| 199 | 
         
            +
                timesteps = torch.linspace(1, 0, num_steps + 1)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                # shifting the schedule to favor high timesteps for higher signal images
         
     | 
| 202 | 
         
            +
                if shift:
         
     | 
| 203 | 
         
            +
                    # eastimate mu based on linear estimation between two points
         
     | 
| 204 | 
         
            +
                    mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
         
     | 
| 205 | 
         
            +
                    timesteps = time_shift(mu, 1.0, timesteps)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                return timesteps.tolist()
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
            def denoise(
         
     | 
| 211 | 
         
            +
                model: Flux,
         
     | 
| 212 | 
         
            +
                # model input
         
     | 
| 213 | 
         
            +
                img: Tensor,
         
     | 
| 214 | 
         
            +
                img_ids: Tensor,
         
     | 
| 215 | 
         
            +
                txt: Tensor,
         
     | 
| 216 | 
         
            +
                txt_ids: Tensor,
         
     | 
| 217 | 
         
            +
                vec: Tensor,
         
     | 
| 218 | 
         
            +
                neg_txt: Tensor,
         
     | 
| 219 | 
         
            +
                neg_txt_ids: Tensor,
         
     | 
| 220 | 
         
            +
                neg_vec: Tensor,
         
     | 
| 221 | 
         
            +
                # sampling parameters
         
     | 
| 222 | 
         
            +
                timesteps: list[float],
         
     | 
| 223 | 
         
            +
                guidance: float = 4.0,
         
     | 
| 224 | 
         
            +
                true_gs = 1,
         
     | 
| 225 | 
         
            +
                timestep_to_start_cfg=0,
         
     | 
| 226 | 
         
            +
                ref_img: Tensor=None,
         
     | 
| 227 | 
         
            +
                ref_img_ids: Tensor=None,
         
     | 
| 228 | 
         
            +
            ):
         
     | 
| 229 | 
         
            +
                i = 0
         
     | 
| 230 | 
         
            +
                guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
         
     | 
| 231 | 
         
            +
                for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
         
     | 
| 232 | 
         
            +
                    t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
         
     | 
| 233 | 
         
            +
                    pred = model(
         
     | 
| 234 | 
         
            +
                        img=img,
         
     | 
| 235 | 
         
            +
                        img_ids=img_ids,
         
     | 
| 236 | 
         
            +
                        ref_img=ref_img,
         
     | 
| 237 | 
         
            +
                        ref_img_ids=ref_img_ids,
         
     | 
| 238 | 
         
            +
                        txt=txt,
         
     | 
| 239 | 
         
            +
                        txt_ids=txt_ids,
         
     | 
| 240 | 
         
            +
                        y=vec,
         
     | 
| 241 | 
         
            +
                        timesteps=t_vec,
         
     | 
| 242 | 
         
            +
                        guidance=guidance_vec
         
     | 
| 243 | 
         
            +
                    )
         
     | 
| 244 | 
         
            +
                    if i >= timestep_to_start_cfg:
         
     | 
| 245 | 
         
            +
                        # not test
         
     | 
| 246 | 
         
            +
                        neg_pred = model(
         
     | 
| 247 | 
         
            +
                            img=img,
         
     | 
| 248 | 
         
            +
                            img_ids=img_ids,
         
     | 
| 249 | 
         
            +
                            ref_img=ref_img, # TODO: neg img embedding
         
     | 
| 250 | 
         
            +
                            ref_img_ids=ref_img_ids,
         
     | 
| 251 | 
         
            +
                            txt=neg_txt,
         
     | 
| 252 | 
         
            +
                            txt_ids=neg_txt_ids,
         
     | 
| 253 | 
         
            +
                            y=neg_vec,
         
     | 
| 254 | 
         
            +
                            timesteps=t_vec,
         
     | 
| 255 | 
         
            +
                            guidance=guidance_vec,
         
     | 
| 256 | 
         
            +
                        )
         
     | 
| 257 | 
         
            +
                        pred = neg_pred + true_gs * (pred - neg_pred)
         
     | 
| 258 | 
         
            +
                    img = img + (t_prev - t_curr) * pred
         
     | 
| 259 | 
         
            +
                    i += 1
         
     | 
| 260 | 
         
            +
                return img
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
            def unpack(x: Tensor, height: int, width: int) -> Tensor:
         
     | 
| 264 | 
         
            +
                return rearrange(
         
     | 
| 265 | 
         
            +
                    x,
         
     | 
| 266 | 
         
            +
                    "b (h w) (c ph pw) -> b c (h ph) (w pw)",
         
     | 
| 267 | 
         
            +
                    h=math.ceil(height / 16),
         
     | 
| 268 | 
         
            +
                    w=math.ceil(width / 16),
         
     | 
| 269 | 
         
            +
                    ph=2,
         
     | 
| 270 | 
         
            +
                    pw=2,
         
     | 
| 271 | 
         
            +
                )
         
     | 
    	
        uno/flux/util.py
    ADDED
    
    | 
         @@ -0,0 +1,390 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import os
         
     | 
| 17 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            import json
         
     | 
| 21 | 
         
            +
            import numpy as np
         
     | 
| 22 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 23 | 
         
            +
            from safetensors import safe_open
         
     | 
| 24 | 
         
            +
            from safetensors.torch import load_file as load_sft
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            from .model import Flux, FluxParams
         
     | 
| 27 | 
         
            +
            from .modules.autoencoder import AutoEncoder, AutoEncoderParams
         
     | 
| 28 | 
         
            +
            from .modules.conditioner import HFEmbedder
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            import re
         
     | 
| 31 | 
         
            +
            from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
         
     | 
| 32 | 
         
            +
            def load_model(ckpt, device='cpu'):
         
     | 
| 33 | 
         
            +
                if ckpt.endswith('safetensors'):
         
     | 
| 34 | 
         
            +
                    from safetensors import safe_open
         
     | 
| 35 | 
         
            +
                    pl_sd = {}
         
     | 
| 36 | 
         
            +
                    with safe_open(ckpt, framework="pt", device=device) as f:
         
     | 
| 37 | 
         
            +
                        for k in f.keys():
         
     | 
| 38 | 
         
            +
                            pl_sd[k] = f.get_tensor(k)
         
     | 
| 39 | 
         
            +
                else:
         
     | 
| 40 | 
         
            +
                    pl_sd = torch.load(ckpt, map_location=device)
         
     | 
| 41 | 
         
            +
                return pl_sd
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def load_safetensors(path):
         
     | 
| 44 | 
         
            +
                tensors = {}
         
     | 
| 45 | 
         
            +
                with safe_open(path, framework="pt", device="cpu") as f:
         
     | 
| 46 | 
         
            +
                    for key in f.keys():
         
     | 
| 47 | 
         
            +
                        tensors[key] = f.get_tensor(key)
         
     | 
| 48 | 
         
            +
                return tensors
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def get_lora_rank(checkpoint):
         
     | 
| 51 | 
         
            +
                for k in checkpoint.keys():
         
     | 
| 52 | 
         
            +
                    if k.endswith(".down.weight"):
         
     | 
| 53 | 
         
            +
                        return checkpoint[k].shape[0]
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def load_checkpoint(local_path, repo_id, name):
         
     | 
| 56 | 
         
            +
                if local_path is not None:
         
     | 
| 57 | 
         
            +
                    if '.safetensors' in local_path:
         
     | 
| 58 | 
         
            +
                        print(f"Loading .safetensors checkpoint from {local_path}")
         
     | 
| 59 | 
         
            +
                        checkpoint = load_safetensors(local_path)
         
     | 
| 60 | 
         
            +
                    else:
         
     | 
| 61 | 
         
            +
                        print(f"Loading checkpoint from {local_path}")
         
     | 
| 62 | 
         
            +
                        checkpoint = torch.load(local_path, map_location='cpu')
         
     | 
| 63 | 
         
            +
                elif repo_id is not None and name is not None:
         
     | 
| 64 | 
         
            +
                    print(f"Loading checkpoint {name} from repo id {repo_id}")
         
     | 
| 65 | 
         
            +
                    checkpoint = load_from_repo_id(repo_id, name)
         
     | 
| 66 | 
         
            +
                else:
         
     | 
| 67 | 
         
            +
                    raise ValueError(
         
     | 
| 68 | 
         
            +
                        "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
         
     | 
| 69 | 
         
            +
                    )
         
     | 
| 70 | 
         
            +
                return checkpoint
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def c_crop(image):
         
     | 
| 74 | 
         
            +
                width, height = image.size
         
     | 
| 75 | 
         
            +
                new_size = min(width, height)
         
     | 
| 76 | 
         
            +
                left = (width - new_size) / 2
         
     | 
| 77 | 
         
            +
                top = (height - new_size) / 2
         
     | 
| 78 | 
         
            +
                right = (width + new_size) / 2
         
     | 
| 79 | 
         
            +
                bottom = (height + new_size) / 2
         
     | 
| 80 | 
         
            +
                return image.crop((left, top, right, bottom))
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def pad64(x):
         
     | 
| 83 | 
         
            +
                return int(np.ceil(float(x) / 64.0) * 64 - x)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            def HWC3(x):
         
     | 
| 86 | 
         
            +
                assert x.dtype == np.uint8
         
     | 
| 87 | 
         
            +
                if x.ndim == 2:
         
     | 
| 88 | 
         
            +
                    x = x[:, :, None]
         
     | 
| 89 | 
         
            +
                assert x.ndim == 3
         
     | 
| 90 | 
         
            +
                H, W, C = x.shape
         
     | 
| 91 | 
         
            +
                assert C == 1 or C == 3 or C == 4
         
     | 
| 92 | 
         
            +
                if C == 3:
         
     | 
| 93 | 
         
            +
                    return x
         
     | 
| 94 | 
         
            +
                if C == 1:
         
     | 
| 95 | 
         
            +
                    return np.concatenate([x, x, x], axis=2)
         
     | 
| 96 | 
         
            +
                if C == 4:
         
     | 
| 97 | 
         
            +
                    color = x[:, :, 0:3].astype(np.float32)
         
     | 
| 98 | 
         
            +
                    alpha = x[:, :, 3:4].astype(np.float32) / 255.0
         
     | 
| 99 | 
         
            +
                    y = color * alpha + 255.0 * (1.0 - alpha)
         
     | 
| 100 | 
         
            +
                    y = y.clip(0, 255).astype(np.uint8)
         
     | 
| 101 | 
         
            +
                    return y
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            @dataclass
         
     | 
| 104 | 
         
            +
            class ModelSpec:
         
     | 
| 105 | 
         
            +
                params: FluxParams
         
     | 
| 106 | 
         
            +
                ae_params: AutoEncoderParams
         
     | 
| 107 | 
         
            +
                ckpt_path: str | None
         
     | 
| 108 | 
         
            +
                ae_path: str | None
         
     | 
| 109 | 
         
            +
                repo_id: str | None
         
     | 
| 110 | 
         
            +
                repo_flow: str | None
         
     | 
| 111 | 
         
            +
                repo_ae: str | None
         
     | 
| 112 | 
         
            +
                repo_id_ae: str | None
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            configs = {
         
     | 
| 116 | 
         
            +
                "flux-dev": ModelSpec(
         
     | 
| 117 | 
         
            +
                    repo_id="black-forest-labs/FLUX.1-dev",
         
     | 
| 118 | 
         
            +
                    repo_id_ae="black-forest-labs/FLUX.1-dev",
         
     | 
| 119 | 
         
            +
                    repo_flow="flux1-dev.safetensors",
         
     | 
| 120 | 
         
            +
                    repo_ae="ae.safetensors",
         
     | 
| 121 | 
         
            +
                    ckpt_path=os.getenv("FLUX_DEV"),
         
     | 
| 122 | 
         
            +
                    params=FluxParams(
         
     | 
| 123 | 
         
            +
                        in_channels=64,
         
     | 
| 124 | 
         
            +
                        vec_in_dim=768,
         
     | 
| 125 | 
         
            +
                        context_in_dim=4096,
         
     | 
| 126 | 
         
            +
                        hidden_size=3072,
         
     | 
| 127 | 
         
            +
                        mlp_ratio=4.0,
         
     | 
| 128 | 
         
            +
                        num_heads=24,
         
     | 
| 129 | 
         
            +
                        depth=19,
         
     | 
| 130 | 
         
            +
                        depth_single_blocks=38,
         
     | 
| 131 | 
         
            +
                        axes_dim=[16, 56, 56],
         
     | 
| 132 | 
         
            +
                        theta=10_000,
         
     | 
| 133 | 
         
            +
                        qkv_bias=True,
         
     | 
| 134 | 
         
            +
                        guidance_embed=True,
         
     | 
| 135 | 
         
            +
                    ),
         
     | 
| 136 | 
         
            +
                    ae_path=os.getenv("AE"),
         
     | 
| 137 | 
         
            +
                    ae_params=AutoEncoderParams(
         
     | 
| 138 | 
         
            +
                        resolution=256,
         
     | 
| 139 | 
         
            +
                        in_channels=3,
         
     | 
| 140 | 
         
            +
                        ch=128,
         
     | 
| 141 | 
         
            +
                        out_ch=3,
         
     | 
| 142 | 
         
            +
                        ch_mult=[1, 2, 4, 4],
         
     | 
| 143 | 
         
            +
                        num_res_blocks=2,
         
     | 
| 144 | 
         
            +
                        z_channels=16,
         
     | 
| 145 | 
         
            +
                        scale_factor=0.3611,
         
     | 
| 146 | 
         
            +
                        shift_factor=0.1159,
         
     | 
| 147 | 
         
            +
                    ),
         
     | 
| 148 | 
         
            +
                ),
         
     | 
| 149 | 
         
            +
                "flux-dev-fp8": ModelSpec(
         
     | 
| 150 | 
         
            +
                    repo_id="XLabs-AI/flux-dev-fp8",
         
     | 
| 151 | 
         
            +
                    repo_id_ae="black-forest-labs/FLUX.1-dev",
         
     | 
| 152 | 
         
            +
                    repo_flow="flux-dev-fp8.safetensors",
         
     | 
| 153 | 
         
            +
                    repo_ae="ae.safetensors",
         
     | 
| 154 | 
         
            +
                    ckpt_path=os.getenv("FLUX_DEV_FP8"),
         
     | 
| 155 | 
         
            +
                    params=FluxParams(
         
     | 
| 156 | 
         
            +
                        in_channels=64,
         
     | 
| 157 | 
         
            +
                        vec_in_dim=768,
         
     | 
| 158 | 
         
            +
                        context_in_dim=4096,
         
     | 
| 159 | 
         
            +
                        hidden_size=3072,
         
     | 
| 160 | 
         
            +
                        mlp_ratio=4.0,
         
     | 
| 161 | 
         
            +
                        num_heads=24,
         
     | 
| 162 | 
         
            +
                        depth=19,
         
     | 
| 163 | 
         
            +
                        depth_single_blocks=38,
         
     | 
| 164 | 
         
            +
                        axes_dim=[16, 56, 56],
         
     | 
| 165 | 
         
            +
                        theta=10_000,
         
     | 
| 166 | 
         
            +
                        qkv_bias=True,
         
     | 
| 167 | 
         
            +
                        guidance_embed=True,
         
     | 
| 168 | 
         
            +
                    ),
         
     | 
| 169 | 
         
            +
                    ae_path=os.getenv("AE"),
         
     | 
| 170 | 
         
            +
                    ae_params=AutoEncoderParams(
         
     | 
| 171 | 
         
            +
                        resolution=256,
         
     | 
| 172 | 
         
            +
                        in_channels=3,
         
     | 
| 173 | 
         
            +
                        ch=128,
         
     | 
| 174 | 
         
            +
                        out_ch=3,
         
     | 
| 175 | 
         
            +
                        ch_mult=[1, 2, 4, 4],
         
     | 
| 176 | 
         
            +
                        num_res_blocks=2,
         
     | 
| 177 | 
         
            +
                        z_channels=16,
         
     | 
| 178 | 
         
            +
                        scale_factor=0.3611,
         
     | 
| 179 | 
         
            +
                        shift_factor=0.1159,
         
     | 
| 180 | 
         
            +
                    ),
         
     | 
| 181 | 
         
            +
                ),
         
     | 
| 182 | 
         
            +
                "flux-schnell": ModelSpec(
         
     | 
| 183 | 
         
            +
                    repo_id="black-forest-labs/FLUX.1-schnell",
         
     | 
| 184 | 
         
            +
                    repo_id_ae="black-forest-labs/FLUX.1-dev",
         
     | 
| 185 | 
         
            +
                    repo_flow="flux1-schnell.safetensors",
         
     | 
| 186 | 
         
            +
                    repo_ae="ae.safetensors",
         
     | 
| 187 | 
         
            +
                    ckpt_path=os.getenv("FLUX_SCHNELL"),
         
     | 
| 188 | 
         
            +
                    params=FluxParams(
         
     | 
| 189 | 
         
            +
                        in_channels=64,
         
     | 
| 190 | 
         
            +
                        vec_in_dim=768,
         
     | 
| 191 | 
         
            +
                        context_in_dim=4096,
         
     | 
| 192 | 
         
            +
                        hidden_size=3072,
         
     | 
| 193 | 
         
            +
                        mlp_ratio=4.0,
         
     | 
| 194 | 
         
            +
                        num_heads=24,
         
     | 
| 195 | 
         
            +
                        depth=19,
         
     | 
| 196 | 
         
            +
                        depth_single_blocks=38,
         
     | 
| 197 | 
         
            +
                        axes_dim=[16, 56, 56],
         
     | 
| 198 | 
         
            +
                        theta=10_000,
         
     | 
| 199 | 
         
            +
                        qkv_bias=True,
         
     | 
| 200 | 
         
            +
                        guidance_embed=False,
         
     | 
| 201 | 
         
            +
                    ),
         
     | 
| 202 | 
         
            +
                    ae_path=os.getenv("AE"),
         
     | 
| 203 | 
         
            +
                    ae_params=AutoEncoderParams(
         
     | 
| 204 | 
         
            +
                        resolution=256,
         
     | 
| 205 | 
         
            +
                        in_channels=3,
         
     | 
| 206 | 
         
            +
                        ch=128,
         
     | 
| 207 | 
         
            +
                        out_ch=3,
         
     | 
| 208 | 
         
            +
                        ch_mult=[1, 2, 4, 4],
         
     | 
| 209 | 
         
            +
                        num_res_blocks=2,
         
     | 
| 210 | 
         
            +
                        z_channels=16,
         
     | 
| 211 | 
         
            +
                        scale_factor=0.3611,
         
     | 
| 212 | 
         
            +
                        shift_factor=0.1159,
         
     | 
| 213 | 
         
            +
                    ),
         
     | 
| 214 | 
         
            +
                ),
         
     | 
| 215 | 
         
            +
            }
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
            def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
         
     | 
| 219 | 
         
            +
                if len(missing) > 0 and len(unexpected) > 0:
         
     | 
| 220 | 
         
            +
                    print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
         
     | 
| 221 | 
         
            +
                    print("\n" + "-" * 79 + "\n")
         
     | 
| 222 | 
         
            +
                    print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
         
     | 
| 223 | 
         
            +
                elif len(missing) > 0:
         
     | 
| 224 | 
         
            +
                    print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
         
     | 
| 225 | 
         
            +
                elif len(unexpected) > 0:
         
     | 
| 226 | 
         
            +
                    print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
            def load_from_repo_id(repo_id, checkpoint_name):
         
     | 
| 229 | 
         
            +
                ckpt_path = hf_hub_download(repo_id, checkpoint_name)
         
     | 
| 230 | 
         
            +
                sd = load_sft(ckpt_path, device='cpu')
         
     | 
| 231 | 
         
            +
                return sd
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
            def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
         
     | 
| 234 | 
         
            +
                # Loading Flux
         
     | 
| 235 | 
         
            +
                print("Init model")
         
     | 
| 236 | 
         
            +
                ckpt_path = configs[name].ckpt_path
         
     | 
| 237 | 
         
            +
                if (
         
     | 
| 238 | 
         
            +
                    ckpt_path is None
         
     | 
| 239 | 
         
            +
                    and configs[name].repo_id is not None
         
     | 
| 240 | 
         
            +
                    and configs[name].repo_flow is not None
         
     | 
| 241 | 
         
            +
                    and hf_download
         
     | 
| 242 | 
         
            +
                ):
         
     | 
| 243 | 
         
            +
                    ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
         
     | 
| 244 | 
         
            +
                
         
     | 
| 245 | 
         
            +
                with torch.device("meta" if ckpt_path is not None else device):
         
     | 
| 246 | 
         
            +
                    model = Flux(configs[name].params).to(torch.bfloat16)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                if ckpt_path is not None:
         
     | 
| 249 | 
         
            +
                    print("Loading checkpoint")
         
     | 
| 250 | 
         
            +
                    # load_sft doesn't support torch.device
         
     | 
| 251 | 
         
            +
                    sd = load_model(ckpt_path, device=str(device))
         
     | 
| 252 | 
         
            +
                    missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 253 | 
         
            +
                    print_load_warning(missing, unexpected)
         
     | 
| 254 | 
         
            +
                return model
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
            def load_flow_model_only_lora(
         
     | 
| 257 | 
         
            +
                name: str,
         
     | 
| 258 | 
         
            +
                device: str | torch.device = "cuda",
         
     | 
| 259 | 
         
            +
                hf_download: bool = True,
         
     | 
| 260 | 
         
            +
                lora_rank: int = 16
         
     | 
| 261 | 
         
            +
            ):
         
     | 
| 262 | 
         
            +
                # Loading Flux
         
     | 
| 263 | 
         
            +
                print("Init model")
         
     | 
| 264 | 
         
            +
                ckpt_path = configs[name].ckpt_path
         
     | 
| 265 | 
         
            +
                if (
         
     | 
| 266 | 
         
            +
                    ckpt_path is None
         
     | 
| 267 | 
         
            +
                    and configs[name].repo_id is not None
         
     | 
| 268 | 
         
            +
                    and configs[name].repo_flow is not None
         
     | 
| 269 | 
         
            +
                    and hf_download
         
     | 
| 270 | 
         
            +
                ):
         
     | 
| 271 | 
         
            +
                    ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
         
     | 
| 272 | 
         
            +
                
         
     | 
| 273 | 
         
            +
                if hf_download:
         
     | 
| 274 | 
         
            +
                    lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
         
     | 
| 275 | 
         
            +
                else:
         
     | 
| 276 | 
         
            +
                    lora_ckpt_path = os.environ.get("LORA", None)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                with torch.device("meta" if ckpt_path is not None else device):
         
     | 
| 279 | 
         
            +
                    model = Flux(configs[name].params)
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                if ckpt_path is not None:
         
     | 
| 285 | 
         
            +
                    print("Loading lora")
         
     | 
| 286 | 
         
            +
                    lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
         
     | 
| 287 | 
         
            +
                        else torch.load(lora_ckpt_path, map_location='cpu')
         
     | 
| 288 | 
         
            +
                    
         
     | 
| 289 | 
         
            +
                    print("Loading main checkpoint")
         
     | 
| 290 | 
         
            +
                    # load_sft doesn't support torch.device
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    if ckpt_path.endswith('safetensors'):
         
     | 
| 293 | 
         
            +
                        sd = load_sft(ckpt_path, device=str(device))
         
     | 
| 294 | 
         
            +
                        sd.update(lora_sd)
         
     | 
| 295 | 
         
            +
                        missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 296 | 
         
            +
                    else:
         
     | 
| 297 | 
         
            +
                        dit_state = torch.load(ckpt_path, map_location='cpu')
         
     | 
| 298 | 
         
            +
                        sd = {}
         
     | 
| 299 | 
         
            +
                        for k in dit_state.keys():
         
     | 
| 300 | 
         
            +
                            sd[k.replace('module.','')] = dit_state[k]
         
     | 
| 301 | 
         
            +
                        sd.update(lora_sd)
         
     | 
| 302 | 
         
            +
                        missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 303 | 
         
            +
                        model.to(str(device))
         
     | 
| 304 | 
         
            +
                    print_load_warning(missing, unexpected)
         
     | 
| 305 | 
         
            +
                return model
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
            def set_lora(
         
     | 
| 309 | 
         
            +
                model: Flux,
         
     | 
| 310 | 
         
            +
                lora_rank: int,
         
     | 
| 311 | 
         
            +
                double_blocks_indices: list[int] | None = None,
         
     | 
| 312 | 
         
            +
                single_blocks_indices: list[int] | None = None,
         
     | 
| 313 | 
         
            +
                device: str | torch.device = "cpu",
         
     | 
| 314 | 
         
            +
            ) -> Flux:
         
     | 
| 315 | 
         
            +
                double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
         
     | 
| 316 | 
         
            +
                single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
         
     | 
| 317 | 
         
            +
                                        else single_blocks_indices
         
     | 
| 318 | 
         
            +
                
         
     | 
| 319 | 
         
            +
                lora_attn_procs = {}
         
     | 
| 320 | 
         
            +
                with torch.device(device):
         
     | 
| 321 | 
         
            +
                    for name, attn_processor in  model.attn_processors.items():
         
     | 
| 322 | 
         
            +
                        match = re.search(r'\.(\d+)\.', name)
         
     | 
| 323 | 
         
            +
                        if match:
         
     | 
| 324 | 
         
            +
                            layer_index = int(match.group(1))
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                        if name.startswith("double_blocks") and layer_index in double_blocks_indices:
         
     | 
| 327 | 
         
            +
                            lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
         
     | 
| 328 | 
         
            +
                        elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
         
     | 
| 329 | 
         
            +
                            lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
         
     | 
| 330 | 
         
            +
                        else:
         
     | 
| 331 | 
         
            +
                            lora_attn_procs[name] = attn_processor
         
     | 
| 332 | 
         
            +
                model.set_attn_processor(lora_attn_procs)
         
     | 
| 333 | 
         
            +
                return model
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
            def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
         
     | 
| 337 | 
         
            +
                # Loading Flux
         
     | 
| 338 | 
         
            +
                from optimum.quanto import requantize
         
     | 
| 339 | 
         
            +
                print("Init model")
         
     | 
| 340 | 
         
            +
                ckpt_path = configs[name].ckpt_path
         
     | 
| 341 | 
         
            +
                if (
         
     | 
| 342 | 
         
            +
                    ckpt_path is None
         
     | 
| 343 | 
         
            +
                    and configs[name].repo_id is not None
         
     | 
| 344 | 
         
            +
                    and configs[name].repo_flow is not None
         
     | 
| 345 | 
         
            +
                    and hf_download
         
     | 
| 346 | 
         
            +
                ):
         
     | 
| 347 | 
         
            +
                    ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
         
     | 
| 348 | 
         
            +
                json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                model = Flux(configs[name].params).to(torch.bfloat16)
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                print("Loading checkpoint")
         
     | 
| 354 | 
         
            +
                # load_sft doesn't support torch.device
         
     | 
| 355 | 
         
            +
                sd = load_sft(ckpt_path, device='cpu')
         
     | 
| 356 | 
         
            +
                with open(json_path, "r") as f:
         
     | 
| 357 | 
         
            +
                    quantization_map = json.load(f)
         
     | 
| 358 | 
         
            +
                print("Start a quantization process...")
         
     | 
| 359 | 
         
            +
                requantize(model, sd, quantization_map, device=device)
         
     | 
| 360 | 
         
            +
                print("Model is quantized!")
         
     | 
| 361 | 
         
            +
                return model
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
            def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
         
     | 
| 364 | 
         
            +
                # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
         
     | 
| 365 | 
         
            +
                return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
            def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
         
     | 
| 368 | 
         
            +
                return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
            def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
         
     | 
| 372 | 
         
            +
                ckpt_path = configs[name].ae_path
         
     | 
| 373 | 
         
            +
                if (
         
     | 
| 374 | 
         
            +
                    ckpt_path is None
         
     | 
| 375 | 
         
            +
                    and configs[name].repo_id is not None
         
     | 
| 376 | 
         
            +
                    and configs[name].repo_ae is not None
         
     | 
| 377 | 
         
            +
                    and hf_download
         
     | 
| 378 | 
         
            +
                ):
         
     | 
| 379 | 
         
            +
                    ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                # Loading the autoencoder
         
     | 
| 382 | 
         
            +
                print("Init AE")
         
     | 
| 383 | 
         
            +
                with torch.device("meta" if ckpt_path is not None else device):
         
     | 
| 384 | 
         
            +
                    ae = AutoEncoder(configs[name].ae_params)
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                if ckpt_path is not None:
         
     | 
| 387 | 
         
            +
                    sd = load_sft(ckpt_path, device=str(device))
         
     | 
| 388 | 
         
            +
                    missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
         
     | 
| 389 | 
         
            +
                    print_load_warning(missing, unexpected)
         
     | 
| 390 | 
         
            +
                return ae
         
     | 
    	
        uno/utils/convert_yaml_to_args_file.py
    ADDED
    
    | 
         @@ -0,0 +1,34 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import argparse
         
     | 
| 16 | 
         
            +
            import yaml
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            parser = argparse.ArgumentParser()
         
     | 
| 19 | 
         
            +
            parser.add_argument("--yaml", type=str, required=True)
         
     | 
| 20 | 
         
            +
            parser.add_argument("--arg", type=str, required=True)
         
     | 
| 21 | 
         
            +
            args = parser.parse_args()
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            with open(args.yaml, "r") as f:
         
     | 
| 25 | 
         
            +
                data = yaml.safe_load(f)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            with open(args.arg, "w") as f:
         
     | 
| 28 | 
         
            +
                for k, v in data.items():
         
     | 
| 29 | 
         
            +
                    if isinstance(v, list):
         
     | 
| 30 | 
         
            +
                        v = list(map(str, v))
         
     | 
| 31 | 
         
            +
                        v = " ".join(v)
         
     | 
| 32 | 
         
            +
                    if v is None:
         
     | 
| 33 | 
         
            +
                        continue
         
     | 
| 34 | 
         
            +
                    print(f"--{k} {v}", end=" ", file=f)
         
     |