Maykeye
commited on
Commit
·
1a030c8
1
Parent(s):
6a2e483
Initial commit
Browse files- .gitignore +167 -0
- LICENSE +201 -0
- README.md +63 -3
- cli_imgen3_flip.py +54 -0
- image_utils.py +130 -0
- imgen3.py +100 -0
- imgen3flip.py +132 -0
- imgen3test.ipynb +0 -0
- imgen3test_flip.ipynb +0 -0
- krita-flip.png +0 -0
- krita-nonflip.png +0 -0
- krita/face1.png +0 -0
- krita/face2.png +0 -0
- torch_utils.py +13 -0
- valid.png +0 -0
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Will go to HF someday
|
2 |
+
data/
|
3 |
+
data/all_images_8.bin
|
4 |
+
data/all_images_64.bin
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
115 |
+
.pdm.toml
|
116 |
+
.pdm-python
|
117 |
+
.pdm-build/
|
118 |
+
|
119 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
120 |
+
__pypackages__/
|
121 |
+
|
122 |
+
# Celery stuff
|
123 |
+
celerybeat-schedule
|
124 |
+
celerybeat.pid
|
125 |
+
|
126 |
+
# SageMath parsed files
|
127 |
+
*.sage.py
|
128 |
+
|
129 |
+
# Environments
|
130 |
+
.env
|
131 |
+
.venv
|
132 |
+
env/
|
133 |
+
venv/
|
134 |
+
ENV/
|
135 |
+
env.bak/
|
136 |
+
venv.bak/
|
137 |
+
|
138 |
+
# Spyder project settings
|
139 |
+
.spyderproject
|
140 |
+
.spyproject
|
141 |
+
|
142 |
+
# Rope project settings
|
143 |
+
.ropeproject
|
144 |
+
|
145 |
+
# mkdocs documentation
|
146 |
+
/site
|
147 |
+
|
148 |
+
# mypy
|
149 |
+
.mypy_cache/
|
150 |
+
.dmypy.json
|
151 |
+
dmypy.json
|
152 |
+
|
153 |
+
# Pyre type checker
|
154 |
+
.pyre/
|
155 |
+
|
156 |
+
# pytype static type analyzer
|
157 |
+
.pytype/
|
158 |
+
|
159 |
+
# Cython debug symbols
|
160 |
+
cython_debug/
|
161 |
+
|
162 |
+
# PyCharm
|
163 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
164 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
165 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
166 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
167 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,63 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- huggan/anime-faces
|
5 |
+
---
|
6 |
+
|
7 |
+
# Mamba face kiss
|
8 |
+
|
9 |
+
## KISS
|
10 |
+
|
11 |
+
This repo contains two Keep It Simple Stupid anime face generators that generates 64x64 faces from 8x8 provided images.
|
12 |
+
|
13 |
+
Basic idea was to take 64x64 anime faces dataset(https://huggingface.co/datasets/huggan/anime-faces), resize it to 8x8, then teach the model to restore original images, intuition is that after that if new unseen images are provided, it will make some face.
|
14 |
+
|
15 |
+

|
16 |
+
|
17 |
+
Mamba is being fed a sequence `[A][A]...[A][SEP][B][B][B]...[B]` where there are 64 `[A]` that came from the 8x8 draft. there are 64x64 `[B]`s that are initially are upscaled draft(nearest neighbor) with addition of PAE. Model run several layers of mamba, and spits last 64x64 into RGB image. (`[SEP]` is not used for anything significant other than BERT has it to separate sentences, so I used it too as placeholder for command "Upscale from here")
|
18 |
+
|
19 |
+
Two models are used.
|
20 |
+
|
21 |
+
### RNN goess brr (one way)
|
22 |
+
|
23 |
+
One(`imgen3test.ipynb` and `imgen3.py`) always feeds images from top-left pixel to bottom-right pixel row by row
|
24 |
+
|
25 |
+

|
26 |
+
|
27 |
+
|
28 |
+
### "Bi-directional"
|
29 |
+
|
30 |
+
Another take(`imgen3test_flip.ipynb` and `imgen3_flip.py`) feed from top-left pixel to bottom-right pixel in every even layer and every odd layer sees upscaled images in reverse order
|
31 |
+
|
32 |
+

|
33 |
+
|
34 |
+
This flip version also uses way more parameters and different dtype. I didn't notice that much difference.
|
35 |
+
|
36 |
+
|
37 |
+
#### Command line tool
|
38 |
+
|
39 |
+
Simple script can be used to call the model on a single image
|
40 |
+
|
41 |
+
```console
|
42 |
+
$ cli_imgen3_flip ./krita/face1.png face1.out.png
|
43 |
+
|
44 |
+
python cli_imgen3_flip.py ./krita/face1.png /tmp/face1.png
|
45 |
+
Weight path is data/image-flip-weights-1024x4-torch.bfloat16.bin
|
46 |
+
Loading the model
|
47 |
+
Loading 8x8 input image from ./krita/face1.png
|
48 |
+
Writing 64x64 image to /tmp/face1.png
|
49 |
+
```
|
50 |
+
|
51 |
+
It's not really good way to use, comparing to calling through jupyter it though: mamba2 is implemented using triton and it takes around 30 seconds to initialize the model each time (on Raider GE76).
|
52 |
+
|
53 |
+
|
54 |
+
## Recreating
|
55 |
+
|
56 |
+
Training is done in `imgen3(_flip)?.py`. Testing is in notebook. `Image_utils` should provide path to anime faces dataset.
|
57 |
+
|
58 |
+
## Naming and configuring
|
59 |
+
|
60 |
+
Name imgen3 comes from "image generation 3".
|
61 |
+
Two other attemts are not that interesting to even backup them.
|
62 |
+
|
63 |
+
I'm too lazy to pass configuration around so parameters are hardcoded in the beginning of the file.
|
cli_imgen3_flip.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from imgen3flip import weights_path, Model, ImageBatch, OPTS
|
2 |
+
import torch
|
3 |
+
import torchvision as TV
|
4 |
+
import torchvision.transforms.functional as VF
|
5 |
+
import sys
|
6 |
+
|
7 |
+
|
8 |
+
assert weights_path.exists(), "Model weights do not exist"
|
9 |
+
|
10 |
+
assert len(sys.argv) == 3, f"Usage: {
|
11 |
+
sys.argv[0]} <input-filename> <output-filename>"
|
12 |
+
|
13 |
+
input_filename = sys.argv[1]
|
14 |
+
output_filename = sys.argv[2]
|
15 |
+
|
16 |
+
assert input_filename != output_filename, f"Use different file names"
|
17 |
+
|
18 |
+
print("Loading the model")
|
19 |
+
model = Model()
|
20 |
+
model.load_state_dict(torch.load(weights_path))
|
21 |
+
|
22 |
+
print(f"Loading 8x8 input image from {input_filename}")
|
23 |
+
# read image and ditch alpha-channel if it presents
|
24 |
+
image = TV.io.read_image(input_filename)[:3]
|
25 |
+
# Convert range from 0..255 to 0.0..1.0
|
26 |
+
image = image / 255.0
|
27 |
+
assert image.shape[0] == 3, "RGB image expected"
|
28 |
+
# Convert C H W -> H W C
|
29 |
+
image = image.permute(1, 2, 0)
|
30 |
+
# Now add batch dimension(B=1): H W C -> 1 H W C
|
31 |
+
# We also specify H, W, C explicitly as model expect them to be 8x8x3
|
32 |
+
image = image.view(1, 8, 8, 3)
|
33 |
+
|
34 |
+
# Now construct batch that model uses
|
35 |
+
# Target and loss are not used in inference, as model code always calculates loss
|
36 |
+
dummy_target = torch.zeros(1, 64, 64, 3, **OPTS)
|
37 |
+
dummy_loss = torch.tensor(-1, **OPTS)
|
38 |
+
inference_batch = ImageBatch(
|
39 |
+
im8=image.to(**OPTS),
|
40 |
+
im64=dummy_target,
|
41 |
+
loss=dummy_loss)
|
42 |
+
result = model(inference_batch)
|
43 |
+
|
44 |
+
# Now convert image to PIL format so we can save it
|
45 |
+
new_image = result.im64.detach().float().cpu()
|
46 |
+
# new_image: 1 H W C -> H W C
|
47 |
+
new_image = new_image[0]
|
48 |
+
# new_image: H W C -> C H W
|
49 |
+
new_image = new_image.permute(2, 0, 1)
|
50 |
+
assert new_image.shape == (3, 64, 64)
|
51 |
+
img = VF.to_pil_image(new_image)
|
52 |
+
# Save
|
53 |
+
print(f"Writing {img.height}x{img.width} image to {output_filename}")
|
54 |
+
img.save(output_filename)
|
image_utils.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch import Tensor
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import torchvision.io as VIO
|
7 |
+
import torchvision.transforms.functional as VF
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
|
11 |
+
# https://huggingface.co/datasets/huggan/anime-faces
|
12 |
+
RAW_IMAGES_PATH = Path(
|
13 |
+
"~/Downloads/datasets/anime/anime-faces/images").expanduser()
|
14 |
+
RESOLUTIONS = [64, 8]
|
15 |
+
|
16 |
+
AS_TENSORS_64 = Path(f"data/all_images_64.bin")
|
17 |
+
AS_TENSORS_8 = Path(f"data/all_images_8.bin")
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class ImageBatch:
|
22 |
+
im8: Tensor
|
23 |
+
im64: Tensor
|
24 |
+
loss: Tensor
|
25 |
+
|
26 |
+
@property
|
27 |
+
def n_batch(self):
|
28 |
+
return self.im8.shape[0]
|
29 |
+
|
30 |
+
def as_1d(self):
|
31 |
+
return ImageBatch(
|
32 |
+
im8=self.im8.view(self.n_batch, 8*8, self.im8.shape[-1]),
|
33 |
+
im64=self.im64.view(self.n_batch, 64*64, self.im64.shape[-1]),
|
34 |
+
loss=self.loss
|
35 |
+
)
|
36 |
+
|
37 |
+
def as_2d(self):
|
38 |
+
return ImageBatch(
|
39 |
+
im8=self.im8.view(self.n_batch, 8, 8, self.im8.shape[-1]),
|
40 |
+
im64=self.im64.view(self.n_batch, 64, 64, self.im64.shape[-1]),
|
41 |
+
loss=self.loss
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class ImageDB:
|
46 |
+
def __init__(self, val_ratio=0.05, dtype=None) -> None:
|
47 |
+
if not AS_TENSORS_64.exists():
|
48 |
+
self.make_tensor_version()
|
49 |
+
print("Load tensors file")
|
50 |
+
self.dtype = dtype or torch.bfloat16
|
51 |
+
self.all_images_64 = torch.load(AS_TENSORS_64).to(self.dtype)
|
52 |
+
self.all_images_8 = torch.load(AS_TENSORS_8).to(self.dtype)
|
53 |
+
self.n_val = int(len(self.all_images_64) * val_ratio)
|
54 |
+
|
55 |
+
def split(self, s: str):
|
56 |
+
if s == "train":
|
57 |
+
return {
|
58 |
+
8: self.all_images_8[:-self.n_val],
|
59 |
+
64: self.all_images_64[:-self.n_val]
|
60 |
+
}
|
61 |
+
if s == "valid":
|
62 |
+
return {
|
63 |
+
8: self.all_images_8[-self.n_val:],
|
64 |
+
64: self.all_images_64[-self.n_val:]
|
65 |
+
}
|
66 |
+
raise ValueError(f"Invalid split {s}")
|
67 |
+
|
68 |
+
@property
|
69 |
+
def train_ds(self):
|
70 |
+
return self.split("train")
|
71 |
+
|
72 |
+
@property
|
73 |
+
def valid_ds(self):
|
74 |
+
return self.split("valid")
|
75 |
+
|
76 |
+
@torch.no_grad()
|
77 |
+
def make_tensor_version(self, path=RAW_IMAGES_PATH):
|
78 |
+
items = list(path.glob("*.png"))
|
79 |
+
all_tensors = [load_single_image(item) for item in tqdm(items)]
|
80 |
+
t64 = torch.stack([t[64] for t in all_tensors])
|
81 |
+
t8 = torch.stack([t[8] for t in all_tensors])
|
82 |
+
torch.save(t64, AS_TENSORS_64)
|
83 |
+
torch.save(t8, AS_TENSORS_8)
|
84 |
+
return {8: t8, 64: t64}
|
85 |
+
|
86 |
+
def random_batch(self, bs: int, split: str = "train"):
|
87 |
+
split_dict = self.split(split)
|
88 |
+
im8 = split_dict[8]
|
89 |
+
im64 = split_dict[64]
|
90 |
+
keys = list(range(len(im8)))
|
91 |
+
random.shuffle(keys)
|
92 |
+
keys = keys[: bs]
|
93 |
+
return ImageBatch(
|
94 |
+
im64=im64[keys].cuda(),
|
95 |
+
im8=im8[keys].cuda(),
|
96 |
+
loss=torch.tensor(-1))
|
97 |
+
|
98 |
+
|
99 |
+
def load_single_image(path: Path):
|
100 |
+
im = VIO.read_image(str(path))
|
101 |
+
im = im / 255.0
|
102 |
+
# resize to 8x8
|
103 |
+
im8 = VF.resize(im, [8, 8], VF.InterpolationMode.NEAREST_EXACT)
|
104 |
+
# C H W -> H W C
|
105 |
+
im = im.permute(1, 2, 0).contiguous()
|
106 |
+
im8 = im8.permute(1, 2, 0).contiguous()
|
107 |
+
|
108 |
+
return {64: im, 8: im8}
|
109 |
+
|
110 |
+
|
111 |
+
class RGBToModel(nn.Module):
|
112 |
+
def __init__(self, d_model, device=None, dtype=None):
|
113 |
+
super().__init__()
|
114 |
+
self.fc = nn.Linear(3, d_model, device=device, dtype=dtype)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return self.fc(x)
|
118 |
+
|
119 |
+
|
120 |
+
class ModelToRGB(nn.Module):
|
121 |
+
def __init__(self, d_model, device=None, dtype=None):
|
122 |
+
super().__init__()
|
123 |
+
self.norm = nn.LayerNorm(d_model, device=device, dtype=dtype)
|
124 |
+
self.fc = nn.Linear(d_model, 3, device=device, dtype=dtype)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = self.norm(x)
|
128 |
+
x = self.fc(x)
|
129 |
+
x = x.sigmoid()
|
130 |
+
return x
|
imgen3.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
2 |
+
from mamba_ssm.modules.mamba2_simple import Mamba2Simple
|
3 |
+
from mamba_ssm.modules.mamba2 import Mamba2
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
import random
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from pathlib import Path
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
|
15 |
+
from image_utils import ImageDB, ImageBatch, RGBToModel
|
16 |
+
from image_utils import ModelToRGB
|
17 |
+
from torch_utils import model_numel
|
18 |
+
|
19 |
+
epochs = 10_000
|
20 |
+
bs = 16
|
21 |
+
d_model = 768
|
22 |
+
weights_path = Path(f"data/weights-{d_model}.bin")
|
23 |
+
|
24 |
+
OPTS = {
|
25 |
+
'device': "cuda",
|
26 |
+
'dtype': torch.float32
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
class MambaWrap(nn.Module):
|
31 |
+
def __init__(self) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.mamba = Mamba2Simple(d_model, **OPTS, headdim=64)
|
34 |
+
self.norm = nn.LayerNorm(d_model, **OPTS)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = x
|
38 |
+
x = self.norm(x)
|
39 |
+
x = self.mamba(x)
|
40 |
+
x = residual + x
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class Model(nn.Module):
|
45 |
+
def __init__(self) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.from_rgb = RGBToModel(d_model, **OPTS)
|
48 |
+
self.to_rgb = ModelToRGB(d_model, **OPTS)
|
49 |
+
self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS))
|
50 |
+
self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS))
|
51 |
+
self.layers = nn.ModuleList([MambaWrap() for _ in range(4)])
|
52 |
+
self.norm0 = nn.LayerNorm(d_model, **OPTS)
|
53 |
+
|
54 |
+
def forward(self, batch: ImageBatch):
|
55 |
+
B = batch.n_batch
|
56 |
+
batch = batch.as_1d()
|
57 |
+
batch.im8 = self.from_rgb(batch.im8)
|
58 |
+
|
59 |
+
s0 = self.s0.repeat(B, 1, 1)
|
60 |
+
s1 = self.zoom(batch.im8)
|
61 |
+
|
62 |
+
x = torch.cat((s0, batch.im8, s1), 1)
|
63 |
+
x = self.norm0(x)
|
64 |
+
x = self.mamba(x)
|
65 |
+
x = x[:, -64*64:]
|
66 |
+
y_hat = self.to_rgb(x)
|
67 |
+
y_true = batch.im64
|
68 |
+
batch.loss = F.mse_loss(y_hat, y_true)
|
69 |
+
batch.im64 = y_hat
|
70 |
+
return batch.as_2d()
|
71 |
+
|
72 |
+
def zoom(self, im8):
|
73 |
+
im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1])
|
74 |
+
im8 = repeat(
|
75 |
+
im8, "B H W C -> B (H 8) (W 8) C").view(im8.shape[0], 64*64, im8.shape[-1])
|
76 |
+
im8 = im8 + self.suffix
|
77 |
+
return im8
|
78 |
+
|
79 |
+
def mamba(self, x):
|
80 |
+
for layer in self.layers:
|
81 |
+
x = layer(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
if __name__ == "__main_":
|
85 |
+
image_db = ImageDB(dtype=OPTS["dtype"])
|
86 |
+
model = Model()
|
87 |
+
if weights_path.exists():
|
88 |
+
print(f"*** Load {weights_path:s}")
|
89 |
+
model.load_state_dict(torch.load(weights_path))
|
90 |
+
opt = torch.optim.AdamW(model.parameters(), fused=True)
|
91 |
+
|
92 |
+
for e in (bar := tqdm(range(epochs))):
|
93 |
+
b = model(image_db.random_batch(bs))
|
94 |
+
b.loss.backward()
|
95 |
+
opt.step()
|
96 |
+
opt.zero_grad()
|
97 |
+
bar.set_description(f'L:{b.loss.item():.4f}')
|
98 |
+
if e and e % 100 == 0:
|
99 |
+
torch.save(model.state_dict(), weights_path)
|
100 |
+
torch.save(model.state_dict(), weights_path)
|
imgen3flip.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mamba_ssm.modules.mamba2_simple import Mamba2Simple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from pathlib import Path
|
7 |
+
from einops import repeat
|
8 |
+
|
9 |
+
|
10 |
+
from image_utils import ImageDB, ImageBatch, RGBToModel
|
11 |
+
from image_utils import ModelToRGB
|
12 |
+
|
13 |
+
epochs = 10_000
|
14 |
+
bs = 16
|
15 |
+
# orig;
|
16 |
+
# bs = 16
|
17 |
+
# d_model = 768
|
18 |
+
# headdim = 64
|
19 |
+
# n_layer = 4
|
20 |
+
|
21 |
+
d_model = 1024
|
22 |
+
headdim = 64
|
23 |
+
n_layer = 4
|
24 |
+
|
25 |
+
OPTS = {
|
26 |
+
'device': "cuda",
|
27 |
+
'dtype': torch.bfloat16
|
28 |
+
}
|
29 |
+
# Since we have KISS flip/flop think that number of mamba layers are actually 2 times higher
|
30 |
+
# This is somewhat relatable to LLM model where 1 block had two mamba layers: one replaced ATTN, one replaced MLP
|
31 |
+
|
32 |
+
weights_path = Path(
|
33 |
+
f"data/image-flip-weights-{d_model}x{n_layer}-{str(OPTS['dtype'])}.bin")
|
34 |
+
print(f"Weight path is {str(weights_path)}")
|
35 |
+
|
36 |
+
|
37 |
+
class MambaWrap(nn.Module):
|
38 |
+
def __init__(self) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.mamba = Mamba2Simple(d_model, **OPTS, headdim=headdim)
|
41 |
+
self.norm = nn.LayerNorm(d_model, **OPTS)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
residual = x
|
45 |
+
x = self.norm(x)
|
46 |
+
x = self.mamba(x)
|
47 |
+
x = residual + x
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class MambaFlipFlop(nn.Module):
|
52 |
+
def __init__(self, n_values) -> None:
|
53 |
+
super().__init__()
|
54 |
+
self.mb_forward = MambaWrap()
|
55 |
+
self.mb_backward = MambaWrap()
|
56 |
+
self.n_values = n_values
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = self.mb_forward(x)
|
60 |
+
x = self.swap_order(x)
|
61 |
+
x = self.mb_backward(x)
|
62 |
+
x = self.swap_order(x)
|
63 |
+
return x
|
64 |
+
|
65 |
+
def swap_order(self, x):
|
66 |
+
T = x.shape[1]
|
67 |
+
head = torch.arange(0, T - self.n_values)
|
68 |
+
tail = torch.arange(T - 1, T - self.n_values - 1, -1)
|
69 |
+
seq = torch.cat((head, tail))
|
70 |
+
x = x[:, seq]
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
class Model(nn.Module):
|
75 |
+
def __init__(self) -> None:
|
76 |
+
super().__init__()
|
77 |
+
self.from_rgb = RGBToModel(d_model, **OPTS)
|
78 |
+
self.to_rgb = ModelToRGB(d_model, **OPTS)
|
79 |
+
self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS))
|
80 |
+
self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS))
|
81 |
+
self.layers = nn.ModuleList([MambaFlipFlop(64*64)
|
82 |
+
for _ in range(n_layer)])
|
83 |
+
self.norm0 = nn.LayerNorm(d_model, **OPTS)
|
84 |
+
|
85 |
+
def forward(self, batch: ImageBatch):
|
86 |
+
B = batch.n_batch
|
87 |
+
batch = batch.as_1d()
|
88 |
+
batch.im8 = self.from_rgb(batch.im8)
|
89 |
+
|
90 |
+
s0 = self.s0.repeat(B, 1, 1)
|
91 |
+
s1 = self.zoom(batch.im8)
|
92 |
+
|
93 |
+
x = torch.cat((s0, batch.im8, s1), 1)
|
94 |
+
x = self.norm0(x)
|
95 |
+
x = self.mamba(x)
|
96 |
+
x = x[:, -64*64:]
|
97 |
+
y_hat = self.to_rgb(x)
|
98 |
+
y_true = batch.im64
|
99 |
+
batch.loss = F.mse_loss(y_hat, y_true)
|
100 |
+
batch.im64 = y_hat
|
101 |
+
return batch.as_2d()
|
102 |
+
|
103 |
+
def zoom(self, im8):
|
104 |
+
im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1])
|
105 |
+
im8 = repeat(im8, "B H W C -> B (H 8) (W 8) C")
|
106 |
+
im8 = im8.view(im8.shape[0], 64*64, im8.shape[-1])
|
107 |
+
im8 = im8 + self.suffix
|
108 |
+
return im8
|
109 |
+
|
110 |
+
def mamba(self, x):
|
111 |
+
for layer in self.layers:
|
112 |
+
x = layer(x)
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
image_db = ImageDB(dtype=OPTS["dtype"])
|
118 |
+
model = Model()
|
119 |
+
if weights_path.exists():
|
120 |
+
print(f"*** Load {str(weights_path)}")
|
121 |
+
model.load_state_dict(torch.load(weights_path))
|
122 |
+
opt = torch.optim.AdamW(model.parameters(), fused=True)
|
123 |
+
|
124 |
+
for e in (bar := tqdm(range(epochs))):
|
125 |
+
b = model(image_db.random_batch(bs))
|
126 |
+
b.loss.backward()
|
127 |
+
opt.step()
|
128 |
+
opt.zero_grad()
|
129 |
+
bar.set_description(f'L:{b.loss.item():.4f}')
|
130 |
+
if e and e % 100 == 0:
|
131 |
+
torch.save(model.state_dict(), weights_path)
|
132 |
+
torch.save(model.state_dict(), weights_path)
|
imgen3test.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
imgen3test_flip.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
krita-flip.png
ADDED
![]() |
krita-nonflip.png
ADDED
![]() |
krita/face1.png
ADDED
![]() |
krita/face2.png
ADDED
![]() |
torch_utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def model_device(m: nn.Module):
|
6 |
+
return next(iter(m.parameters())).device
|
7 |
+
|
8 |
+
|
9 |
+
def model_numel(m: nn.Module, requires_grad=False):
|
10 |
+
if requires_grad:
|
11 |
+
return sum(p.numel() for p in m.parameters() if p.requires_grad)
|
12 |
+
else:
|
13 |
+
return sum(p.numel() for p in m.parameters())
|
valid.png
ADDED
![]() |