Spaces:
Running
Running
hjc-owo
commited on
Commit
·
966ae59
1
Parent(s):
663d321
init repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +172 -0
- .gitmodules +3 -0
- ImageReward/ImageReward.py +177 -0
- ImageReward/ReFL.py +830 -0
- ImageReward/__init__.py +3 -0
- ImageReward/models/AestheticScore.py +95 -0
- ImageReward/models/BLIP/__init__.py +1 -0
- ImageReward/models/BLIP/blip.py +70 -0
- ImageReward/models/BLIP/blip_pretrain.py +43 -0
- ImageReward/models/BLIP/med.py +947 -0
- ImageReward/models/BLIP/vit.py +301 -0
- ImageReward/models/BLIPScore.py +97 -0
- ImageReward/models/CLIPScore.py +78 -0
- ImageReward/models/__init__.py +4 -0
- ImageReward/utils.py +184 -0
- Install.md +66 -0
- LICENSE +373 -0
- README copy.md +304 -0
- app.py +83 -0
- assets/fonts/Bell-MT.ttf +0 -0
- assets/fonts/DeliusUnicase-Regular.ttf +0 -0
- assets/fonts/HobeauxRococeaux-Sherman.ttf +0 -0
- assets/fonts/IndieFlower-Regular.ttf +0 -0
- assets/fonts/JosefinSans-Light.ttf +0 -0
- assets/fonts/KaushanScript-Regular.ttf +0 -0
- assets/fonts/LuckiestGuy-Regular.ttf +0 -0
- assets/fonts/Noteworthy-Bold.ttf +0 -0
- assets/fonts/Quicksand.ttf +0 -0
- assets/fonts/Saira-Regular.ttf +0 -0
- checkpoint/placeholder.md +1 -0
- conf/config.yaml +56 -0
- conf/x/clipascene.yaml +87 -0
- conf/x/clipasso.yaml +48 -0
- conf/x/clipdraw.yaml +20 -0
- conf/x/clipfont.yaml +27 -0
- conf/x/diffsketcher.yaml +76 -0
- conf/x/diffvg.yaml +18 -0
- conf/x/live.yaml +31 -0
- conf/x/styleclipdraw.yaml +21 -0
- conf/x/stylediffsketcher.yaml +77 -0
- conf/x/svgdreamer.yaml +122 -0
- conf/x/vectorfusion.yaml +85 -0
- conf/x/wordasimage.yaml +46 -0
- data/alphabet1.svg +726 -0
- data/ballerina.png +0 -0
- data/ch1.svg +0 -0
- data/fallingwater.png +0 -0
- data/horse.png +0 -0
- data/simile.png +0 -0
- data/starry.png +0 -0
.gitignore
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
*/.ipynb_checkpoints/*
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
.python-version
|
87 |
+
|
88 |
+
# pipenv
|
89 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
90 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
91 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
92 |
+
# install all needed dependencies.
|
93 |
+
#Pipfile.lock
|
94 |
+
|
95 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
96 |
+
__pypackages__/
|
97 |
+
|
98 |
+
# Celery stuff
|
99 |
+
celerybeat-schedule
|
100 |
+
celerybeat.pid
|
101 |
+
|
102 |
+
# SageMath parsed files
|
103 |
+
*.sage.py
|
104 |
+
|
105 |
+
# Environments
|
106 |
+
.env
|
107 |
+
.venv
|
108 |
+
env/
|
109 |
+
venv/
|
110 |
+
ENV/
|
111 |
+
env.bak/
|
112 |
+
venv.bak/
|
113 |
+
|
114 |
+
# Spyder project settings
|
115 |
+
.spyderproject
|
116 |
+
.spyproject
|
117 |
+
|
118 |
+
# Rope project settings
|
119 |
+
.ropeproject
|
120 |
+
|
121 |
+
# mkdocs documentation
|
122 |
+
/site
|
123 |
+
|
124 |
+
# mypy
|
125 |
+
.mypy_cache/
|
126 |
+
.dmypy.json
|
127 |
+
dmypy.json
|
128 |
+
|
129 |
+
# Pyre type checker
|
130 |
+
.pyre/
|
131 |
+
|
132 |
+
# pytype static type analyzer
|
133 |
+
.pytype/
|
134 |
+
|
135 |
+
# Cython debug symbols
|
136 |
+
cython_debug/
|
137 |
+
|
138 |
+
# .idea
|
139 |
+
.idea/
|
140 |
+
/idea/
|
141 |
+
*.ipr
|
142 |
+
*.iml
|
143 |
+
*.iws
|
144 |
+
|
145 |
+
# macos system
|
146 |
+
.DS_Store
|
147 |
+
|
148 |
+
### project ###
|
149 |
+
|
150 |
+
# /diffvg/
|
151 |
+
# big-lama*
|
152 |
+
|
153 |
+
# pytorch-lighting logs
|
154 |
+
lightning_logs/*
|
155 |
+
|
156 |
+
# Edit settings
|
157 |
+
.editorconfig
|
158 |
+
|
159 |
+
# model checkpoint
|
160 |
+
/checkpoint/u2net/u2net.pth
|
161 |
+
!/checkpoint/placeholder.md
|
162 |
+
|
163 |
+
# ignore local results
|
164 |
+
/workspace/
|
165 |
+
.workspace/
|
166 |
+
|
167 |
+
# ignore files
|
168 |
+
./tmp/
|
169 |
+
./tmp/*
|
170 |
+
/tmp/
|
171 |
+
/tmp_select/
|
172 |
+
/tmp_select/*
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "diffvg"]
|
2 |
+
path = diffvg
|
3 |
+
url = https://github.com/BachiLi/diffvg.git
|
ImageReward/ImageReward.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : ImageReward.py
|
3 |
+
@Time : 2023/01/28 19:53:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: ImageReward Reward model.
|
7 |
+
* Based on CLIP code base and improved-aesthetic-predictor code base
|
8 |
+
* https://github.com/openai/CLIP
|
9 |
+
* https://github.com/christophschuhmann/improved-aesthetic-predictor
|
10 |
+
'''
|
11 |
+
|
12 |
+
import os
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from PIL import Image
|
16 |
+
from .models.BLIP.blip_pretrain import BLIP_Pretrain
|
17 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torchvision.transforms import InterpolationMode
|
21 |
+
|
22 |
+
BICUBIC = InterpolationMode.BICUBIC
|
23 |
+
except ImportError:
|
24 |
+
BICUBIC = Image.BICUBIC
|
25 |
+
|
26 |
+
|
27 |
+
def _convert_image_to_rgb(image):
|
28 |
+
return image.convert("RGB")
|
29 |
+
|
30 |
+
|
31 |
+
def _transform(n_px):
|
32 |
+
return Compose([
|
33 |
+
Resize(n_px, interpolation=BICUBIC),
|
34 |
+
CenterCrop(n_px),
|
35 |
+
_convert_image_to_rgb,
|
36 |
+
ToTensor(),
|
37 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
38 |
+
])
|
39 |
+
|
40 |
+
|
41 |
+
class MLP(nn.Module):
|
42 |
+
def __init__(self, input_size):
|
43 |
+
super().__init__()
|
44 |
+
self.input_size = input_size
|
45 |
+
|
46 |
+
self.layers = nn.Sequential(
|
47 |
+
nn.Linear(self.input_size, 1024),
|
48 |
+
# nn.ReLU(),
|
49 |
+
nn.Dropout(0.2),
|
50 |
+
nn.Linear(1024, 128),
|
51 |
+
# nn.ReLU(),
|
52 |
+
nn.Dropout(0.2),
|
53 |
+
nn.Linear(128, 64),
|
54 |
+
# nn.ReLU(),
|
55 |
+
nn.Dropout(0.1),
|
56 |
+
nn.Linear(64, 16),
|
57 |
+
# nn.ReLU(),
|
58 |
+
nn.Linear(16, 1)
|
59 |
+
)
|
60 |
+
|
61 |
+
# initial MLP param
|
62 |
+
for name, param in self.layers.named_parameters():
|
63 |
+
if 'weight' in name:
|
64 |
+
nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1))
|
65 |
+
if 'bias' in name:
|
66 |
+
nn.init.constant_(param, val=0)
|
67 |
+
|
68 |
+
def forward(self, input):
|
69 |
+
return self.layers(input)
|
70 |
+
|
71 |
+
|
72 |
+
class ImageReward(nn.Module):
|
73 |
+
def __init__(self, med_config, device='cpu'):
|
74 |
+
super().__init__()
|
75 |
+
self.device = device
|
76 |
+
|
77 |
+
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
78 |
+
self.preprocess = _transform(224)
|
79 |
+
self.mlp = MLP(768)
|
80 |
+
|
81 |
+
self.mean = 0.16717362830052426
|
82 |
+
self.std = 1.0333394966054072
|
83 |
+
|
84 |
+
def score_gard(self, prompt_ids, prompt_attention_mask, image):
|
85 |
+
|
86 |
+
image_embeds = self.blip.visual_encoder(image)
|
87 |
+
# text encode cross attention with image
|
88 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
89 |
+
text_output = self.blip.text_encoder(prompt_ids,
|
90 |
+
attention_mask=prompt_attention_mask,
|
91 |
+
encoder_hidden_states=image_embeds,
|
92 |
+
encoder_attention_mask=image_atts,
|
93 |
+
return_dict=True,
|
94 |
+
)
|
95 |
+
|
96 |
+
txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim)
|
97 |
+
rewards = self.mlp(txt_features)
|
98 |
+
rewards = (rewards - self.mean) / self.std
|
99 |
+
|
100 |
+
return rewards
|
101 |
+
|
102 |
+
def score(self, prompt, image):
|
103 |
+
|
104 |
+
if (type(image).__name__ == 'list'):
|
105 |
+
_, rewards = self.inference_rank(prompt, image)
|
106 |
+
return rewards
|
107 |
+
|
108 |
+
# text encode
|
109 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
|
110 |
+
return_tensors="pt").to(self.device)
|
111 |
+
|
112 |
+
# image encode
|
113 |
+
if isinstance(image, Image.Image):
|
114 |
+
pil_image = image
|
115 |
+
elif isinstance(image, str):
|
116 |
+
if os.path.isfile(image):
|
117 |
+
pil_image = Image.open(image)
|
118 |
+
else:
|
119 |
+
raise TypeError(
|
120 |
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
121 |
+
|
122 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
123 |
+
image_embeds = self.blip.visual_encoder(image)
|
124 |
+
|
125 |
+
# text encode cross attention with image
|
126 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
127 |
+
text_output = self.blip.text_encoder(text_input.input_ids,
|
128 |
+
attention_mask=text_input.attention_mask,
|
129 |
+
encoder_hidden_states=image_embeds,
|
130 |
+
encoder_attention_mask=image_atts,
|
131 |
+
return_dict=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
|
135 |
+
rewards = self.mlp(txt_features)
|
136 |
+
rewards = (rewards - self.mean) / self.std
|
137 |
+
|
138 |
+
return rewards.detach().cpu().numpy().item()
|
139 |
+
|
140 |
+
def inference_rank(self, prompt, generations_list):
|
141 |
+
|
142 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
|
143 |
+
return_tensors="pt").to(self.device)
|
144 |
+
|
145 |
+
txt_set = []
|
146 |
+
for generation in generations_list:
|
147 |
+
# image encode
|
148 |
+
if isinstance(generation, Image.Image):
|
149 |
+
pil_image = generation
|
150 |
+
elif isinstance(generation, str):
|
151 |
+
if os.path.isfile(generation):
|
152 |
+
pil_image = Image.open(generation)
|
153 |
+
else:
|
154 |
+
raise TypeError(
|
155 |
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
156 |
+
|
157 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
158 |
+
image_embeds = self.blip.visual_encoder(image)
|
159 |
+
|
160 |
+
# text encode cross attention with image
|
161 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
162 |
+
text_output = self.blip.text_encoder(text_input.input_ids,
|
163 |
+
attention_mask=text_input.attention_mask,
|
164 |
+
encoder_hidden_states=image_embeds,
|
165 |
+
encoder_attention_mask=image_atts,
|
166 |
+
return_dict=True)
|
167 |
+
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
168 |
+
|
169 |
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
170 |
+
rewards = self.mlp(txt_features) # [image_num, 1]
|
171 |
+
rewards = (rewards - self.mean) / self.std
|
172 |
+
rewards = torch.squeeze(rewards)
|
173 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
174 |
+
_, indices = torch.sort(rank, dim=0)
|
175 |
+
indices = indices + 1
|
176 |
+
|
177 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/ReFL.py
ADDED
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : ReFL.py
|
3 |
+
@Time : 2023/05/01 19:36:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: ReFL Algorithm.
|
7 |
+
* Based on diffusers code base
|
8 |
+
* https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
|
9 |
+
'''
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
import random
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import accelerate
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch.utils.checkpoint
|
23 |
+
import transformers
|
24 |
+
from accelerate import Accelerator
|
25 |
+
from accelerate.logging import get_logger
|
26 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
27 |
+
from datasets import load_dataset
|
28 |
+
from huggingface_hub import create_repo, upload_folder
|
29 |
+
from packaging import version
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
|
33 |
+
from PIL import Image
|
34 |
+
import ImageReward as RM
|
35 |
+
|
36 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
|
37 |
+
|
38 |
+
try:
|
39 |
+
from torchvision.transforms import InterpolationMode
|
40 |
+
|
41 |
+
BICUBIC = InterpolationMode.BICUBIC
|
42 |
+
except ImportError:
|
43 |
+
BICUBIC = Image.BICUBIC
|
44 |
+
|
45 |
+
import diffusers
|
46 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
47 |
+
from diffusers.optimization import get_scheduler
|
48 |
+
from diffusers.training_utils import EMAModel
|
49 |
+
from diffusers.utils import check_min_version, deprecate
|
50 |
+
from diffusers.utils.import_utils import is_xformers_available
|
51 |
+
|
52 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
53 |
+
check_min_version("0.16.0.dev0")
|
54 |
+
|
55 |
+
logger = get_logger(__name__, log_level="INFO")
|
56 |
+
|
57 |
+
DATASET_NAME_MAPPING = {
|
58 |
+
"refl": ("image", "text"),
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def parse_args():
|
63 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
64 |
+
parser.add_argument(
|
65 |
+
"--grad_scale", type=float, default=1e-3, help="Scale divided for grad loss value."
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--revision",
|
72 |
+
type=str,
|
73 |
+
default=None,
|
74 |
+
required=False,
|
75 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--dataset_name",
|
79 |
+
type=str,
|
80 |
+
default=None,
|
81 |
+
help=(
|
82 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
83 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
84 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
85 |
+
),
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--dataset_config_name",
|
89 |
+
type=str,
|
90 |
+
default=None,
|
91 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--caption_column",
|
98 |
+
type=str,
|
99 |
+
default="text",
|
100 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--max_train_samples",
|
104 |
+
type=int,
|
105 |
+
default=None,
|
106 |
+
help=(
|
107 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
108 |
+
"value if set."
|
109 |
+
),
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--validation_prompts",
|
113 |
+
type=str,
|
114 |
+
default=None,
|
115 |
+
nargs="+",
|
116 |
+
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--output_dir",
|
120 |
+
type=str,
|
121 |
+
default="checkpoint/refl",
|
122 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--cache_dir",
|
126 |
+
type=str,
|
127 |
+
default=None,
|
128 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
129 |
+
)
|
130 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
131 |
+
parser.add_argument(
|
132 |
+
"--resolution",
|
133 |
+
type=int,
|
134 |
+
default=512,
|
135 |
+
help=(
|
136 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
137 |
+
" resolution"
|
138 |
+
),
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--center_crop",
|
142 |
+
default=False,
|
143 |
+
action="store_true",
|
144 |
+
help=(
|
145 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
146 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
147 |
+
),
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--random_flip",
|
151 |
+
action="store_true",
|
152 |
+
help="whether to randomly flip images horizontally",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader."
|
156 |
+
)
|
157 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
158 |
+
parser.add_argument(
|
159 |
+
"--max_train_steps",
|
160 |
+
type=int,
|
161 |
+
default=100,
|
162 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--gradient_accumulation_steps",
|
166 |
+
type=int,
|
167 |
+
default=4,
|
168 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--gradient_checkpointing",
|
172 |
+
action="store_true",
|
173 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--learning_rate",
|
177 |
+
type=float,
|
178 |
+
default=1e-5,
|
179 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--scale_lr",
|
183 |
+
action="store_true",
|
184 |
+
default=False,
|
185 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--lr_scheduler",
|
189 |
+
type=str,
|
190 |
+
default="constant",
|
191 |
+
help=(
|
192 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
193 |
+
' "constant", "constant_with_warmup"]'
|
194 |
+
),
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--snr_gamma",
|
201 |
+
type=float,
|
202 |
+
default=None,
|
203 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
204 |
+
"More details here: https://arxiv.org/abs/2303.09556.",
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--allow_tf32",
|
211 |
+
action="store_true",
|
212 |
+
help=(
|
213 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
214 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
215 |
+
),
|
216 |
+
)
|
217 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
218 |
+
parser.add_argument(
|
219 |
+
"--non_ema_revision",
|
220 |
+
type=str,
|
221 |
+
default=None,
|
222 |
+
required=False,
|
223 |
+
help=(
|
224 |
+
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
|
225 |
+
" remote repository specified with --pretrained_model_name_or_path."
|
226 |
+
),
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--dataloader_num_workers",
|
230 |
+
type=int,
|
231 |
+
default=0,
|
232 |
+
help=(
|
233 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
234 |
+
),
|
235 |
+
)
|
236 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
237 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
238 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
239 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
240 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
241 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
242 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
243 |
+
parser.add_argument(
|
244 |
+
"--hub_model_id",
|
245 |
+
type=str,
|
246 |
+
default=None,
|
247 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--logging_dir",
|
251 |
+
type=str,
|
252 |
+
default="logs",
|
253 |
+
help=(
|
254 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
255 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
256 |
+
),
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--mixed_precision",
|
260 |
+
type=str,
|
261 |
+
default=None,
|
262 |
+
choices=["no", "fp16", "bf16"],
|
263 |
+
help=(
|
264 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
265 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
266 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
267 |
+
),
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--report_to",
|
271 |
+
type=str,
|
272 |
+
default="tensorboard",
|
273 |
+
help=(
|
274 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
275 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
276 |
+
),
|
277 |
+
)
|
278 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
279 |
+
parser.add_argument(
|
280 |
+
"--checkpointing_steps",
|
281 |
+
type=int,
|
282 |
+
default=100,
|
283 |
+
help=(
|
284 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
285 |
+
" training using `--resume_from_checkpoint`."
|
286 |
+
),
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--checkpoints_total_limit",
|
290 |
+
type=int,
|
291 |
+
default=None,
|
292 |
+
help=(
|
293 |
+
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
294 |
+
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
295 |
+
" for more docs"
|
296 |
+
),
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--resume_from_checkpoint",
|
300 |
+
type=str,
|
301 |
+
default=None,
|
302 |
+
help=(
|
303 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
304 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
305 |
+
),
|
306 |
+
)
|
307 |
+
parser.add_argument(
|
308 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
309 |
+
)
|
310 |
+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
311 |
+
parser.add_argument(
|
312 |
+
"--validation_epochs",
|
313 |
+
type=int,
|
314 |
+
default=5,
|
315 |
+
help="Run validation every X epochs.",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--tracker_project_name",
|
319 |
+
type=str,
|
320 |
+
default="text2image-refl",
|
321 |
+
help=(
|
322 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
323 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
324 |
+
),
|
325 |
+
)
|
326 |
+
|
327 |
+
args = parser.parse_args()
|
328 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
329 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
330 |
+
args.local_rank = env_local_rank
|
331 |
+
|
332 |
+
# default to using the same revision for the non-ema model if not specified
|
333 |
+
if args.non_ema_revision is None:
|
334 |
+
args.non_ema_revision = args.revision
|
335 |
+
|
336 |
+
return args
|
337 |
+
|
338 |
+
|
339 |
+
class Trainer(object):
|
340 |
+
|
341 |
+
def __init__(self, pretrained_model_name_or_path, train_data_dir, args):
|
342 |
+
|
343 |
+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
344 |
+
self.train_data_dir = train_data_dir
|
345 |
+
|
346 |
+
# Sanity checks
|
347 |
+
if args.dataset_name is None and self.train_data_dir is None:
|
348 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
349 |
+
|
350 |
+
if args.non_ema_revision is not None:
|
351 |
+
deprecate(
|
352 |
+
"non_ema_revision!=None",
|
353 |
+
"0.15.0",
|
354 |
+
message=(
|
355 |
+
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
|
356 |
+
" use `--variant=non_ema` instead."
|
357 |
+
),
|
358 |
+
)
|
359 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
360 |
+
|
361 |
+
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
362 |
+
|
363 |
+
self.accelerator = Accelerator(
|
364 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
365 |
+
mixed_precision=args.mixed_precision,
|
366 |
+
log_with=args.report_to,
|
367 |
+
logging_dir=logging_dir,
|
368 |
+
project_config=accelerator_project_config,
|
369 |
+
)
|
370 |
+
|
371 |
+
# Make one log on every process with the configuration for debugging.
|
372 |
+
logging.basicConfig(
|
373 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
374 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
375 |
+
level=logging.INFO,
|
376 |
+
)
|
377 |
+
logger.info(self.accelerator.state, main_process_only=False)
|
378 |
+
if self.accelerator.is_local_main_process:
|
379 |
+
transformers.utils.logging.set_verbosity_warning()
|
380 |
+
diffusers.utils.logging.set_verbosity_info()
|
381 |
+
else:
|
382 |
+
transformers.utils.logging.set_verbosity_error()
|
383 |
+
diffusers.utils.logging.set_verbosity_error()
|
384 |
+
|
385 |
+
# If passed along, set the training seed now.
|
386 |
+
if args.seed is not None:
|
387 |
+
set_seed(args.seed)
|
388 |
+
|
389 |
+
# Handle the repository creation
|
390 |
+
if self.accelerator.is_main_process:
|
391 |
+
if args.output_dir is not None:
|
392 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
393 |
+
|
394 |
+
if args.push_to_hub:
|
395 |
+
self.repo_id = create_repo(
|
396 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
397 |
+
).repo_id
|
398 |
+
|
399 |
+
# Load scheduler, tokenizer and models.
|
400 |
+
self.noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
|
401 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
402 |
+
self.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
403 |
+
)
|
404 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
405 |
+
self.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
406 |
+
)
|
407 |
+
self.vae = AutoencoderKL.from_pretrained(self.pretrained_model_name_or_path, subfolder="vae",
|
408 |
+
revision=args.revision)
|
409 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
410 |
+
self.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
411 |
+
)
|
412 |
+
self.reward_model = RM.load("ImageReward-v1.0", device=self.accelerator.device)
|
413 |
+
|
414 |
+
# Freeze vae and text_encoder
|
415 |
+
self.vae.requires_grad_(False)
|
416 |
+
self.text_encoder.requires_grad_(False)
|
417 |
+
self.reward_model.requires_grad_(False)
|
418 |
+
|
419 |
+
# Create EMA for the unet.
|
420 |
+
if args.use_ema:
|
421 |
+
self.ema_unet = UNet2DConditionModel.from_pretrained(
|
422 |
+
self.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
423 |
+
)
|
424 |
+
self.ema_unet = EMAModel(self.ema_unet.parameters(), model_cls=UNet2DConditionModel,
|
425 |
+
model_config=self.ema_unet.config)
|
426 |
+
|
427 |
+
if args.enable_xformers_memory_efficient_attention:
|
428 |
+
if is_xformers_available():
|
429 |
+
import xformers
|
430 |
+
|
431 |
+
xformers_version = version.parse(xformers.__version__)
|
432 |
+
if xformers_version == version.parse("0.0.16"):
|
433 |
+
logger.warn(
|
434 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
435 |
+
)
|
436 |
+
self.unet.enable_xformers_memory_efficient_attention()
|
437 |
+
else:
|
438 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
439 |
+
|
440 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
441 |
+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
442 |
+
# create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format
|
443 |
+
def save_model_hook(models, weights, output_dir):
|
444 |
+
if args.use_ema:
|
445 |
+
self.ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
446 |
+
|
447 |
+
for i, model in enumerate(models):
|
448 |
+
model.save_pretrained(os.path.join(output_dir, "unet"))
|
449 |
+
|
450 |
+
# make sure to pop weight so that corresponding model is not saved again
|
451 |
+
weights.pop()
|
452 |
+
|
453 |
+
def load_model_hook(models, input_dir):
|
454 |
+
if args.use_ema:
|
455 |
+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
|
456 |
+
self.ema_unet.load_state_dict(load_model.state_dict())
|
457 |
+
self.ema_unet.to(self.accelerator.device)
|
458 |
+
del load_model
|
459 |
+
|
460 |
+
for i in range(len(models)):
|
461 |
+
# pop models so that they are not loaded again
|
462 |
+
model = models.pop()
|
463 |
+
|
464 |
+
# load diffusers style into model
|
465 |
+
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
466 |
+
model.register_to_config(**load_model.config)
|
467 |
+
|
468 |
+
model.load_state_dict(load_model.state_dict())
|
469 |
+
del load_model
|
470 |
+
|
471 |
+
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
472 |
+
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
473 |
+
|
474 |
+
if args.gradient_checkpointing:
|
475 |
+
self.unet.enable_gradient_checkpointing()
|
476 |
+
|
477 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
478 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
479 |
+
if args.allow_tf32:
|
480 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
481 |
+
|
482 |
+
if args.scale_lr:
|
483 |
+
args.learning_rate = (
|
484 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes
|
485 |
+
)
|
486 |
+
|
487 |
+
# Initialize the optimizer
|
488 |
+
if args.use_8bit_adam:
|
489 |
+
try:
|
490 |
+
import bitsandbytes as bnb
|
491 |
+
except ImportError:
|
492 |
+
raise ImportError(
|
493 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
494 |
+
)
|
495 |
+
|
496 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
497 |
+
else:
|
498 |
+
optimizer_cls = torch.optim.AdamW
|
499 |
+
|
500 |
+
self.optimizer = optimizer_cls(
|
501 |
+
self.unet.parameters(),
|
502 |
+
lr=args.learning_rate,
|
503 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
504 |
+
weight_decay=args.adam_weight_decay,
|
505 |
+
eps=args.adam_epsilon,
|
506 |
+
)
|
507 |
+
|
508 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
509 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
510 |
+
|
511 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
512 |
+
# download the dataset.
|
513 |
+
if args.dataset_name is not None:
|
514 |
+
# Downloading and loading a dataset from the hub.
|
515 |
+
dataset = load_dataset(
|
516 |
+
args.dataset_name,
|
517 |
+
args.dataset_config_name,
|
518 |
+
cache_dir=args.cache_dir,
|
519 |
+
)
|
520 |
+
else:
|
521 |
+
data_files = {}
|
522 |
+
data_files["train"] = self.train_data_dir
|
523 |
+
dataset = load_dataset(
|
524 |
+
"json",
|
525 |
+
data_files=data_files,
|
526 |
+
cache_dir=args.cache_dir,
|
527 |
+
)
|
528 |
+
# See more about loading custom images at
|
529 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
530 |
+
|
531 |
+
# Preprocessing the datasets.
|
532 |
+
# We need to tokenize inputs and targets.
|
533 |
+
column_names = dataset["train"].column_names
|
534 |
+
|
535 |
+
# Get the column names for input/target.
|
536 |
+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
537 |
+
if args.image_column is None:
|
538 |
+
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
539 |
+
else:
|
540 |
+
image_column = args.image_column
|
541 |
+
if image_column not in column_names:
|
542 |
+
raise ValueError(
|
543 |
+
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
544 |
+
)
|
545 |
+
if args.caption_column is None:
|
546 |
+
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
547 |
+
else:
|
548 |
+
caption_column = args.caption_column
|
549 |
+
if caption_column not in column_names:
|
550 |
+
raise ValueError(
|
551 |
+
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
552 |
+
)
|
553 |
+
|
554 |
+
# Preprocessing the datasets.
|
555 |
+
# We need to tokenize input captions and transform the images.
|
556 |
+
def tokenize_captions(examples, is_train=True):
|
557 |
+
captions = []
|
558 |
+
for caption in examples[caption_column]:
|
559 |
+
if isinstance(caption, str):
|
560 |
+
captions.append(caption)
|
561 |
+
elif isinstance(caption, (list, np.ndarray)):
|
562 |
+
# take a random caption if there are multiple
|
563 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
564 |
+
else:
|
565 |
+
raise ValueError(
|
566 |
+
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
567 |
+
)
|
568 |
+
inputs = tokenizer(
|
569 |
+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
|
570 |
+
return_tensors="pt"
|
571 |
+
)
|
572 |
+
return inputs.input_ids
|
573 |
+
|
574 |
+
def preprocess_train(examples):
|
575 |
+
examples["input_ids"] = tokenize_captions(examples)
|
576 |
+
examples["rm_input_ids"] = self.reward_model.blip.tokenizer(examples[caption_column], padding='max_length',
|
577 |
+
truncation=True, max_length=35,
|
578 |
+
return_tensors="pt").input_ids
|
579 |
+
examples["rm_attention_mask"] = self.reward_model.blip.tokenizer(examples[caption_column],
|
580 |
+
padding='max_length', truncation=True,
|
581 |
+
max_length=35,
|
582 |
+
return_tensors="pt").attention_mask
|
583 |
+
return examples
|
584 |
+
|
585 |
+
with self.accelerator.main_process_first():
|
586 |
+
if args.max_train_samples is not None:
|
587 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
588 |
+
# Set the training transforms
|
589 |
+
self.train_dataset = dataset["train"].with_transform(preprocess_train)
|
590 |
+
|
591 |
+
def collate_fn(examples):
|
592 |
+
input_ids = torch.stack([example["input_ids"] for example in examples])
|
593 |
+
rm_input_ids = torch.stack([example["rm_input_ids"] for example in examples])
|
594 |
+
rm_attention_mask = torch.stack([example["rm_attention_mask"] for example in examples])
|
595 |
+
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
596 |
+
rm_input_ids = rm_input_ids.view(-1, rm_input_ids.shape[-1])
|
597 |
+
rm_attention_mask = rm_attention_mask.view(-1, rm_attention_mask.shape[-1])
|
598 |
+
return {"input_ids": input_ids, "rm_input_ids": rm_input_ids, "rm_attention_mask": rm_attention_mask}
|
599 |
+
|
600 |
+
# DataLoaders creation:
|
601 |
+
self.train_dataloader = torch.utils.data.DataLoader(
|
602 |
+
self.train_dataset,
|
603 |
+
shuffle=True,
|
604 |
+
collate_fn=collate_fn,
|
605 |
+
batch_size=args.train_batch_size,
|
606 |
+
num_workers=args.dataloader_num_workers,
|
607 |
+
)
|
608 |
+
|
609 |
+
# Scheduler and math around the number of training steps.
|
610 |
+
overrode_max_train_steps = False
|
611 |
+
self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
|
612 |
+
if args.max_train_steps is None:
|
613 |
+
args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
|
614 |
+
overrode_max_train_steps = True
|
615 |
+
|
616 |
+
self.lr_scheduler = get_scheduler(
|
617 |
+
args.lr_scheduler,
|
618 |
+
optimizer=self.optimizer,
|
619 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
620 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
621 |
+
)
|
622 |
+
|
623 |
+
# Prepare everything with our `self.accelerator`.
|
624 |
+
self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
|
625 |
+
self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler
|
626 |
+
)
|
627 |
+
|
628 |
+
if args.use_ema:
|
629 |
+
self.ema_unet.to(self.accelerator.device)
|
630 |
+
|
631 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
632 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
633 |
+
self.weight_dtype = torch.float32
|
634 |
+
if self.accelerator.mixed_precision == "fp16":
|
635 |
+
self.weight_dtype = torch.float16
|
636 |
+
elif self.accelerator.mixed_precision == "bf16":
|
637 |
+
self.weight_dtype = torch.bfloat16
|
638 |
+
|
639 |
+
# Move text_encode and vae to gpu and cast to self.weight_dtype
|
640 |
+
self.text_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
|
641 |
+
self.vae.to(self.accelerator.device, dtype=self.weight_dtype)
|
642 |
+
self.reward_model.to(self.accelerator.device, dtype=self.weight_dtype)
|
643 |
+
|
644 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
645 |
+
self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
|
646 |
+
if overrode_max_train_steps:
|
647 |
+
args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
|
648 |
+
# Afterwards we recalculate our number of training epochs
|
649 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / self.num_update_steps_per_epoch)
|
650 |
+
|
651 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
652 |
+
# The trackers initializes automatically on the main process.
|
653 |
+
if self.accelerator.is_main_process:
|
654 |
+
tracker_config = dict(vars(args))
|
655 |
+
tracker_config.pop("validation_prompts")
|
656 |
+
self.accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
657 |
+
|
658 |
+
def train(self, args):
|
659 |
+
|
660 |
+
# Train!
|
661 |
+
total_batch_size = args.train_batch_size * self.accelerator.num_processes * args.gradient_accumulation_steps
|
662 |
+
|
663 |
+
logger.info("***** Running training *****")
|
664 |
+
logger.info(f" Num examples = {len(self.train_dataset)}")
|
665 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
666 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
667 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
668 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
669 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
670 |
+
global_step = 0
|
671 |
+
first_epoch = 0
|
672 |
+
|
673 |
+
# Potentially load in the weights and states from a previous save
|
674 |
+
if args.resume_from_checkpoint:
|
675 |
+
if args.resume_from_checkpoint != "latest":
|
676 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
677 |
+
else:
|
678 |
+
# Get the most recent checkpoint
|
679 |
+
dirs = os.listdir(args.output_dir)
|
680 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
681 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
682 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
683 |
+
|
684 |
+
if path is None:
|
685 |
+
self.accelerator.print(
|
686 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
687 |
+
)
|
688 |
+
args.resume_from_checkpoint = None
|
689 |
+
else:
|
690 |
+
self.accelerator.print(f"Resuming from checkpoint {path}")
|
691 |
+
self.accelerator.load_state(os.path.join(args.output_dir, path))
|
692 |
+
global_step = int(path.split("-")[1])
|
693 |
+
|
694 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
695 |
+
first_epoch = global_step // self.num_update_steps_per_epoch
|
696 |
+
resume_step = resume_global_step % (self.num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
697 |
+
|
698 |
+
# Only show the progress bar once on each machine.
|
699 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps),
|
700 |
+
disable=not self.accelerator.is_local_main_process)
|
701 |
+
progress_bar.set_description("Steps")
|
702 |
+
|
703 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
704 |
+
self.unet.train()
|
705 |
+
train_loss = 0.0
|
706 |
+
for step, batch in enumerate(self.train_dataloader):
|
707 |
+
# Skip steps until we reach the resumed step
|
708 |
+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
709 |
+
if step % args.gradient_accumulation_steps == 0:
|
710 |
+
progress_bar.update(1)
|
711 |
+
continue
|
712 |
+
|
713 |
+
with self.accelerator.accumulate(self.unet):
|
714 |
+
encoder_hidden_states = self.text_encoder(batch["input_ids"])[0]
|
715 |
+
latents = torch.randn((args.train_batch_size, 4, 64, 64), device=self.accelerator.device)
|
716 |
+
|
717 |
+
self.noise_scheduler.set_timesteps(40, device=self.accelerator.device)
|
718 |
+
timesteps = self.noise_scheduler.timesteps
|
719 |
+
|
720 |
+
mid_timestep = random.randint(30, 39)
|
721 |
+
|
722 |
+
for i, t in enumerate(timesteps[:mid_timestep]):
|
723 |
+
with torch.no_grad():
|
724 |
+
latent_model_input = latents
|
725 |
+
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
|
726 |
+
noise_pred = self.unet(
|
727 |
+
latent_model_input,
|
728 |
+
t,
|
729 |
+
encoder_hidden_states=encoder_hidden_states,
|
730 |
+
).sample
|
731 |
+
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
|
732 |
+
|
733 |
+
latent_model_input = latents
|
734 |
+
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input,
|
735 |
+
timesteps[mid_timestep])
|
736 |
+
noise_pred = self.unet(
|
737 |
+
latent_model_input,
|
738 |
+
timesteps[mid_timestep],
|
739 |
+
encoder_hidden_states=encoder_hidden_states,
|
740 |
+
).sample
|
741 |
+
pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep],
|
742 |
+
latents).pred_original_sample.to(self.weight_dtype)
|
743 |
+
|
744 |
+
pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
|
745 |
+
image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
|
746 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
747 |
+
|
748 |
+
# image encode
|
749 |
+
def _transform():
|
750 |
+
return Compose([
|
751 |
+
Resize(224, interpolation=BICUBIC),
|
752 |
+
CenterCrop(224),
|
753 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
754 |
+
])
|
755 |
+
|
756 |
+
rm_preprocess = _transform()
|
757 |
+
image = rm_preprocess(image).to(self.accelerator.device)
|
758 |
+
|
759 |
+
rewards = self.reward_model.score_gard(batch["rm_input_ids"], batch["rm_attention_mask"], image)
|
760 |
+
loss = F.relu(-rewards + 2)
|
761 |
+
loss = loss.mean() * args.grad_scale
|
762 |
+
|
763 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
764 |
+
avg_loss = self.accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
765 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
766 |
+
|
767 |
+
# Backpropagate
|
768 |
+
self.accelerator.backward(loss)
|
769 |
+
if self.accelerator.sync_gradients:
|
770 |
+
self.accelerator.clip_grad_norm_(self.unet.parameters(), args.max_grad_norm)
|
771 |
+
self.optimizer.step()
|
772 |
+
self.lr_scheduler.step()
|
773 |
+
self.optimizer.zero_grad()
|
774 |
+
|
775 |
+
# Checks if the self.accelerator has performed an optimization step behind the scenes
|
776 |
+
if self.accelerator.sync_gradients:
|
777 |
+
if args.use_ema:
|
778 |
+
self.ema_unet.step(self.unet.parameters())
|
779 |
+
progress_bar.update(1)
|
780 |
+
global_step += 1
|
781 |
+
self.accelerator.log({"train_loss": train_loss}, step=global_step)
|
782 |
+
train_loss = 0.0
|
783 |
+
|
784 |
+
if global_step % args.checkpointing_steps == 0:
|
785 |
+
if self.accelerator.is_main_process:
|
786 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
787 |
+
self.accelerator.save_state(save_path)
|
788 |
+
logger.info(f"Saved state to {save_path}")
|
789 |
+
|
790 |
+
logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
|
791 |
+
progress_bar.set_postfix(**logs)
|
792 |
+
|
793 |
+
if global_step >= args.max_train_steps:
|
794 |
+
break
|
795 |
+
|
796 |
+
if self.accelerator.is_main_process:
|
797 |
+
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
798 |
+
if args.use_ema:
|
799 |
+
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
800 |
+
self.ema_unet.store(self.unet.parameters())
|
801 |
+
self.ema_unet.copy_to(self.unet.parameters())
|
802 |
+
if args.use_ema:
|
803 |
+
# Switch back to the original UNet parameters.
|
804 |
+
self.ema_unet.restore(self.unet.parameters())
|
805 |
+
|
806 |
+
# Create the pipeline using the trained modules and save it.
|
807 |
+
self.accelerator.wait_for_everyone()
|
808 |
+
if self.accelerator.is_main_process:
|
809 |
+
self.unet = self.accelerator.unwrap_model(self.unet)
|
810 |
+
if args.use_ema:
|
811 |
+
self.ema_unet.copy_to(self.unet.parameters())
|
812 |
+
|
813 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
814 |
+
self.pretrained_model_name_or_path,
|
815 |
+
text_encoder=self.text_encoder,
|
816 |
+
vae=self.vae,
|
817 |
+
unet=self.unet,
|
818 |
+
revision=args.revision,
|
819 |
+
)
|
820 |
+
pipeline.save_pretrained(args.output_dir)
|
821 |
+
|
822 |
+
if args.push_to_hub:
|
823 |
+
upload_folder(
|
824 |
+
repo_id=self.repo_id,
|
825 |
+
folder_path=args.output_dir,
|
826 |
+
commit_message="End of training",
|
827 |
+
ignore_patterns=["step_*", "epoch_*"],
|
828 |
+
)
|
829 |
+
|
830 |
+
self.accelerator.end_training()
|
ImageReward/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import *
|
2 |
+
from .models import *
|
3 |
+
from .ReFL import *
|
ImageReward/models/AestheticScore.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : AestheticScore.py
|
3 |
+
@Time : 2023/02/12 14:54:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: AestheticScore.
|
7 |
+
* Based on improved-aesthetic-predictor code base
|
8 |
+
* https://github.com/christophschuhmann/improved-aesthetic-predictor
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from PIL import Image
|
15 |
+
import clip
|
16 |
+
|
17 |
+
|
18 |
+
# if you changed the MLP architecture during training, change it also here:
|
19 |
+
class MLP(nn.Module):
|
20 |
+
def __init__(self, input_size):
|
21 |
+
super().__init__()
|
22 |
+
self.input_size = input_size
|
23 |
+
self.layers = nn.Sequential(
|
24 |
+
nn.Linear(self.input_size, 1024),
|
25 |
+
# nn.ReLU(),
|
26 |
+
nn.Dropout(0.2),
|
27 |
+
nn.Linear(1024, 128),
|
28 |
+
# nn.ReLU(),
|
29 |
+
nn.Dropout(0.2),
|
30 |
+
nn.Linear(128, 64),
|
31 |
+
# nn.ReLU(),
|
32 |
+
nn.Dropout(0.1),
|
33 |
+
|
34 |
+
nn.Linear(64, 16),
|
35 |
+
# nn.ReLU(),
|
36 |
+
|
37 |
+
nn.Linear(16, 1)
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.layers(x)
|
42 |
+
|
43 |
+
|
44 |
+
class AestheticScore(nn.Module):
|
45 |
+
def __init__(self, download_root, device='cpu'):
|
46 |
+
super().__init__()
|
47 |
+
self.device = device
|
48 |
+
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
|
49 |
+
download_root=download_root)
|
50 |
+
self.mlp = MLP(768)
|
51 |
+
|
52 |
+
if device == "cpu":
|
53 |
+
self.clip_model.float()
|
54 |
+
else:
|
55 |
+
clip.model.convert_weights(
|
56 |
+
self.clip_model) # Actually this line is unnecessary since clip by default already on float16
|
57 |
+
|
58 |
+
# have clip.logit_scale require no grad.
|
59 |
+
self.clip_model.logit_scale.requires_grad_(False)
|
60 |
+
|
61 |
+
def score(self, prompt, image_path):
|
62 |
+
|
63 |
+
if (type(image_path).__name__ == 'list'):
|
64 |
+
_, rewards = self.inference_rank(prompt, image_path)
|
65 |
+
return rewards
|
66 |
+
|
67 |
+
# image encode
|
68 |
+
pil_image = Image.open(image_path)
|
69 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
70 |
+
image_features = F.normalize(self.clip_model.encode_image(image)).float()
|
71 |
+
|
72 |
+
# score
|
73 |
+
rewards = self.mlp(image_features)
|
74 |
+
|
75 |
+
return rewards.detach().cpu().numpy().item()
|
76 |
+
|
77 |
+
def inference_rank(self, prompt, generations_list):
|
78 |
+
|
79 |
+
img_set = []
|
80 |
+
for generations in generations_list:
|
81 |
+
# image encode
|
82 |
+
img_path = generations
|
83 |
+
pil_image = Image.open(img_path)
|
84 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
85 |
+
image_features = F.normalize(self.clip_model.encode_image(image))
|
86 |
+
img_set.append(image_features)
|
87 |
+
|
88 |
+
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
|
89 |
+
rewards = self.mlp(img_features)
|
90 |
+
rewards = torch.squeeze(rewards)
|
91 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
92 |
+
_, indices = torch.sort(rank, dim=0)
|
93 |
+
indices = indices + 1
|
94 |
+
|
95 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/models/BLIP/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .blip_pretrain import *
|
ImageReward/models/BLIP/blip.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
'''
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
warnings.filterwarnings("ignore")
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
from timm.models.hub import download_cached_file
|
12 |
+
from transformers import BertTokenizer
|
13 |
+
from .vit import VisionTransformer, interpolate_pos_embed
|
14 |
+
|
15 |
+
|
16 |
+
def init_tokenizer():
|
17 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
18 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
19 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
20 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
21 |
+
return tokenizer
|
22 |
+
|
23 |
+
|
24 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
25 |
+
|
26 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
27 |
+
if vit=='base':
|
28 |
+
vision_width = 768
|
29 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
30 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
31 |
+
drop_path_rate=0 or drop_path_rate
|
32 |
+
)
|
33 |
+
elif vit=='large':
|
34 |
+
vision_width = 1024
|
35 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
36 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
37 |
+
drop_path_rate=0.1 or drop_path_rate
|
38 |
+
)
|
39 |
+
return visual_encoder, vision_width
|
40 |
+
|
41 |
+
|
42 |
+
def is_url(url_or_filename):
|
43 |
+
parsed = urlparse(url_or_filename)
|
44 |
+
return parsed.scheme in ("http", "https")
|
45 |
+
|
46 |
+
def load_checkpoint(model,url_or_filename):
|
47 |
+
if is_url(url_or_filename):
|
48 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
49 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
50 |
+
elif os.path.isfile(url_or_filename):
|
51 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
52 |
+
else:
|
53 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
54 |
+
|
55 |
+
state_dict = checkpoint['model']
|
56 |
+
|
57 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
58 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
59 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
60 |
+
model.visual_encoder_m)
|
61 |
+
for key in model.state_dict().keys():
|
62 |
+
if key in state_dict.keys():
|
63 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
64 |
+
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
65 |
+
del state_dict[key]
|
66 |
+
|
67 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
68 |
+
print('load checkpoint from %s'%url_or_filename)
|
69 |
+
return model,msg
|
70 |
+
|
ImageReward/models/BLIP/blip_pretrain.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
'''
|
4 |
+
|
5 |
+
import transformers
|
6 |
+
transformers.logging.set_verbosity_error()
|
7 |
+
|
8 |
+
from torch import nn
|
9 |
+
import os
|
10 |
+
from .med import BertConfig, BertModel
|
11 |
+
from .blip import create_vit, init_tokenizer
|
12 |
+
|
13 |
+
class BLIP_Pretrain(nn.Module):
|
14 |
+
def __init__(self,
|
15 |
+
med_config = "med_config.json",
|
16 |
+
image_size = 224,
|
17 |
+
vit = 'base',
|
18 |
+
vit_grad_ckpt = False,
|
19 |
+
vit_ckpt_layer = 0,
|
20 |
+
embed_dim = 256,
|
21 |
+
queue_size = 57600,
|
22 |
+
momentum = 0.995,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
27 |
+
image_size (int): input image size
|
28 |
+
vit (str): model size of vision transformer
|
29 |
+
"""
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
33 |
+
|
34 |
+
self.tokenizer = init_tokenizer()
|
35 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
36 |
+
encoder_config.encoder_width = vision_width
|
37 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
38 |
+
|
39 |
+
text_width = self.text_encoder.config.hidden_size
|
40 |
+
|
41 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
42 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
43 |
+
|
ImageReward/models/BLIP/med.py
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
* Based on huggingface code base
|
4 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
5 |
+
'''
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor, device, nn
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import CrossEntropyLoss
|
15 |
+
|
16 |
+
from transformers.activations import ACT2FN
|
17 |
+
from transformers.file_utils import (
|
18 |
+
ModelOutput,
|
19 |
+
)
|
20 |
+
from transformers.modeling_outputs import (
|
21 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
22 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
23 |
+
CausalLMOutputWithCrossAttentions,
|
24 |
+
MaskedLMOutput,
|
25 |
+
MultipleChoiceModelOutput,
|
26 |
+
NextSentencePredictorOutput,
|
27 |
+
QuestionAnsweringModelOutput,
|
28 |
+
SequenceClassifierOutput,
|
29 |
+
TokenClassifierOutput,
|
30 |
+
)
|
31 |
+
from transformers.modeling_utils import (
|
32 |
+
PreTrainedModel,
|
33 |
+
apply_chunking_to_forward,
|
34 |
+
find_pruneable_heads_and_indices,
|
35 |
+
prune_linear_layer,
|
36 |
+
)
|
37 |
+
from transformers.utils import logging
|
38 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
39 |
+
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
class BertEmbeddings(nn.Module):
|
45 |
+
"""Construct the embeddings from word and position embeddings."""
|
46 |
+
|
47 |
+
def __init__(self, config):
|
48 |
+
super().__init__()
|
49 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
50 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
51 |
+
|
52 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
53 |
+
# any TensorFlow checkpoint file
|
54 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
55 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
56 |
+
|
57 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
58 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
59 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
60 |
+
|
61 |
+
self.config = config
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
65 |
+
):
|
66 |
+
if input_ids is not None:
|
67 |
+
input_shape = input_ids.size()
|
68 |
+
else:
|
69 |
+
input_shape = inputs_embeds.size()[:-1]
|
70 |
+
|
71 |
+
seq_length = input_shape[1]
|
72 |
+
|
73 |
+
if position_ids is None:
|
74 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
75 |
+
|
76 |
+
if inputs_embeds is None:
|
77 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
78 |
+
|
79 |
+
embeddings = inputs_embeds
|
80 |
+
|
81 |
+
if self.position_embedding_type == "absolute":
|
82 |
+
position_embeddings = self.position_embeddings(position_ids)
|
83 |
+
embeddings += position_embeddings
|
84 |
+
embeddings = self.LayerNorm(embeddings)
|
85 |
+
embeddings = self.dropout(embeddings)
|
86 |
+
return embeddings
|
87 |
+
|
88 |
+
|
89 |
+
class BertSelfAttention(nn.Module):
|
90 |
+
def __init__(self, config, is_cross_attention):
|
91 |
+
super().__init__()
|
92 |
+
self.config = config
|
93 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
94 |
+
raise ValueError(
|
95 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
96 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
97 |
+
)
|
98 |
+
|
99 |
+
self.num_attention_heads = config.num_attention_heads
|
100 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
101 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
102 |
+
|
103 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
104 |
+
if is_cross_attention:
|
105 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
106 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
107 |
+
else:
|
108 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
109 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
110 |
+
|
111 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
112 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
113 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
114 |
+
self.max_position_embeddings = config.max_position_embeddings
|
115 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
116 |
+
self.save_attention = False
|
117 |
+
|
118 |
+
def save_attn_gradients(self, attn_gradients):
|
119 |
+
self.attn_gradients = attn_gradients
|
120 |
+
|
121 |
+
def get_attn_gradients(self):
|
122 |
+
return self.attn_gradients
|
123 |
+
|
124 |
+
def save_attention_map(self, attention_map):
|
125 |
+
self.attention_map = attention_map
|
126 |
+
|
127 |
+
def get_attention_map(self):
|
128 |
+
return self.attention_map
|
129 |
+
|
130 |
+
def transpose_for_scores(self, x):
|
131 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
132 |
+
x = x.view(*new_x_shape)
|
133 |
+
return x.permute(0, 2, 1, 3)
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
hidden_states,
|
138 |
+
attention_mask=None,
|
139 |
+
head_mask=None,
|
140 |
+
encoder_hidden_states=None,
|
141 |
+
encoder_attention_mask=None,
|
142 |
+
past_key_value=None,
|
143 |
+
output_attentions=False,
|
144 |
+
):
|
145 |
+
mixed_query_layer = self.query(hidden_states)
|
146 |
+
|
147 |
+
# If this is instantiated as a cross-attention module, the keys
|
148 |
+
# and values come from an encoder; the attention mask needs to be
|
149 |
+
# such that the encoder's padding tokens are not attended to.
|
150 |
+
is_cross_attention = encoder_hidden_states is not None
|
151 |
+
|
152 |
+
if is_cross_attention:
|
153 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
154 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
155 |
+
attention_mask = encoder_attention_mask
|
156 |
+
elif past_key_value is not None:
|
157 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
158 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
159 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
160 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
161 |
+
else:
|
162 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
163 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
164 |
+
|
165 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
166 |
+
|
167 |
+
past_key_value = (key_layer, value_layer)
|
168 |
+
|
169 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
170 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
171 |
+
|
172 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
173 |
+
seq_length = hidden_states.size()[1]
|
174 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
175 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
176 |
+
distance = position_ids_l - position_ids_r
|
177 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
178 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
179 |
+
|
180 |
+
if self.position_embedding_type == "relative_key":
|
181 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
182 |
+
attention_scores = attention_scores + relative_position_scores
|
183 |
+
elif self.position_embedding_type == "relative_key_query":
|
184 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
185 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
186 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
187 |
+
|
188 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
189 |
+
if attention_mask is not None:
|
190 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
191 |
+
attention_scores = attention_scores + attention_mask
|
192 |
+
|
193 |
+
# Normalize the attention scores to probabilities.
|
194 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
195 |
+
|
196 |
+
if is_cross_attention and self.save_attention:
|
197 |
+
self.save_attention_map(attention_probs)
|
198 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
199 |
+
|
200 |
+
# This is actually dropping out entire tokens to attend to, which might
|
201 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
202 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
203 |
+
|
204 |
+
# Mask heads if we want to
|
205 |
+
if head_mask is not None:
|
206 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
207 |
+
|
208 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
209 |
+
|
210 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
211 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
212 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
213 |
+
|
214 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
215 |
+
|
216 |
+
outputs = outputs + (past_key_value,)
|
217 |
+
return outputs
|
218 |
+
|
219 |
+
|
220 |
+
class BertSelfOutput(nn.Module):
|
221 |
+
def __init__(self, config):
|
222 |
+
super().__init__()
|
223 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
224 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
225 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
226 |
+
|
227 |
+
def forward(self, hidden_states, input_tensor):
|
228 |
+
hidden_states = self.dense(hidden_states)
|
229 |
+
hidden_states = self.dropout(hidden_states)
|
230 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
231 |
+
return hidden_states
|
232 |
+
|
233 |
+
|
234 |
+
class BertAttention(nn.Module):
|
235 |
+
def __init__(self, config, is_cross_attention=False):
|
236 |
+
super().__init__()
|
237 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
238 |
+
self.output = BertSelfOutput(config)
|
239 |
+
self.pruned_heads = set()
|
240 |
+
|
241 |
+
def prune_heads(self, heads):
|
242 |
+
if len(heads) == 0:
|
243 |
+
return
|
244 |
+
heads, index = find_pruneable_heads_and_indices(
|
245 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
246 |
+
)
|
247 |
+
|
248 |
+
# Prune linear layers
|
249 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
250 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
251 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
252 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
253 |
+
|
254 |
+
# Update hyper params and store pruned heads
|
255 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
256 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
257 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
258 |
+
|
259 |
+
def forward(
|
260 |
+
self,
|
261 |
+
hidden_states,
|
262 |
+
attention_mask=None,
|
263 |
+
head_mask=None,
|
264 |
+
encoder_hidden_states=None,
|
265 |
+
encoder_attention_mask=None,
|
266 |
+
past_key_value=None,
|
267 |
+
output_attentions=False,
|
268 |
+
):
|
269 |
+
self_outputs = self.self(
|
270 |
+
hidden_states,
|
271 |
+
attention_mask,
|
272 |
+
head_mask,
|
273 |
+
encoder_hidden_states,
|
274 |
+
encoder_attention_mask,
|
275 |
+
past_key_value,
|
276 |
+
output_attentions,
|
277 |
+
)
|
278 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
279 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
280 |
+
return outputs
|
281 |
+
|
282 |
+
|
283 |
+
class BertIntermediate(nn.Module):
|
284 |
+
def __init__(self, config):
|
285 |
+
super().__init__()
|
286 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
287 |
+
if isinstance(config.hidden_act, str):
|
288 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
289 |
+
else:
|
290 |
+
self.intermediate_act_fn = config.hidden_act
|
291 |
+
|
292 |
+
def forward(self, hidden_states):
|
293 |
+
hidden_states = self.dense(hidden_states)
|
294 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
295 |
+
return hidden_states
|
296 |
+
|
297 |
+
|
298 |
+
class BertOutput(nn.Module):
|
299 |
+
def __init__(self, config):
|
300 |
+
super().__init__()
|
301 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
302 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
303 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
304 |
+
|
305 |
+
def forward(self, hidden_states, input_tensor):
|
306 |
+
hidden_states = self.dense(hidden_states)
|
307 |
+
hidden_states = self.dropout(hidden_states)
|
308 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
309 |
+
return hidden_states
|
310 |
+
|
311 |
+
|
312 |
+
class BertLayer(nn.Module):
|
313 |
+
def __init__(self, config, layer_num):
|
314 |
+
super().__init__()
|
315 |
+
self.config = config
|
316 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
317 |
+
self.seq_len_dim = 1
|
318 |
+
self.attention = BertAttention(config)
|
319 |
+
self.layer_num = layer_num
|
320 |
+
if self.config.add_cross_attention:
|
321 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
322 |
+
self.intermediate = BertIntermediate(config)
|
323 |
+
self.output = BertOutput(config)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
hidden_states,
|
328 |
+
attention_mask=None,
|
329 |
+
head_mask=None,
|
330 |
+
encoder_hidden_states=None,
|
331 |
+
encoder_attention_mask=None,
|
332 |
+
past_key_value=None,
|
333 |
+
output_attentions=False,
|
334 |
+
mode=None,
|
335 |
+
):
|
336 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
337 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
338 |
+
self_attention_outputs = self.attention(
|
339 |
+
hidden_states,
|
340 |
+
attention_mask,
|
341 |
+
head_mask,
|
342 |
+
output_attentions=output_attentions,
|
343 |
+
past_key_value=self_attn_past_key_value,
|
344 |
+
)
|
345 |
+
attention_output = self_attention_outputs[0]
|
346 |
+
|
347 |
+
outputs = self_attention_outputs[1:-1]
|
348 |
+
present_key_value = self_attention_outputs[-1]
|
349 |
+
|
350 |
+
if mode=='multimodal':
|
351 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
352 |
+
|
353 |
+
cross_attention_outputs = self.crossattention(
|
354 |
+
attention_output,
|
355 |
+
attention_mask,
|
356 |
+
head_mask,
|
357 |
+
encoder_hidden_states,
|
358 |
+
encoder_attention_mask,
|
359 |
+
output_attentions=output_attentions,
|
360 |
+
)
|
361 |
+
attention_output = cross_attention_outputs[0]
|
362 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
363 |
+
layer_output = apply_chunking_to_forward(
|
364 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
365 |
+
)
|
366 |
+
outputs = (layer_output,) + outputs
|
367 |
+
|
368 |
+
outputs = outputs + (present_key_value,)
|
369 |
+
|
370 |
+
return outputs
|
371 |
+
|
372 |
+
def feed_forward_chunk(self, attention_output):
|
373 |
+
intermediate_output = self.intermediate(attention_output)
|
374 |
+
layer_output = self.output(intermediate_output, attention_output)
|
375 |
+
return layer_output
|
376 |
+
|
377 |
+
|
378 |
+
class BertEncoder(nn.Module):
|
379 |
+
def __init__(self, config):
|
380 |
+
super().__init__()
|
381 |
+
self.config = config
|
382 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
383 |
+
self.gradient_checkpointing = False
|
384 |
+
|
385 |
+
def forward(
|
386 |
+
self,
|
387 |
+
hidden_states,
|
388 |
+
attention_mask=None,
|
389 |
+
head_mask=None,
|
390 |
+
encoder_hidden_states=None,
|
391 |
+
encoder_attention_mask=None,
|
392 |
+
past_key_values=None,
|
393 |
+
use_cache=None,
|
394 |
+
output_attentions=False,
|
395 |
+
output_hidden_states=False,
|
396 |
+
return_dict=True,
|
397 |
+
mode='multimodal',
|
398 |
+
):
|
399 |
+
all_hidden_states = () if output_hidden_states else None
|
400 |
+
all_self_attentions = () if output_attentions else None
|
401 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
402 |
+
|
403 |
+
next_decoder_cache = () if use_cache else None
|
404 |
+
|
405 |
+
for i in range(self.config.num_hidden_layers):
|
406 |
+
layer_module = self.layer[i]
|
407 |
+
if output_hidden_states:
|
408 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
409 |
+
|
410 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
411 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
412 |
+
|
413 |
+
if self.gradient_checkpointing and self.training:
|
414 |
+
|
415 |
+
if use_cache:
|
416 |
+
logger.warn(
|
417 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
418 |
+
)
|
419 |
+
use_cache = False
|
420 |
+
|
421 |
+
def create_custom_forward(module):
|
422 |
+
def custom_forward(*inputs):
|
423 |
+
return module(*inputs, past_key_value, output_attentions)
|
424 |
+
|
425 |
+
return custom_forward
|
426 |
+
|
427 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
428 |
+
create_custom_forward(layer_module),
|
429 |
+
hidden_states,
|
430 |
+
attention_mask,
|
431 |
+
layer_head_mask,
|
432 |
+
encoder_hidden_states,
|
433 |
+
encoder_attention_mask,
|
434 |
+
mode=mode,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
layer_outputs = layer_module(
|
438 |
+
hidden_states,
|
439 |
+
attention_mask,
|
440 |
+
layer_head_mask,
|
441 |
+
encoder_hidden_states,
|
442 |
+
encoder_attention_mask,
|
443 |
+
past_key_value,
|
444 |
+
output_attentions,
|
445 |
+
mode=mode,
|
446 |
+
)
|
447 |
+
|
448 |
+
hidden_states = layer_outputs[0]
|
449 |
+
if use_cache:
|
450 |
+
next_decoder_cache += (layer_outputs[-1],)
|
451 |
+
if output_attentions:
|
452 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
453 |
+
|
454 |
+
if output_hidden_states:
|
455 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
456 |
+
|
457 |
+
if not return_dict:
|
458 |
+
return tuple(
|
459 |
+
v
|
460 |
+
for v in [
|
461 |
+
hidden_states,
|
462 |
+
next_decoder_cache,
|
463 |
+
all_hidden_states,
|
464 |
+
all_self_attentions,
|
465 |
+
all_cross_attentions,
|
466 |
+
]
|
467 |
+
if v is not None
|
468 |
+
)
|
469 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
470 |
+
last_hidden_state=hidden_states,
|
471 |
+
past_key_values=next_decoder_cache,
|
472 |
+
hidden_states=all_hidden_states,
|
473 |
+
attentions=all_self_attentions,
|
474 |
+
cross_attentions=all_cross_attentions,
|
475 |
+
)
|
476 |
+
|
477 |
+
|
478 |
+
class BertPooler(nn.Module):
|
479 |
+
def __init__(self, config):
|
480 |
+
super().__init__()
|
481 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
482 |
+
self.activation = nn.Tanh()
|
483 |
+
|
484 |
+
def forward(self, hidden_states):
|
485 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
486 |
+
# to the first token.
|
487 |
+
first_token_tensor = hidden_states[:, 0]
|
488 |
+
pooled_output = self.dense(first_token_tensor)
|
489 |
+
pooled_output = self.activation(pooled_output)
|
490 |
+
return pooled_output
|
491 |
+
|
492 |
+
|
493 |
+
class BertPredictionHeadTransform(nn.Module):
|
494 |
+
def __init__(self, config):
|
495 |
+
super().__init__()
|
496 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
497 |
+
if isinstance(config.hidden_act, str):
|
498 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
499 |
+
else:
|
500 |
+
self.transform_act_fn = config.hidden_act
|
501 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
502 |
+
|
503 |
+
def forward(self, hidden_states):
|
504 |
+
hidden_states = self.dense(hidden_states)
|
505 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
506 |
+
hidden_states = self.LayerNorm(hidden_states)
|
507 |
+
return hidden_states
|
508 |
+
|
509 |
+
|
510 |
+
class BertLMPredictionHead(nn.Module):
|
511 |
+
def __init__(self, config):
|
512 |
+
super().__init__()
|
513 |
+
self.transform = BertPredictionHeadTransform(config)
|
514 |
+
|
515 |
+
# The output weights are the same as the input embeddings, but there is
|
516 |
+
# an output-only bias for each token.
|
517 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
518 |
+
|
519 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
520 |
+
|
521 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
522 |
+
self.decoder.bias = self.bias
|
523 |
+
|
524 |
+
def forward(self, hidden_states):
|
525 |
+
hidden_states = self.transform(hidden_states)
|
526 |
+
hidden_states = self.decoder(hidden_states)
|
527 |
+
return hidden_states
|
528 |
+
|
529 |
+
|
530 |
+
class BertOnlyMLMHead(nn.Module):
|
531 |
+
def __init__(self, config):
|
532 |
+
super().__init__()
|
533 |
+
self.predictions = BertLMPredictionHead(config)
|
534 |
+
|
535 |
+
def forward(self, sequence_output):
|
536 |
+
prediction_scores = self.predictions(sequence_output)
|
537 |
+
return prediction_scores
|
538 |
+
|
539 |
+
|
540 |
+
class BertPreTrainedModel(PreTrainedModel):
|
541 |
+
"""
|
542 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
543 |
+
models.
|
544 |
+
"""
|
545 |
+
|
546 |
+
config_class = BertConfig
|
547 |
+
base_model_prefix = "bert"
|
548 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
549 |
+
|
550 |
+
def _init_weights(self, module):
|
551 |
+
""" Initialize the weights """
|
552 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
553 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
554 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
555 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
556 |
+
elif isinstance(module, nn.LayerNorm):
|
557 |
+
module.bias.data.zero_()
|
558 |
+
module.weight.data.fill_(1.0)
|
559 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
560 |
+
module.bias.data.zero_()
|
561 |
+
|
562 |
+
|
563 |
+
class BertModel(BertPreTrainedModel):
|
564 |
+
"""
|
565 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
566 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
567 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
568 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
569 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
570 |
+
input to the forward pass.
|
571 |
+
"""
|
572 |
+
|
573 |
+
def __init__(self, config, add_pooling_layer=True):
|
574 |
+
super().__init__(config)
|
575 |
+
self.config = config
|
576 |
+
|
577 |
+
self.embeddings = BertEmbeddings(config)
|
578 |
+
|
579 |
+
self.encoder = BertEncoder(config)
|
580 |
+
|
581 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
582 |
+
|
583 |
+
self.init_weights()
|
584 |
+
|
585 |
+
|
586 |
+
def get_input_embeddings(self):
|
587 |
+
return self.embeddings.word_embeddings
|
588 |
+
|
589 |
+
def set_input_embeddings(self, value):
|
590 |
+
self.embeddings.word_embeddings = value
|
591 |
+
|
592 |
+
def _prune_heads(self, heads_to_prune):
|
593 |
+
"""
|
594 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
595 |
+
class PreTrainedModel
|
596 |
+
"""
|
597 |
+
for layer, heads in heads_to_prune.items():
|
598 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
599 |
+
|
600 |
+
|
601 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
602 |
+
"""
|
603 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
604 |
+
|
605 |
+
Arguments:
|
606 |
+
attention_mask (:obj:`torch.Tensor`):
|
607 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
608 |
+
input_shape (:obj:`Tuple[int]`):
|
609 |
+
The shape of the input to the model.
|
610 |
+
device: (:obj:`torch.device`):
|
611 |
+
The device of the input to the model.
|
612 |
+
|
613 |
+
Returns:
|
614 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
615 |
+
"""
|
616 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
617 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
618 |
+
if attention_mask.dim() == 3:
|
619 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
620 |
+
elif attention_mask.dim() == 2:
|
621 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
622 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
623 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
624 |
+
if is_decoder:
|
625 |
+
batch_size, seq_length = input_shape
|
626 |
+
|
627 |
+
seq_ids = torch.arange(seq_length, device=device)
|
628 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
629 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
630 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
631 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
632 |
+
|
633 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
634 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
635 |
+
causal_mask = torch.cat(
|
636 |
+
[
|
637 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
638 |
+
causal_mask,
|
639 |
+
],
|
640 |
+
axis=-1,
|
641 |
+
)
|
642 |
+
|
643 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
644 |
+
else:
|
645 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
646 |
+
else:
|
647 |
+
raise ValueError(
|
648 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
649 |
+
input_shape, attention_mask.shape
|
650 |
+
)
|
651 |
+
)
|
652 |
+
|
653 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
654 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
655 |
+
# positions we want to attend and -10000.0 for masked positions.
|
656 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
657 |
+
# effectively the same as removing these entirely.
|
658 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
659 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
660 |
+
return extended_attention_mask
|
661 |
+
|
662 |
+
def forward(
|
663 |
+
self,
|
664 |
+
input_ids=None,
|
665 |
+
attention_mask=None,
|
666 |
+
position_ids=None,
|
667 |
+
head_mask=None,
|
668 |
+
inputs_embeds=None,
|
669 |
+
encoder_embeds=None,
|
670 |
+
encoder_hidden_states=None,
|
671 |
+
encoder_attention_mask=None,
|
672 |
+
past_key_values=None,
|
673 |
+
use_cache=None,
|
674 |
+
output_attentions=None,
|
675 |
+
output_hidden_states=None,
|
676 |
+
return_dict=None,
|
677 |
+
is_decoder=False,
|
678 |
+
mode='multimodal',
|
679 |
+
):
|
680 |
+
r"""
|
681 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
682 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
683 |
+
the model is configured as a decoder.
|
684 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
685 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
686 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
687 |
+
- 1 for tokens that are **not masked**,
|
688 |
+
- 0 for tokens that are **masked**.
|
689 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
690 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
691 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
692 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
693 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
694 |
+
use_cache (:obj:`bool`, `optional`):
|
695 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
696 |
+
decoding (see :obj:`past_key_values`).
|
697 |
+
"""
|
698 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
699 |
+
output_hidden_states = (
|
700 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
701 |
+
)
|
702 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
703 |
+
|
704 |
+
if is_decoder:
|
705 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
706 |
+
else:
|
707 |
+
use_cache = False
|
708 |
+
|
709 |
+
if input_ids is not None and inputs_embeds is not None:
|
710 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
711 |
+
elif input_ids is not None:
|
712 |
+
input_shape = input_ids.size()
|
713 |
+
batch_size, seq_length = input_shape
|
714 |
+
device = input_ids.device
|
715 |
+
elif inputs_embeds is not None:
|
716 |
+
input_shape = inputs_embeds.size()[:-1]
|
717 |
+
batch_size, seq_length = input_shape
|
718 |
+
device = inputs_embeds.device
|
719 |
+
elif encoder_embeds is not None:
|
720 |
+
input_shape = encoder_embeds.size()[:-1]
|
721 |
+
batch_size, seq_length = input_shape
|
722 |
+
device = encoder_embeds.device
|
723 |
+
else:
|
724 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
725 |
+
|
726 |
+
# past_key_values_length
|
727 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
728 |
+
|
729 |
+
if attention_mask is None:
|
730 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
731 |
+
|
732 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
733 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
734 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
735 |
+
device, is_decoder)
|
736 |
+
|
737 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
738 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
739 |
+
if encoder_hidden_states is not None:
|
740 |
+
if type(encoder_hidden_states) == list:
|
741 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
742 |
+
else:
|
743 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
744 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
745 |
+
|
746 |
+
if type(encoder_attention_mask) == list:
|
747 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
748 |
+
elif encoder_attention_mask is None:
|
749 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
750 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
751 |
+
else:
|
752 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
753 |
+
else:
|
754 |
+
encoder_extended_attention_mask = None
|
755 |
+
|
756 |
+
# Prepare head mask if needed
|
757 |
+
# 1.0 in head_mask indicate we keep the head
|
758 |
+
# attention_probs has shape bsz x n_heads x N x N
|
759 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
760 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
761 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
762 |
+
|
763 |
+
if encoder_embeds is None:
|
764 |
+
embedding_output = self.embeddings(
|
765 |
+
input_ids=input_ids,
|
766 |
+
position_ids=position_ids,
|
767 |
+
inputs_embeds=inputs_embeds,
|
768 |
+
past_key_values_length=past_key_values_length,
|
769 |
+
)
|
770 |
+
else:
|
771 |
+
embedding_output = encoder_embeds
|
772 |
+
|
773 |
+
encoder_outputs = self.encoder(
|
774 |
+
embedding_output,
|
775 |
+
attention_mask=extended_attention_mask,
|
776 |
+
head_mask=head_mask,
|
777 |
+
encoder_hidden_states=encoder_hidden_states,
|
778 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
779 |
+
past_key_values=past_key_values,
|
780 |
+
use_cache=use_cache,
|
781 |
+
output_attentions=output_attentions,
|
782 |
+
output_hidden_states=output_hidden_states,
|
783 |
+
return_dict=return_dict,
|
784 |
+
mode=mode,
|
785 |
+
)
|
786 |
+
sequence_output = encoder_outputs[0]
|
787 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
788 |
+
|
789 |
+
if not return_dict:
|
790 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
791 |
+
|
792 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
793 |
+
last_hidden_state=sequence_output,
|
794 |
+
pooler_output=pooled_output,
|
795 |
+
past_key_values=encoder_outputs.past_key_values,
|
796 |
+
hidden_states=encoder_outputs.hidden_states,
|
797 |
+
attentions=encoder_outputs.attentions,
|
798 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
799 |
+
)
|
800 |
+
|
801 |
+
|
802 |
+
|
803 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
804 |
+
|
805 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
806 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
807 |
+
|
808 |
+
def __init__(self, config):
|
809 |
+
super().__init__(config)
|
810 |
+
|
811 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
812 |
+
self.cls = BertOnlyMLMHead(config)
|
813 |
+
|
814 |
+
self.init_weights()
|
815 |
+
|
816 |
+
def get_output_embeddings(self):
|
817 |
+
return self.cls.predictions.decoder
|
818 |
+
|
819 |
+
def set_output_embeddings(self, new_embeddings):
|
820 |
+
self.cls.predictions.decoder = new_embeddings
|
821 |
+
|
822 |
+
def forward(
|
823 |
+
self,
|
824 |
+
input_ids=None,
|
825 |
+
attention_mask=None,
|
826 |
+
position_ids=None,
|
827 |
+
head_mask=None,
|
828 |
+
inputs_embeds=None,
|
829 |
+
encoder_hidden_states=None,
|
830 |
+
encoder_attention_mask=None,
|
831 |
+
labels=None,
|
832 |
+
past_key_values=None,
|
833 |
+
use_cache=None,
|
834 |
+
output_attentions=None,
|
835 |
+
output_hidden_states=None,
|
836 |
+
return_dict=None,
|
837 |
+
return_logits=False,
|
838 |
+
is_decoder=True,
|
839 |
+
reduction='mean',
|
840 |
+
mode='multimodal',
|
841 |
+
):
|
842 |
+
r"""
|
843 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
844 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
845 |
+
the model is configured as a decoder.
|
846 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
847 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
848 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
849 |
+
- 1 for tokens that are **not masked**,
|
850 |
+
- 0 for tokens that are **masked**.
|
851 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
852 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
853 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
854 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
855 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
856 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
857 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
858 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
859 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
860 |
+
use_cache (:obj:`bool`, `optional`):
|
861 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
862 |
+
decoding (see :obj:`past_key_values`).
|
863 |
+
Returns:
|
864 |
+
Example::
|
865 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
866 |
+
>>> import torch
|
867 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
868 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
869 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
870 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
871 |
+
>>> outputs = model(**inputs)
|
872 |
+
>>> prediction_logits = outputs.logits
|
873 |
+
"""
|
874 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
875 |
+
if labels is not None:
|
876 |
+
use_cache = False
|
877 |
+
|
878 |
+
outputs = self.bert(
|
879 |
+
input_ids,
|
880 |
+
attention_mask=attention_mask,
|
881 |
+
position_ids=position_ids,
|
882 |
+
head_mask=head_mask,
|
883 |
+
inputs_embeds=inputs_embeds,
|
884 |
+
encoder_hidden_states=encoder_hidden_states,
|
885 |
+
encoder_attention_mask=encoder_attention_mask,
|
886 |
+
past_key_values=past_key_values,
|
887 |
+
use_cache=use_cache,
|
888 |
+
output_attentions=output_attentions,
|
889 |
+
output_hidden_states=output_hidden_states,
|
890 |
+
return_dict=return_dict,
|
891 |
+
is_decoder=is_decoder,
|
892 |
+
mode=mode,
|
893 |
+
)
|
894 |
+
|
895 |
+
sequence_output = outputs[0]
|
896 |
+
prediction_scores = self.cls(sequence_output)
|
897 |
+
|
898 |
+
if return_logits:
|
899 |
+
return prediction_scores[:, :-1, :].contiguous()
|
900 |
+
|
901 |
+
lm_loss = None
|
902 |
+
if labels is not None:
|
903 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
904 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
905 |
+
labels = labels[:, 1:].contiguous()
|
906 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
907 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
908 |
+
if reduction=='none':
|
909 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
910 |
+
|
911 |
+
if not return_dict:
|
912 |
+
output = (prediction_scores,) + outputs[2:]
|
913 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
914 |
+
|
915 |
+
return CausalLMOutputWithCrossAttentions(
|
916 |
+
loss=lm_loss,
|
917 |
+
logits=prediction_scores,
|
918 |
+
past_key_values=outputs.past_key_values,
|
919 |
+
hidden_states=outputs.hidden_states,
|
920 |
+
attentions=outputs.attentions,
|
921 |
+
cross_attentions=outputs.cross_attentions,
|
922 |
+
)
|
923 |
+
|
924 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
925 |
+
input_shape = input_ids.shape
|
926 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
927 |
+
if attention_mask is None:
|
928 |
+
attention_mask = input_ids.new_ones(input_shape)
|
929 |
+
|
930 |
+
# cut decoder_input_ids if past is used
|
931 |
+
if past is not None:
|
932 |
+
input_ids = input_ids[:, -1:]
|
933 |
+
|
934 |
+
return {
|
935 |
+
"input_ids": input_ids,
|
936 |
+
"attention_mask": attention_mask,
|
937 |
+
"past_key_values": past,
|
938 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
939 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
940 |
+
"is_decoder": True,
|
941 |
+
}
|
942 |
+
|
943 |
+
def _reorder_cache(self, past, beam_idx):
|
944 |
+
reordered_past = ()
|
945 |
+
for layer_past in past:
|
946 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
947 |
+
return reordered_past
|
ImageReward/models/BLIP/vit.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
* Based on timm code base
|
4 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
5 |
+
'''
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from functools import partial
|
11 |
+
|
12 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
13 |
+
from timm.models.registry import register_model
|
14 |
+
from timm.models.layers import trunc_normal_, DropPath
|
15 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
16 |
+
|
17 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
18 |
+
|
19 |
+
class Mlp(nn.Module):
|
20 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
21 |
+
"""
|
22 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
27 |
+
self.act = act_layer()
|
28 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
29 |
+
self.drop = nn.Dropout(drop)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.fc1(x)
|
33 |
+
x = self.act(x)
|
34 |
+
x = self.drop(x)
|
35 |
+
x = self.fc2(x)
|
36 |
+
x = self.drop(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class Attention(nn.Module):
|
41 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
42 |
+
super().__init__()
|
43 |
+
self.num_heads = num_heads
|
44 |
+
head_dim = dim // num_heads
|
45 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
46 |
+
self.scale = qk_scale or head_dim ** -0.5
|
47 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
48 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
49 |
+
self.proj = nn.Linear(dim, dim)
|
50 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
51 |
+
self.attn_gradients = None
|
52 |
+
self.attention_map = None
|
53 |
+
|
54 |
+
def save_attn_gradients(self, attn_gradients):
|
55 |
+
self.attn_gradients = attn_gradients
|
56 |
+
|
57 |
+
def get_attn_gradients(self):
|
58 |
+
return self.attn_gradients
|
59 |
+
|
60 |
+
def save_attention_map(self, attention_map):
|
61 |
+
self.attention_map = attention_map
|
62 |
+
|
63 |
+
def get_attention_map(self):
|
64 |
+
return self.attention_map
|
65 |
+
|
66 |
+
def forward(self, x, register_hook=False):
|
67 |
+
B, N, C = x.shape
|
68 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
69 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
70 |
+
|
71 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
72 |
+
attn = attn.softmax(dim=-1)
|
73 |
+
attn = self.attn_drop(attn)
|
74 |
+
|
75 |
+
if register_hook:
|
76 |
+
self.save_attention_map(attn)
|
77 |
+
attn.register_hook(self.save_attn_gradients)
|
78 |
+
|
79 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
80 |
+
x = self.proj(x)
|
81 |
+
x = self.proj_drop(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class Block(nn.Module):
|
86 |
+
|
87 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
88 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
89 |
+
super().__init__()
|
90 |
+
self.norm1 = norm_layer(dim)
|
91 |
+
self.attn = Attention(
|
92 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
93 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
94 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
95 |
+
self.norm2 = norm_layer(dim)
|
96 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
97 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
98 |
+
|
99 |
+
if use_grad_checkpointing:
|
100 |
+
self.attn = checkpoint_wrapper(self.attn)
|
101 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
102 |
+
|
103 |
+
def forward(self, x, register_hook=False):
|
104 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
105 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class VisionTransformer(nn.Module):
|
110 |
+
""" Vision Transformer
|
111 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
112 |
+
https://arxiv.org/abs/2010.11929
|
113 |
+
"""
|
114 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
115 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
116 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
117 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
img_size (int, tuple): input image size
|
121 |
+
patch_size (int, tuple): patch size
|
122 |
+
in_chans (int): number of input channels
|
123 |
+
num_classes (int): number of classes for classification head
|
124 |
+
embed_dim (int): embedding dimension
|
125 |
+
depth (int): depth of transformer
|
126 |
+
num_heads (int): number of attention heads
|
127 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
128 |
+
qkv_bias (bool): enable bias for qkv if True
|
129 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
130 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
131 |
+
drop_rate (float): dropout rate
|
132 |
+
attn_drop_rate (float): attention dropout rate
|
133 |
+
drop_path_rate (float): stochastic depth rate
|
134 |
+
norm_layer: (nn.Module): normalization layer
|
135 |
+
"""
|
136 |
+
super().__init__()
|
137 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
138 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
139 |
+
|
140 |
+
self.patch_embed = PatchEmbed(
|
141 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
142 |
+
|
143 |
+
num_patches = self.patch_embed.num_patches
|
144 |
+
|
145 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
146 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
147 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
148 |
+
|
149 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
150 |
+
self.blocks = nn.ModuleList([
|
151 |
+
Block(
|
152 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
153 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
154 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
155 |
+
)
|
156 |
+
for i in range(depth)])
|
157 |
+
self.norm = norm_layer(embed_dim)
|
158 |
+
|
159 |
+
trunc_normal_(self.pos_embed, std=.02)
|
160 |
+
trunc_normal_(self.cls_token, std=.02)
|
161 |
+
self.apply(self._init_weights)
|
162 |
+
|
163 |
+
def _init_weights(self, m):
|
164 |
+
if isinstance(m, nn.Linear):
|
165 |
+
trunc_normal_(m.weight, std=.02)
|
166 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
167 |
+
nn.init.constant_(m.bias, 0)
|
168 |
+
elif isinstance(m, nn.LayerNorm):
|
169 |
+
nn.init.constant_(m.bias, 0)
|
170 |
+
nn.init.constant_(m.weight, 1.0)
|
171 |
+
|
172 |
+
@torch.jit.ignore
|
173 |
+
def no_weight_decay(self):
|
174 |
+
return {'pos_embed', 'cls_token'}
|
175 |
+
|
176 |
+
def forward(self, x, register_blk=-1):
|
177 |
+
B = x.shape[0]
|
178 |
+
x = self.patch_embed(x)
|
179 |
+
|
180 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
181 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
182 |
+
|
183 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
184 |
+
x = self.pos_drop(x)
|
185 |
+
|
186 |
+
for i,blk in enumerate(self.blocks):
|
187 |
+
x = blk(x, register_blk==i)
|
188 |
+
x = self.norm(x)
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
@torch.jit.ignore()
|
193 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
194 |
+
_load_weights(self, checkpoint_path, prefix)
|
195 |
+
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
199 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
200 |
+
"""
|
201 |
+
import numpy as np
|
202 |
+
|
203 |
+
def _n2p(w, t=True):
|
204 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
205 |
+
w = w.flatten()
|
206 |
+
if t:
|
207 |
+
if w.ndim == 4:
|
208 |
+
w = w.transpose([3, 2, 0, 1])
|
209 |
+
elif w.ndim == 3:
|
210 |
+
w = w.transpose([2, 0, 1])
|
211 |
+
elif w.ndim == 2:
|
212 |
+
w = w.transpose([1, 0])
|
213 |
+
return torch.from_numpy(w)
|
214 |
+
|
215 |
+
w = np.load(checkpoint_path)
|
216 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
217 |
+
prefix = 'opt/target/'
|
218 |
+
|
219 |
+
if hasattr(model.patch_embed, 'backbone'):
|
220 |
+
# hybrid
|
221 |
+
backbone = model.patch_embed.backbone
|
222 |
+
stem_only = not hasattr(backbone, 'stem')
|
223 |
+
stem = backbone if stem_only else backbone.stem
|
224 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
225 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
226 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
227 |
+
if not stem_only:
|
228 |
+
for i, stage in enumerate(backbone.stages):
|
229 |
+
for j, block in enumerate(stage.blocks):
|
230 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
231 |
+
for r in range(3):
|
232 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
233 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
234 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
235 |
+
if block.downsample is not None:
|
236 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
237 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
238 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
239 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
240 |
+
else:
|
241 |
+
embed_conv_w = adapt_input_conv(
|
242 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
243 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
244 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
245 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
246 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
247 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
248 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
249 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
250 |
+
model.pos_embed.copy_(pos_embed_w)
|
251 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
252 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
253 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
254 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
255 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
256 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
257 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
258 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
259 |
+
for i, block in enumerate(model.blocks.children()):
|
260 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
261 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
262 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
263 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
264 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
265 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
266 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
267 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
268 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
269 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
270 |
+
for r in range(2):
|
271 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
272 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
273 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
274 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
275 |
+
|
276 |
+
|
277 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
278 |
+
# interpolate position embedding
|
279 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
280 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
281 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
282 |
+
# height (== width) for the checkpoint position embedding
|
283 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
284 |
+
# height (== width) for the new position embedding
|
285 |
+
new_size = int(num_patches ** 0.5)
|
286 |
+
|
287 |
+
if orig_size!=new_size:
|
288 |
+
# class_token and dist_token are kept unchanged
|
289 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
290 |
+
# only the position tokens are interpolated
|
291 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
292 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
293 |
+
pos_tokens = torch.nn.functional.interpolate(
|
294 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
295 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
296 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
297 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
298 |
+
|
299 |
+
return new_pos_embed
|
300 |
+
else:
|
301 |
+
return pos_embed_checkpoint
|
ImageReward/models/BLIPScore.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : BLIPScore.py
|
3 |
+
@Time : 2023/02/19 20:48:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: BLIPScore.
|
7 |
+
* Based on BLIP code base
|
8 |
+
* https://github.com/salesforce/BLIP
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from PIL import Image
|
15 |
+
from ImageReward.models.BLIP.blip_pretrain import BLIP_Pretrain
|
16 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
17 |
+
|
18 |
+
try:
|
19 |
+
from torchvision.transforms import InterpolationMode
|
20 |
+
BICUBIC = InterpolationMode.BICUBIC
|
21 |
+
except ImportError:
|
22 |
+
BICUBIC = Image.BICUBIC
|
23 |
+
|
24 |
+
|
25 |
+
def _convert_image_to_rgb(image):
|
26 |
+
return image.convert("RGB")
|
27 |
+
|
28 |
+
|
29 |
+
def _transform(n_px):
|
30 |
+
return Compose([
|
31 |
+
Resize(n_px, interpolation=BICUBIC),
|
32 |
+
CenterCrop(n_px),
|
33 |
+
_convert_image_to_rgb,
|
34 |
+
ToTensor(),
|
35 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
36 |
+
])
|
37 |
+
|
38 |
+
|
39 |
+
class BLIPScore(nn.Module):
|
40 |
+
def __init__(self, med_config, device='cpu'):
|
41 |
+
super().__init__()
|
42 |
+
self.device = device
|
43 |
+
|
44 |
+
self.preprocess = _transform(224)
|
45 |
+
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
46 |
+
|
47 |
+
|
48 |
+
def score(self, prompt, image_path):
|
49 |
+
|
50 |
+
if (type(image_path).__name__=='list'):
|
51 |
+
_, rewards = self.inference_rank(prompt, image_path)
|
52 |
+
return rewards
|
53 |
+
|
54 |
+
# text encode
|
55 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
56 |
+
text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
57 |
+
txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
|
58 |
+
|
59 |
+
# image encode
|
60 |
+
pil_image = Image.open(image_path)
|
61 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
62 |
+
image_embeds = self.blip.visual_encoder(image)
|
63 |
+
image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
|
64 |
+
|
65 |
+
# score
|
66 |
+
rewards = torch.sum(torch.mul(txt_feature, image_features), dim=1, keepdim=True)
|
67 |
+
|
68 |
+
return rewards.detach().cpu().numpy().item()
|
69 |
+
|
70 |
+
|
71 |
+
def inference_rank(self, prompt, generations_list):
|
72 |
+
|
73 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
74 |
+
text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
75 |
+
txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
|
76 |
+
|
77 |
+
txt_set = []
|
78 |
+
img_set = []
|
79 |
+
for generations in generations_list:
|
80 |
+
# image encode
|
81 |
+
img_path = generations
|
82 |
+
pil_image = Image.open(img_path)
|
83 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
84 |
+
image_embeds = self.blip.visual_encoder(image)
|
85 |
+
image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
|
86 |
+
img_set.append(image_features)
|
87 |
+
txt_set.append(txt_feature)
|
88 |
+
|
89 |
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
90 |
+
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
|
91 |
+
rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
|
92 |
+
rewards = torch.squeeze(rewards)
|
93 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
94 |
+
_, indices = torch.sort(rank, dim=0)
|
95 |
+
indices = indices + 1
|
96 |
+
|
97 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/models/CLIPScore.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : CLIPScore.py
|
3 |
+
@Time : 2023/02/12 13:14:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: CLIPScore.
|
7 |
+
* Based on CLIP code base
|
8 |
+
* https://github.com/openai/CLIP
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from PIL import Image
|
15 |
+
import clip
|
16 |
+
|
17 |
+
class CLIPScore(nn.Module):
|
18 |
+
def __init__(self, download_root, device='cpu'):
|
19 |
+
super().__init__()
|
20 |
+
self.device = device
|
21 |
+
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
|
22 |
+
download_root=download_root)
|
23 |
+
|
24 |
+
if device == "cpu":
|
25 |
+
self.clip_model.float()
|
26 |
+
else:
|
27 |
+
clip.model.convert_weights(self.clip_model) # Actually this line is unnecessary since clip by default already on float16
|
28 |
+
|
29 |
+
# have clip.logit_scale require no grad.
|
30 |
+
self.clip_model.logit_scale.requires_grad_(False)
|
31 |
+
|
32 |
+
|
33 |
+
def score(self, prompt, image_path):
|
34 |
+
|
35 |
+
if (type(image_path).__name__=='list'):
|
36 |
+
_, rewards = self.inference_rank(prompt, image_path)
|
37 |
+
return rewards
|
38 |
+
|
39 |
+
# text encode
|
40 |
+
text = clip.tokenize(prompt, truncate=True).to(self.device)
|
41 |
+
txt_features = F.normalize(self.clip_model.encode_text(text))
|
42 |
+
|
43 |
+
# image encode
|
44 |
+
pil_image = Image.open(image_path)
|
45 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
46 |
+
image_features = F.normalize(self.clip_model.encode_image(image))
|
47 |
+
|
48 |
+
# score
|
49 |
+
rewards = torch.sum(torch.mul(txt_features, image_features), dim=1, keepdim=True)
|
50 |
+
|
51 |
+
return rewards.detach().cpu().numpy().item()
|
52 |
+
|
53 |
+
|
54 |
+
def inference_rank(self, prompt, generations_list):
|
55 |
+
|
56 |
+
text = clip.tokenize(prompt, truncate=True).to(self.device)
|
57 |
+
txt_feature = F.normalize(self.clip_model.encode_text(text))
|
58 |
+
|
59 |
+
txt_set = []
|
60 |
+
img_set = []
|
61 |
+
for generations in generations_list:
|
62 |
+
# image encode
|
63 |
+
img_path = generations
|
64 |
+
pil_image = Image.open(img_path)
|
65 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
66 |
+
image_features = F.normalize(self.clip_model.encode_image(image))
|
67 |
+
img_set.append(image_features)
|
68 |
+
txt_set.append(txt_feature)
|
69 |
+
|
70 |
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
71 |
+
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
|
72 |
+
rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
|
73 |
+
rewards = torch.squeeze(rewards)
|
74 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
75 |
+
_, indices = torch.sort(rank, dim=0)
|
76 |
+
indices = indices + 1
|
77 |
+
|
78 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .AestheticScore import *
|
2 |
+
from .BLIPScore import *
|
3 |
+
from .CLIPScore import *
|
4 |
+
from .BLIP import *
|
ImageReward/utils.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : utils.py
|
3 |
+
@Time : 2023/04/05 19:18:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
* Based on CLIP code base
|
7 |
+
* https://github.com/openai/CLIP
|
8 |
+
* Checkpoint of CLIP/BLIP/Aesthetic are from:
|
9 |
+
* https://github.com/openai/CLIP
|
10 |
+
* https://github.com/salesforce/BLIP
|
11 |
+
* https://github.com/christophschuhmann/improved-aesthetic-predictor
|
12 |
+
'''
|
13 |
+
|
14 |
+
import os
|
15 |
+
import urllib
|
16 |
+
from typing import Union, List
|
17 |
+
import pathlib
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from tqdm import tqdm
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
|
23 |
+
from .ImageReward import ImageReward
|
24 |
+
from .models.CLIPScore import CLIPScore
|
25 |
+
from .models.BLIPScore import BLIPScore
|
26 |
+
from .models.AestheticScore import AestheticScore
|
27 |
+
|
28 |
+
_MODELS = {
|
29 |
+
"ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt",
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def available_models() -> List[str]:
|
34 |
+
"""Returns the names of available ImageReward models"""
|
35 |
+
return list(_MODELS.keys())
|
36 |
+
|
37 |
+
|
38 |
+
def ImageReward_download(url: str, root: str):
|
39 |
+
os.makedirs(root, exist_ok=True)
|
40 |
+
filename = os.path.basename(url)
|
41 |
+
download_target = os.path.join(root, filename)
|
42 |
+
hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root)
|
43 |
+
return download_target
|
44 |
+
|
45 |
+
|
46 |
+
def load(name: str = "ImageReward-v1.0",
|
47 |
+
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
48 |
+
download_root: str = None,
|
49 |
+
med_config_path: str = None):
|
50 |
+
"""Load a ImageReward model
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
name: str
|
55 |
+
A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
|
56 |
+
device: Union[str, torch.device]
|
57 |
+
The device to put the loaded model
|
58 |
+
download_root: str
|
59 |
+
path to download the model files; by default, it uses "~/.cache/ImageReward"
|
60 |
+
med_config_path: str
|
61 |
+
|
62 |
+
Returns
|
63 |
+
-------
|
64 |
+
model : torch.nn.Module
|
65 |
+
The ImageReward model
|
66 |
+
"""
|
67 |
+
if name in _MODELS:
|
68 |
+
download_root = download_root or "~/.cache/ImageReward"
|
69 |
+
download_root = pathlib.Path(download_root)
|
70 |
+
model_path = pathlib.Path(download_root) / 'ImageReward.pt'
|
71 |
+
|
72 |
+
if not model_path.exists():
|
73 |
+
model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix())
|
74 |
+
elif os.path.isfile(name):
|
75 |
+
model_path = name
|
76 |
+
else:
|
77 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
78 |
+
|
79 |
+
print('-> load ImageReward model from %s' % model_path)
|
80 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
81 |
+
|
82 |
+
# med_config
|
83 |
+
if med_config_path is None:
|
84 |
+
med_config_root = download_root or "~/.cache/ImageReward"
|
85 |
+
med_config_root = pathlib.Path(med_config_root)
|
86 |
+
med_config_path = med_config_root / 'med_config.json'
|
87 |
+
|
88 |
+
if not med_config_path.exists():
|
89 |
+
med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
|
90 |
+
root=med_config_root.as_posix())
|
91 |
+
print('-> load ImageReward med_config from %s' % med_config_path)
|
92 |
+
|
93 |
+
model = ImageReward(device=device, med_config=med_config_path).to(device)
|
94 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
95 |
+
model.eval()
|
96 |
+
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
_SCORES = {
|
101 |
+
"CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
102 |
+
"BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
|
103 |
+
"Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def available_scores() -> List[str]:
|
108 |
+
"""Returns the names of available ImageReward scores"""
|
109 |
+
return list(_SCORES.keys())
|
110 |
+
|
111 |
+
|
112 |
+
def _download(url: str, root: str):
|
113 |
+
os.makedirs(root, exist_ok=True)
|
114 |
+
filename = os.path.basename(url)
|
115 |
+
|
116 |
+
download_target = os.path.join(root, filename)
|
117 |
+
|
118 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
119 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
120 |
+
|
121 |
+
if os.path.isfile(download_target):
|
122 |
+
return download_target
|
123 |
+
|
124 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
125 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
|
126 |
+
unit_divisor=1024) as loop:
|
127 |
+
while True:
|
128 |
+
buffer = source.read(8192)
|
129 |
+
if not buffer:
|
130 |
+
break
|
131 |
+
|
132 |
+
output.write(buffer)
|
133 |
+
loop.update(len(buffer))
|
134 |
+
|
135 |
+
return download_target
|
136 |
+
|
137 |
+
|
138 |
+
def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
139 |
+
download_root: str = None):
|
140 |
+
"""Load a ImageReward model
|
141 |
+
|
142 |
+
Parameters
|
143 |
+
----------
|
144 |
+
name : str
|
145 |
+
A model name listed by `ImageReward.available_models()`
|
146 |
+
|
147 |
+
device : Union[str, torch.device]
|
148 |
+
The device to put the loaded model
|
149 |
+
|
150 |
+
download_root: str
|
151 |
+
path to download the model files; by default, it uses "~/.cache/ImageReward"
|
152 |
+
|
153 |
+
Returns
|
154 |
+
-------
|
155 |
+
model : torch.nn.Module
|
156 |
+
The ImageReward model
|
157 |
+
"""
|
158 |
+
model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward")
|
159 |
+
|
160 |
+
if name in _SCORES:
|
161 |
+
model_path = _download(_SCORES[name], model_download_root)
|
162 |
+
else:
|
163 |
+
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
|
164 |
+
|
165 |
+
print('load checkpoint from %s' % model_path)
|
166 |
+
if name == "BLIP":
|
167 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
168 |
+
med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
|
169 |
+
model_download_root)
|
170 |
+
model = BLIPScore(med_config=med_config, device=device).to(device)
|
171 |
+
model.blip.load_state_dict(state_dict['model'], strict=False)
|
172 |
+
elif name == "CLIP":
|
173 |
+
model = CLIPScore(download_root=model_download_root, device=device).to(device)
|
174 |
+
elif name == "Aesthetic":
|
175 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
176 |
+
model = AestheticScore(download_root=model_download_root, device=device).to(device)
|
177 |
+
model.mlp.load_state_dict(state_dict, strict=False)
|
178 |
+
else:
|
179 |
+
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
|
180 |
+
|
181 |
+
print("checkpoint loaded")
|
182 |
+
model.eval()
|
183 |
+
|
184 |
+
return model
|
Install.md
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Installation
|
2 |
+
|
3 |
+
Create a new conda environment:
|
4 |
+
|
5 |
+
```shell
|
6 |
+
conda create --name svgrender python=3.10
|
7 |
+
conda activate svgrender
|
8 |
+
```
|
9 |
+
|
10 |
+
Install pytorch and the following libraries:
|
11 |
+
|
12 |
+
```shell
|
13 |
+
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
|
14 |
+
pip install hydra-core omegaconf
|
15 |
+
pip install freetype-py shapely svgutils
|
16 |
+
pip install opencv-python scikit-image matplotlib visdom wandb BeautifulSoup4
|
17 |
+
pip install triton numba
|
18 |
+
pip install numpy scipy scikit-fmm einops timm fairscale=0.4.13
|
19 |
+
pip install accelerate transformers safetensors datasets
|
20 |
+
```
|
21 |
+
|
22 |
+
Install LaMa:
|
23 |
+
|
24 |
+
```shell
|
25 |
+
pip install easydict scikit-learn pytorch_lightning webdataset
|
26 |
+
pip install albumentations==0.5.2
|
27 |
+
pip install kornia==0.5.0
|
28 |
+
pip install wldhx.yadisk-direct
|
29 |
+
|
30 |
+
cd lama
|
31 |
+
# download LaMa model weights
|
32 |
+
# raw link(deprecated): curl -L $(yadisk-direct https://disk.yandex.ru/d/kHJkc7bs7mKIVA) -o big-lama.zip
|
33 |
+
curl -O -L https://huggingface.co/xingxm/PyTorch-SVGRender-models/resolve/main/big-lama.zip
|
34 |
+
unzip big-lama.zip
|
35 |
+
```
|
36 |
+
|
37 |
+
Install CLIP:
|
38 |
+
|
39 |
+
```shell
|
40 |
+
pip install ftfy regex tqdm
|
41 |
+
pip install git+https://github.com/openai/CLIP.git
|
42 |
+
```
|
43 |
+
|
44 |
+
Install diffusers:
|
45 |
+
|
46 |
+
```shell
|
47 |
+
pip install diffusers==0.20.2
|
48 |
+
```
|
49 |
+
|
50 |
+
Install xformers (require `python=3.10`):
|
51 |
+
|
52 |
+
```shell
|
53 |
+
conda install xformers -c xformers
|
54 |
+
```
|
55 |
+
|
56 |
+
Install diffvg:
|
57 |
+
|
58 |
+
```shell
|
59 |
+
git clone https://github.com/BachiLi/diffvg.git
|
60 |
+
cd diffvg
|
61 |
+
git submodule update --init --recursive
|
62 |
+
conda install -y -c anaconda cmake
|
63 |
+
conda install -y -c conda-forge ffmpeg
|
64 |
+
pip install svgwrite svgpathtools cssutils torch-tools
|
65 |
+
python setup.py install
|
66 |
+
```
|
LICENSE
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Mozilla Public License Version 2.0
|
2 |
+
==================================
|
3 |
+
|
4 |
+
1. Definitions
|
5 |
+
--------------
|
6 |
+
|
7 |
+
1.1. "Contributor"
|
8 |
+
means each individual or legal entity that creates, contributes to
|
9 |
+
the creation of, or owns Covered Software.
|
10 |
+
|
11 |
+
1.2. "Contributor Version"
|
12 |
+
means the combination of the Contributions of others (if any) used
|
13 |
+
by a Contributor and that particular Contributor's Contribution.
|
14 |
+
|
15 |
+
1.3. "Contribution"
|
16 |
+
means Covered Software of a particular Contributor.
|
17 |
+
|
18 |
+
1.4. "Covered Software"
|
19 |
+
means Source Code Form to which the initial Contributor has attached
|
20 |
+
the notice in Exhibit A, the Executable Form of such Source Code
|
21 |
+
Form, and Modifications of such Source Code Form, in each case
|
22 |
+
including portions thereof.
|
23 |
+
|
24 |
+
1.5. "Incompatible With Secondary Licenses"
|
25 |
+
means
|
26 |
+
|
27 |
+
(a) that the initial Contributor has attached the notice described
|
28 |
+
in Exhibit B to the Covered Software; or
|
29 |
+
|
30 |
+
(b) that the Covered Software was made available under the terms of
|
31 |
+
version 1.1 or earlier of the License, but not also under the
|
32 |
+
terms of a Secondary License.
|
33 |
+
|
34 |
+
1.6. "Executable Form"
|
35 |
+
means any form of the work other than Source Code Form.
|
36 |
+
|
37 |
+
1.7. "Larger Work"
|
38 |
+
means a work that combines Covered Software with other material, in
|
39 |
+
a separate file or files, that is not Covered Software.
|
40 |
+
|
41 |
+
1.8. "License"
|
42 |
+
means this document.
|
43 |
+
|
44 |
+
1.9. "Licensable"
|
45 |
+
means having the right to grant, to the maximum extent possible,
|
46 |
+
whether at the time of the initial grant or subsequently, any and
|
47 |
+
all of the rights conveyed by this License.
|
48 |
+
|
49 |
+
1.10. "Modifications"
|
50 |
+
means any of the following:
|
51 |
+
|
52 |
+
(a) any file in Source Code Form that results from an addition to,
|
53 |
+
deletion from, or modification of the contents of Covered
|
54 |
+
Software; or
|
55 |
+
|
56 |
+
(b) any new file in Source Code Form that contains any Covered
|
57 |
+
Software.
|
58 |
+
|
59 |
+
1.11. "Patent Claims" of a Contributor
|
60 |
+
means any patent claim(s), including without limitation, method,
|
61 |
+
process, and apparatus claims, in any patent Licensable by such
|
62 |
+
Contributor that would be infringed, but for the grant of the
|
63 |
+
License, by the making, using, selling, offering for sale, having
|
64 |
+
made, import, or transfer of either its Contributions or its
|
65 |
+
Contributor Version.
|
66 |
+
|
67 |
+
1.12. "Secondary License"
|
68 |
+
means either the GNU General Public License, Version 2.0, the GNU
|
69 |
+
Lesser General Public License, Version 2.1, the GNU Affero General
|
70 |
+
Public License, Version 3.0, or any later versions of those
|
71 |
+
licenses.
|
72 |
+
|
73 |
+
1.13. "Source Code Form"
|
74 |
+
means the form of the work preferred for making modifications.
|
75 |
+
|
76 |
+
1.14. "You" (or "Your")
|
77 |
+
means an individual or a legal entity exercising rights under this
|
78 |
+
License. For legal entities, "You" includes any entity that
|
79 |
+
controls, is controlled by, or is under common control with You. For
|
80 |
+
purposes of this definition, "control" means (a) the power, direct
|
81 |
+
or indirect, to cause the direction or management of such entity,
|
82 |
+
whether by contract or otherwise, or (b) ownership of more than
|
83 |
+
fifty percent (50%) of the outstanding shares or beneficial
|
84 |
+
ownership of such entity.
|
85 |
+
|
86 |
+
2. License Grants and Conditions
|
87 |
+
--------------------------------
|
88 |
+
|
89 |
+
2.1. Grants
|
90 |
+
|
91 |
+
Each Contributor hereby grants You a world-wide, royalty-free,
|
92 |
+
non-exclusive license:
|
93 |
+
|
94 |
+
(a) under intellectual property rights (other than patent or trademark)
|
95 |
+
Licensable by such Contributor to use, reproduce, make available,
|
96 |
+
modify, display, perform, distribute, and otherwise exploit its
|
97 |
+
Contributions, either on an unmodified basis, with Modifications, or
|
98 |
+
as part of a Larger Work; and
|
99 |
+
|
100 |
+
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
101 |
+
for sale, have made, import, and otherwise transfer either its
|
102 |
+
Contributions or its Contributor Version.
|
103 |
+
|
104 |
+
2.2. Effective Date
|
105 |
+
|
106 |
+
The licenses granted in Section 2.1 with respect to any Contribution
|
107 |
+
become effective for each Contribution on the date the Contributor first
|
108 |
+
distributes such Contribution.
|
109 |
+
|
110 |
+
2.3. Limitations on Grant Scope
|
111 |
+
|
112 |
+
The licenses granted in this Section 2 are the only rights granted under
|
113 |
+
this License. No additional rights or licenses will be implied from the
|
114 |
+
distribution or licensing of Covered Software under this License.
|
115 |
+
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
116 |
+
Contributor:
|
117 |
+
|
118 |
+
(a) for any code that a Contributor has removed from Covered Software;
|
119 |
+
or
|
120 |
+
|
121 |
+
(b) for infringements caused by: (i) Your and any other third party's
|
122 |
+
modifications of Covered Software, or (ii) the combination of its
|
123 |
+
Contributions with other software (except as part of its Contributor
|
124 |
+
Version); or
|
125 |
+
|
126 |
+
(c) under Patent Claims infringed by Covered Software in the absence of
|
127 |
+
its Contributions.
|
128 |
+
|
129 |
+
This License does not grant any rights in the trademarks, service marks,
|
130 |
+
or logos of any Contributor (except as may be necessary to comply with
|
131 |
+
the notice requirements in Section 3.4).
|
132 |
+
|
133 |
+
2.4. Subsequent Licenses
|
134 |
+
|
135 |
+
No Contributor makes additional grants as a result of Your choice to
|
136 |
+
distribute the Covered Software under a subsequent version of this
|
137 |
+
License (see Section 10.2) or under the terms of a Secondary License (if
|
138 |
+
permitted under the terms of Section 3.3).
|
139 |
+
|
140 |
+
2.5. Representation
|
141 |
+
|
142 |
+
Each Contributor represents that the Contributor believes its
|
143 |
+
Contributions are its original creation(s) or it has sufficient rights
|
144 |
+
to grant the rights to its Contributions conveyed by this License.
|
145 |
+
|
146 |
+
2.6. Fair Use
|
147 |
+
|
148 |
+
This License is not intended to limit any rights You have under
|
149 |
+
applicable copyright doctrines of fair use, fair dealing, or other
|
150 |
+
equivalents.
|
151 |
+
|
152 |
+
2.7. Conditions
|
153 |
+
|
154 |
+
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
155 |
+
in Section 2.1.
|
156 |
+
|
157 |
+
3. Responsibilities
|
158 |
+
-------------------
|
159 |
+
|
160 |
+
3.1. Distribution of Source Form
|
161 |
+
|
162 |
+
All distribution of Covered Software in Source Code Form, including any
|
163 |
+
Modifications that You create or to which You contribute, must be under
|
164 |
+
the terms of this License. You must inform recipients that the Source
|
165 |
+
Code Form of the Covered Software is governed by the terms of this
|
166 |
+
License, and how they can obtain a copy of this License. You may not
|
167 |
+
attempt to alter or restrict the recipients' rights in the Source Code
|
168 |
+
Form.
|
169 |
+
|
170 |
+
3.2. Distribution of Executable Form
|
171 |
+
|
172 |
+
If You distribute Covered Software in Executable Form then:
|
173 |
+
|
174 |
+
(a) such Covered Software must also be made available in Source Code
|
175 |
+
Form, as described in Section 3.1, and You must inform recipients of
|
176 |
+
the Executable Form how they can obtain a copy of such Source Code
|
177 |
+
Form by reasonable means in a timely manner, at a charge no more
|
178 |
+
than the cost of distribution to the recipient; and
|
179 |
+
|
180 |
+
(b) You may distribute such Executable Form under the terms of this
|
181 |
+
License, or sublicense it under different terms, provided that the
|
182 |
+
license for the Executable Form does not attempt to limit or alter
|
183 |
+
the recipients' rights in the Source Code Form under this License.
|
184 |
+
|
185 |
+
3.3. Distribution of a Larger Work
|
186 |
+
|
187 |
+
You may create and distribute a Larger Work under terms of Your choice,
|
188 |
+
provided that You also comply with the requirements of this License for
|
189 |
+
the Covered Software. If the Larger Work is a combination of Covered
|
190 |
+
Software with a work governed by one or more Secondary Licenses, and the
|
191 |
+
Covered Software is not Incompatible With Secondary Licenses, this
|
192 |
+
License permits You to additionally distribute such Covered Software
|
193 |
+
under the terms of such Secondary License(s), so that the recipient of
|
194 |
+
the Larger Work may, at their option, further distribute the Covered
|
195 |
+
Software under the terms of either this License or such Secondary
|
196 |
+
License(s).
|
197 |
+
|
198 |
+
3.4. Notices
|
199 |
+
|
200 |
+
You may not remove or alter the substance of any license notices
|
201 |
+
(including copyright notices, patent notices, disclaimers of warranty,
|
202 |
+
or limitations of liability) contained within the Source Code Form of
|
203 |
+
the Covered Software, except that You may alter any license notices to
|
204 |
+
the extent required to remedy known factual inaccuracies.
|
205 |
+
|
206 |
+
3.5. Application of Additional Terms
|
207 |
+
|
208 |
+
You may choose to offer, and to charge a fee for, warranty, support,
|
209 |
+
indemnity or liability obligations to one or more recipients of Covered
|
210 |
+
Software. However, You may do so only on Your own behalf, and not on
|
211 |
+
behalf of any Contributor. You must make it absolutely clear that any
|
212 |
+
such warranty, support, indemnity, or liability obligation is offered by
|
213 |
+
You alone, and You hereby agree to indemnify every Contributor for any
|
214 |
+
liability incurred by such Contributor as a result of warranty, support,
|
215 |
+
indemnity or liability terms You offer. You may include additional
|
216 |
+
disclaimers of warranty and limitations of liability specific to any
|
217 |
+
jurisdiction.
|
218 |
+
|
219 |
+
4. Inability to Comply Due to Statute or Regulation
|
220 |
+
---------------------------------------------------
|
221 |
+
|
222 |
+
If it is impossible for You to comply with any of the terms of this
|
223 |
+
License with respect to some or all of the Covered Software due to
|
224 |
+
statute, judicial order, or regulation then You must: (a) comply with
|
225 |
+
the terms of this License to the maximum extent possible; and (b)
|
226 |
+
describe the limitations and the code they affect. Such description must
|
227 |
+
be placed in a text file included with all distributions of the Covered
|
228 |
+
Software under this License. Except to the extent prohibited by statute
|
229 |
+
or regulation, such description must be sufficiently detailed for a
|
230 |
+
recipient of ordinary skill to be able to understand it.
|
231 |
+
|
232 |
+
5. Termination
|
233 |
+
--------------
|
234 |
+
|
235 |
+
5.1. The rights granted under this License will terminate automatically
|
236 |
+
if You fail to comply with any of its terms. However, if You become
|
237 |
+
compliant, then the rights granted under this License from a particular
|
238 |
+
Contributor are reinstated (a) provisionally, unless and until such
|
239 |
+
Contributor explicitly and finally terminates Your grants, and (b) on an
|
240 |
+
ongoing basis, if such Contributor fails to notify You of the
|
241 |
+
non-compliance by some reasonable means prior to 60 days after You have
|
242 |
+
come back into compliance. Moreover, Your grants from a particular
|
243 |
+
Contributor are reinstated on an ongoing basis if such Contributor
|
244 |
+
notifies You of the non-compliance by some reasonable means, this is the
|
245 |
+
first time You have received notice of non-compliance with this License
|
246 |
+
from such Contributor, and You become compliant prior to 30 days after
|
247 |
+
Your receipt of the notice.
|
248 |
+
|
249 |
+
5.2. If You initiate litigation against any entity by asserting a patent
|
250 |
+
infringement claim (excluding declaratory judgment actions,
|
251 |
+
counter-claims, and cross-claims) alleging that a Contributor Version
|
252 |
+
directly or indirectly infringes any patent, then the rights granted to
|
253 |
+
You by any and all Contributors for the Covered Software under Section
|
254 |
+
2.1 of this License shall terminate.
|
255 |
+
|
256 |
+
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
257 |
+
end user license agreements (excluding distributors and resellers) which
|
258 |
+
have been validly granted by You or Your distributors under this License
|
259 |
+
prior to termination shall survive termination.
|
260 |
+
|
261 |
+
************************************************************************
|
262 |
+
* *
|
263 |
+
* 6. Disclaimer of Warranty *
|
264 |
+
* ------------------------- *
|
265 |
+
* *
|
266 |
+
* Covered Software is provided under this License on an "as is" *
|
267 |
+
* basis, without warranty of any kind, either expressed, implied, or *
|
268 |
+
* statutory, including, without limitation, warranties that the *
|
269 |
+
* Covered Software is free of defects, merchantable, fit for a *
|
270 |
+
* particular purpose or non-infringing. The entire risk as to the *
|
271 |
+
* quality and performance of the Covered Software is with You. *
|
272 |
+
* Should any Covered Software prove defective in any respect, You *
|
273 |
+
* (not any Contributor) assume the cost of any necessary servicing, *
|
274 |
+
* repair, or correction. This disclaimer of warranty constitutes an *
|
275 |
+
* essential part of this License. No use of any Covered Software is *
|
276 |
+
* authorized under this License except under this disclaimer. *
|
277 |
+
* *
|
278 |
+
************************************************************************
|
279 |
+
|
280 |
+
************************************************************************
|
281 |
+
* *
|
282 |
+
* 7. Limitation of Liability *
|
283 |
+
* -------------------------- *
|
284 |
+
* *
|
285 |
+
* Under no circumstances and under no legal theory, whether tort *
|
286 |
+
* (including negligence), contract, or otherwise, shall any *
|
287 |
+
* Contributor, or anyone who distributes Covered Software as *
|
288 |
+
* permitted above, be liable to You for any direct, indirect, *
|
289 |
+
* special, incidental, or consequential damages of any character *
|
290 |
+
* including, without limitation, damages for lost profits, loss of *
|
291 |
+
* goodwill, work stoppage, computer failure or malfunction, or any *
|
292 |
+
* and all other commercial damages or losses, even if such party *
|
293 |
+
* shall have been informed of the possibility of such damages. This *
|
294 |
+
* limitation of liability shall not apply to liability for death or *
|
295 |
+
* personal injury resulting from such party's negligence to the *
|
296 |
+
* extent applicable law prohibits such limitation. Some *
|
297 |
+
* jurisdictions do not allow the exclusion or limitation of *
|
298 |
+
* incidental or consequential damages, so this exclusion and *
|
299 |
+
* limitation may not apply to You. *
|
300 |
+
* *
|
301 |
+
************************************************************************
|
302 |
+
|
303 |
+
8. Litigation
|
304 |
+
-------------
|
305 |
+
|
306 |
+
Any litigation relating to this License may be brought only in the
|
307 |
+
courts of a jurisdiction where the defendant maintains its principal
|
308 |
+
place of business and such litigation shall be governed by laws of that
|
309 |
+
jurisdiction, without reference to its conflict-of-law provisions.
|
310 |
+
Nothing in this Section shall prevent a party's ability to bring
|
311 |
+
cross-claims or counter-claims.
|
312 |
+
|
313 |
+
9. Miscellaneous
|
314 |
+
----------------
|
315 |
+
|
316 |
+
This License represents the complete agreement concerning the subject
|
317 |
+
matter hereof. If any provision of this License is held to be
|
318 |
+
unenforceable, such provision shall be reformed only to the extent
|
319 |
+
necessary to make it enforceable. Any law or regulation which provides
|
320 |
+
that the language of a contract shall be construed against the drafter
|
321 |
+
shall not be used to construe this License against a Contributor.
|
322 |
+
|
323 |
+
10. Versions of the License
|
324 |
+
---------------------------
|
325 |
+
|
326 |
+
10.1. New Versions
|
327 |
+
|
328 |
+
Mozilla Foundation is the license steward. Except as provided in Section
|
329 |
+
10.3, no one other than the license steward has the right to modify or
|
330 |
+
publish new versions of this License. Each version will be given a
|
331 |
+
distinguishing version number.
|
332 |
+
|
333 |
+
10.2. Effect of New Versions
|
334 |
+
|
335 |
+
You may distribute the Covered Software under the terms of the version
|
336 |
+
of the License under which You originally received the Covered Software,
|
337 |
+
or under the terms of any subsequent version published by the license
|
338 |
+
steward.
|
339 |
+
|
340 |
+
10.3. Modified Versions
|
341 |
+
|
342 |
+
If you create software not governed by this License, and you want to
|
343 |
+
create a new license for such software, you may create and use a
|
344 |
+
modified version of this License if you rename the license and remove
|
345 |
+
any references to the name of the license steward (except to note that
|
346 |
+
such modified license differs from this License).
|
347 |
+
|
348 |
+
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
349 |
+
Licenses
|
350 |
+
|
351 |
+
If You choose to distribute Source Code Form that is Incompatible With
|
352 |
+
Secondary Licenses under the terms of this version of the License, the
|
353 |
+
notice described in Exhibit B of this License must be attached.
|
354 |
+
|
355 |
+
Exhibit A - Source Code Form License Notice
|
356 |
+
-------------------------------------------
|
357 |
+
|
358 |
+
This Source Code Form is subject to the terms of the Mozilla Public
|
359 |
+
License, v. 2.0. If a copy of the MPL was not distributed with this
|
360 |
+
file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
361 |
+
|
362 |
+
If it is not possible or desirable to put the notice in a particular
|
363 |
+
file, then You may include the notice in a location (such as a LICENSE
|
364 |
+
file in a relevant directory) where a recipient would be likely to look
|
365 |
+
for such a notice.
|
366 |
+
|
367 |
+
You may add additional accurate notices of copyright ownership.
|
368 |
+
|
369 |
+
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
370 |
+
---------------------------------------------------------
|
371 |
+
|
372 |
+
This Source Code Form is "Incompatible With Secondary Licenses", as
|
373 |
+
defined by the Mozilla Public License, v. 2.0.
|
README copy.md
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: SVGRender
|
3 |
+
emoji: 💻
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.20.1
|
8 |
+
python_version: 3.10.12
|
9 |
+
app_file: app.py
|
10 |
+
pinned: false
|
11 |
+
license: apache-2.0
|
12 |
+
---
|
13 |
+
|
14 |
+
<h1 id="ptsvg" align="center">Pytorch-SVGRender</h1>
|
15 |
+
|
16 |
+
<p align="center">
|
17 |
+
<a href="https://www.python.org/"><img src="https://img.shields.io/badge/python-3.10-or?logo=python" alt="pyhton"></a>
|
18 |
+
<a href="http://mozilla.org/MPL/2.0/"><img src="https://img.shields.io/badge/license-MPL2.0-orange" alt="license"></a>
|
19 |
+
<a href="https://ximinng.github.io/PyTorch-SVGRender-project/"><img src="https://img.shields.io/badge/website-Gitpage-yellow" alt="website"></a>
|
20 |
+
<a href="https://pytorch-svgrender.readthedocs.io/en/latest/index.html"><img src="https://img.shields.io/badge/docs-readthedocs-purple" alt="docs"></a>
|
21 |
+
</p>
|
22 |
+
|
23 |
+
<div align="center">
|
24 |
+
<img src="./assets/logo.png" style="width: 350px; height: 300px;" alt="Pytorch-SVGRender">
|
25 |
+
<p><strong>Pytorch-SVGRender: </strong>The go-to library for differentiable rendering methods for SVG generation.</p>
|
26 |
+
</div>
|
27 |
+
<p align="center">
|
28 |
+
<a href="#recent-updates">Updates</a> •
|
29 |
+
<a href="#table-of-contents">Table of Contents</a> •
|
30 |
+
<a href="#installation">Installation</a> •
|
31 |
+
<a href="#quickstart">Quickstart</a> •
|
32 |
+
<a href="#faq">FAQ</a> •
|
33 |
+
<a href="#todo">TODO</a> •
|
34 |
+
<a href="#acknowledgement">Acknowledgment</a> •
|
35 |
+
<a href="#citation">Citation</a> •
|
36 |
+
<a href="#licence">Licence</a>
|
37 |
+
</p>
|
38 |
+
|
39 |
+
<h2 align="center">Recent Updates</h2>
|
40 |
+
|
41 |
+
- [12/2023] 🔥 We open-sourced Pytorch-SVGRender V1.0.
|
42 |
+
|
43 |
+
<h2 align="center">Table of Contents</h2>
|
44 |
+
<p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
|
45 |
+
|
46 |
+
### 1. Image Vectorization
|
47 |
+
|
48 |
+
- DiffVG: Differentiable Vector Graphics Rasterization for Editing and Learning (`SIGGRAPH 2020`)
|
49 |
+
|
50 |
+
[[Project]](https://people.csail.mit.edu/tzumao/diffvg/) [[Paper]](https://cseweb.ucsd.edu/~tzli/diffvg/diffvg.pdf) [[Code]](https://github.com/BachiLi/diffvg)
|
51 |
+
|
52 |
+
DiffVG is a differentiable rasterizer for 2D vector graphics. **This repository is heavily based on DiffVG.**
|
53 |
+
|
54 |
+
- LIVE: Towards Layer-wise Image Vectorization (`CVPR 2022`)
|
55 |
+
|
56 |
+
[[Project]](https://ma-xu.github.io/LIVE/) [[Paper]](https://ma-xu.github.io/LIVE/index_files/CVPR22_LIVE_main.pdf) [[Code]](https://github.com/Picsart-AI-Research/LIVE-Layerwise-Image-Vectorization)
|
57 |
+
|
58 |
+
- CLIPasso: Semantically-Aware Object Sketching (`SIGGRAPH 2022`)
|
59 |
+
|
60 |
+
[[Project]](https://clipasso.github.io/clipasso/) [[Paper]](https://arxiv.org/abs/2202.05822) [[Code]](https://github.com/yael-vinker/CLIPasso)
|
61 |
+
|
62 |
+
- CLIPascene: Scene Sketching with Different Types and Levels of Abstraction (`ICCV 2023`)
|
63 |
+
|
64 |
+
[[Project]](https://clipascene.github.io/CLIPascene/) [[Paper]](https://arxiv.org/abs/2211.17256) [[Code]](https://github.com/yael-vinker/SceneSketch)
|
65 |
+
|
66 |
+
### 2. Text-to-SVG Synthesis
|
67 |
+
|
68 |
+
- CLIPDraw: Exploring Text-to-Drawing Synthesis through Language-Image Encoders (`NIPS 2022`)
|
69 |
+
|
70 |
+
[[Paper]](https://arxiv.org/abs/2106.14843) [[Code]](https://github.com/kvfrans/clipdraw)
|
71 |
+
|
72 |
+
- StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Synthesis
|
73 |
+
|
74 |
+
[[Live]](https://slideslive.com/38970834/styleclipdraw-coupling-content-and-style-in-texttodrawing-synthesis?ref=account-folder-92044-folders) [[Paper]](https://arxiv.org/abs/2202.12362) [[Code]](https://github.com/pschaldenbrand/StyleCLIPDraw)
|
75 |
+
|
76 |
+
- CLIPFont: Texture Guided Vector WordArt Generation (`BMVC 2022`)
|
77 |
+
|
78 |
+
[[Paper]](https://bmvc2022.mpi-inf.mpg.de/0543.pdf) [[Code]](https://github.com/songyiren98/CLIPFont)
|
79 |
+
|
80 |
+
- VectorFusion: Text-to-SVG by Abstracting Pixel-Based Diffusion Models (`CVPR 2023`)
|
81 |
+
|
82 |
+
[[Project]](https://vectorfusion.github.io/) [[Paper]](https://openaccess.thecvf.com/content/CVPR2023/papers/Jain_VectorFusion_Text-to-SVG_by_Abstracting_Pixel-Based_Diffusion_Models_CVPR_2023_paper.pdf)
|
83 |
+
|
84 |
+
- DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models (`NIPS 2023`)
|
85 |
+
|
86 |
+
[[Project]](https://ximinng.github.io/DiffSketcher-project/) [[Live]](https://neurips.cc/virtual/2023/poster/72425) [[Paper]](https://arxiv.org/abs/2306.14685) [[Code]](https://github.com/ximinng/DiffSketcher)
|
87 |
+
|
88 |
+
- Word-As-Image for Semantic Typography (`SIGGRAPH 2023`)
|
89 |
+
|
90 |
+
[[Project]](https://wordasimage.github.io/Word-As-Image-Page/) [[Paper]](https://arxiv.org/abs/2303.01818) [[Code]](https://github.com/Shiriluz/Word-As-Image)
|
91 |
+
|
92 |
+
- SVGDreamer: Text Guided SVG Generation with Diffusion Model (`CVPR 2024`)
|
93 |
+
|
94 |
+
[[Project]](https://ximinng.github.io/SVGDreamer-project/) [[Paper]](https://arxiv.org/abs/2312.16476) [[code]](https://github.com/ximinng/SVGDreamer)
|
95 |
+
|
96 |
+
<h2 align="center">Installation</h2>
|
97 |
+
|
98 |
+
You can follow the steps below to quickly get up and running with PyTorch-SVGRender.
|
99 |
+
These steps will let you run quick inference locally.
|
100 |
+
|
101 |
+
In the top level directory run,
|
102 |
+
|
103 |
+
```bash
|
104 |
+
sh script/install.sh
|
105 |
+
```
|
106 |
+
|
107 |
+
Note: Make sure that the script file has execution **permissions** (you can give them using `chmod +x script.sh`), and
|
108 |
+
then run the script.
|
109 |
+
|
110 |
+
For more information, please refer to
|
111 |
+
the [Install.md](https://github.com/ximinng/PyTorch-SVGRender/blob/main/Install.md).
|
112 |
+
|
113 |
+
<h2 align="center">Quickstart</h2>
|
114 |
+
<p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
|
115 |
+
|
116 |
+
**For more information, [read the docs](https://pytorch-svgrender.readthedocs.io/en/latest/index.html).**
|
117 |
+
|
118 |
+
### 1. Basic Usage
|
119 |
+
|
120 |
+
**DiffVG** vectorizes any raster images:
|
121 |
+
|
122 |
+
```shell
|
123 |
+
python svg_render.py x=diffvg target='./data/fallingwater.png'
|
124 |
+
# change 'num_paths' and 'num_iter' for better results
|
125 |
+
python svg_render.py x=diffvg target='./data/fallingwater.png' x.num_paths=512 x.num_iter=2000
|
126 |
+
```
|
127 |
+
|
128 |
+
**LIVE** vectorizes the raster emojis images (in original PNG format):
|
129 |
+
|
130 |
+
```shell
|
131 |
+
python svg_render.py x=live target='./data/simile.png'
|
132 |
+
# change 'num_paths' and 'schedule_each' for better results
|
133 |
+
python svg_render.py x=live target='./data/simile.png' x.num_paths=5 x.schedule_each=1
|
134 |
+
```
|
135 |
+
|
136 |
+
**CLIPasso** synthesizes vectorized sketches from images:
|
137 |
+
|
138 |
+
**note:** first download the U2Net model `sh script/download_u2net.sh`.
|
139 |
+
|
140 |
+
```shell
|
141 |
+
python svg_render.py x=clipasso target='./data/horse.png'
|
142 |
+
```
|
143 |
+
|
144 |
+
**CLIPascene** synthesizes vectorized sketches from images:
|
145 |
+
|
146 |
+
**note:** first download the U2Net model `sh script/download_u2net.sh`, and make sure the `./data/background` folder and
|
147 |
+
the `./data/scene` folder exist with target images.
|
148 |
+
|
149 |
+
```shell
|
150 |
+
python svg_render.py x=clipascene target='ballerina.png'
|
151 |
+
```
|
152 |
+
|
153 |
+
**CLIPDraw** synthesizes SVGs based on text prompts:
|
154 |
+
|
155 |
+
```shell
|
156 |
+
python svg_render.py x=clipdraw "prompt='a photo of a cat'"
|
157 |
+
```
|
158 |
+
|
159 |
+
**StyleCLIPDraw** synthesizes SVG based on a text prompt and a reference image:
|
160 |
+
|
161 |
+
```shell
|
162 |
+
python svg_render.py x=styleclipdraw "prompt='a photo of a cat'" target='./data/starry.png'
|
163 |
+
```
|
164 |
+
|
165 |
+
**CLIPFont** styles vector fonts according to text prompts:
|
166 |
+
|
167 |
+
```shell
|
168 |
+
python svg_render.py x=clipfont "prompt='Starry Night by Vincent van gogh'" target='./data/alphabet1.svg'
|
169 |
+
```
|
170 |
+
|
171 |
+
---
|
172 |
+
|
173 |
+
> Because the following methods rely on stable diffusion, add `diffuser.download=True` to the command the **first time** you
|
174 |
+
run the script.
|
175 |
+
|
176 |
+
**SVGDreamer** generates various styles of SVG based on text prompts. It supports the use of six vector primitives,
|
177 |
+
including Iconography, Sketch, Pixel Art, Low-Poly, Painting, and Ink and Wash.
|
178 |
+
|
179 |
+
```shell
|
180 |
+
# primitive: iconography
|
181 |
+
## 1. German shepherd
|
182 |
+
python svg_render.py x=svgdreamer "prompt='A colorful German shepherd in vector art. tending on artstation.'" save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 result_path='./svgdreamer/GermanShepherd'
|
183 |
+
## 2. sydney opera house
|
184 |
+
python svg_render.py x=svgdreamer "prompt='Sydney opera house. oil painting. by Van Gogh'" save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.num_paths=512 result_path='./svgdreamer/Sydney'
|
185 |
+
# primitive: low-ploy
|
186 |
+
python svg_render.py x=svgdreamer "prompt='A picture of a bald eagle. low-ploy. polygon'" x.style='low-poly' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.num_iter=1000 result_path='./svgdreamer/eagle'
|
187 |
+
# primitive: pixel-art
|
188 |
+
python svg_render.py x=svgdreamer "prompt='Darth vader with lightsaber. ultrarealistic. pixelart. trending on artstation.'" x.style='pixelart' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.num_iter=1000 result_path='./svgdreamer/DarthVader'
|
189 |
+
# primitive: painting
|
190 |
+
python svg_render.py x=svgdreamer "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.style='painting' save_step=50 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.t_schedule='randint' x.num_paths=1500 result_path='./svgdreamer/VanGogh_portrait'
|
191 |
+
# primitive: sketch
|
192 |
+
python svg_render.py x=svgdreamer "prompt='A free-hand drawing of A speeding Lamborghini. black and white drawing.'" x.style='sketch' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.t_schedule='randint' x.num_paths=128 result_path='./svgdreamer/Lamborghini'
|
193 |
+
# primitive: ink and wash
|
194 |
+
python svg_render.py x=svgdreamer "prompt='Big Wild Goose Pagoda. ink style. Minimalist abstract art grayscale watercolor.'" x.style='ink' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.t_schedule='randint' x.num_paths=128 x.width=6 result_path='./svgdreamer/BigWildGoosePagoda'
|
195 |
+
```
|
196 |
+
|
197 |
+
**VectorFusion** synthesizes SVGs in various styles based on text prompts:
|
198 |
+
|
199 |
+
```shell
|
200 |
+
# Iconography style
|
201 |
+
python svg_render.py x=vectorfusion x.style='iconography' "prompt='a panda rowing a boat in a pond. minimal flat 2d vector icon. lineal color. trending on artstation.'"
|
202 |
+
# PixelArt style
|
203 |
+
python svg_render.py x=vectorfusion x.style='pixelart' "prompt='a panda rowing a boat in a pond. pixel art. trending on artstation.'"
|
204 |
+
# Sketch style
|
205 |
+
python svg_render.py x=vectorfusion x.style='sketch' "prompt='a panda rowing a boat in a pond. minimal 2d line drawing. trending on artstation.'"
|
206 |
+
```
|
207 |
+
|
208 |
+
Following SVGDreamer, we've added three additional styles (`Paining`, `Ink and Wash` and `low-ploy`) to VectorFusion.
|
209 |
+
|
210 |
+
**DiffSketcher** synthesizes vector sketches based on text prompts:
|
211 |
+
|
212 |
+
```shell
|
213 |
+
# DiffSketcher
|
214 |
+
python svg_render.py x=diffsketcher "prompt='a photo of Sydney opera house'" x.token_ind=5 seed=8019
|
215 |
+
# DiffSketcher, variable stroke width
|
216 |
+
python svg_render.py x=diffsketcher "prompt='a photo of Sydney opera house'" x.token_ind=5 x.optim_width=True seed=8019
|
217 |
+
# DiffSketcher RGBA version
|
218 |
+
python svg_render.py x=diffsketcher "prompt='a photo of Sydney opera house'" x.token_ind=5 x.optim_width=True x.optim_rgba=True x.optim_opacity=False seed=8019
|
219 |
+
# DiffSketcher + style transfer
|
220 |
+
python svg_render.py x=stylediffsketcher "prompt='The French Revolution. highly detailed. 8k. ornate. intricate. cinematic. dehazed. atmospheric. oil painting. by Van Gogh'" x.token_ind=4 x.num_paths=2000 target='./data/starry.png' seed=876809
|
221 |
+
```
|
222 |
+
|
223 |
+
**Word-As-Image** follow a text prompt to style a letter in a word:
|
224 |
+
|
225 |
+
```shell
|
226 |
+
# Inject the meaning of the word bunny into the 'Y' in the word 'BUNNY'
|
227 |
+
python svg_render.py x=wordasimage x.word='BUNNY' prompt='BUNNY' x.optim_letter='Y'
|
228 |
+
```
|
229 |
+
|
230 |
+
### 2. SDS Loss based Approach
|
231 |
+
|
232 |
+
This is achieved by utilizing a pretrained text-to-image diffusion model as a strong image prior to supervise the
|
233 |
+
training of the PyDiffVG, enabling rendering SVG alignment with the text. This remarkable capability is fundamentally
|
234 |
+
grounded in the use of Score Distillation Sampling (SDS). SDS acts as the core mechanism that lifts raster images from
|
235 |
+
diffusion models to the SVG domain, enabling the training of SVG parameters without images.
|
236 |
+
This includes the methods VectorFusion, DiffSketcher and SVGDreamer.
|
237 |
+
|
238 |
+
We only compare the performance of SDS, which means that no other loss is used:
|
239 |
+
|
240 |
+
```shell
|
241 |
+
# SDS loss
|
242 |
+
python svg_render.py x=vectorfusion "prompt='a panda rowing a boat in a pond. minimal flat 2d vector icon. lineal color. trending on artstation.'"
|
243 |
+
# Input Augmentation SDS loss (LSDS loss)
|
244 |
+
python svg_render.py x=vectorfusion x.style='sketch' "prompt='an elephant. minimal 2d line drawing. trending on artstation.'"
|
245 |
+
# Input Augmentation SDS loss (ASDS loss)
|
246 |
+
python svg_render.py x=diffsketcher "prompt='an elephant. minimal 2d line drawing. trending on artstation.'" x.token_ind=2 x.sds.grad_scale=1 x.sds.num_aug=4 x.clip.vis_loss=0 x.perceptual.coeff=0 x.opacity_delta=0.3
|
247 |
+
# Vectorized Particle-based Score Distillation (VPSD loss)
|
248 |
+
python svg_render.py x=svgdreamer "prompt='a panda rowing a boat in a pond. minimal flat 2d vector icon. lineal color. trending on artstation.'" save_step=60 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2
|
249 |
+
```
|
250 |
+
|
251 |
+
<h2 align="center">FAQ</h2>
|
252 |
+
<p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
|
253 |
+
|
254 |
+
- Q: Where can I get more scripts and visualizations?
|
255 |
+
- A: check the [pytorch-svgrender.readthedocs.io](https://pytorch-svgrender.readthedocs.io/en/latest/index.html).
|
256 |
+
|
257 |
+
- Q: An error says HuggingFace cannot find the model in the disk cache.
|
258 |
+
- A: Add *`diffuser.download=True`* to the command for downloading model checkpoints the **first time** you run the script.
|
259 |
+
|
260 |
+
<h2 align="center">TODO</h2>
|
261 |
+
<p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
|
262 |
+
|
263 |
+
- [x] integrated SVGDreamer.
|
264 |
+
|
265 |
+
<h2 align="center">Acknowledgement</h2>
|
266 |
+
<p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
|
267 |
+
|
268 |
+
The project is built based on the following repository:
|
269 |
+
|
270 |
+
[BachiLi/diffvg](https://github.com/BachiLi/diffvg),
|
271 |
+
[huggingface/diffusers](https://github.com/huggingface/diffusers),
|
272 |
+
[threestudio-project/threestudio](https://github.com/threestudio-project/threestudio),
|
273 |
+
[yael-vinker/CLIPasso](https://github.com/yael-vinker/CLIPasso),
|
274 |
+
[ximinng/DiffSketcher](https://github.com/ximinng/DiffSketcher),
|
275 |
+
[THUDM/ImageReward](https://github.com/THUDM/ImageReward),
|
276 |
+
[advimman/lama](https://github.com/advimman/lama)
|
277 |
+
|
278 |
+
We gratefully thank the authors for their wonderful works.
|
279 |
+
|
280 |
+
<h2 align="center">Citation</h2>
|
281 |
+
<p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
|
282 |
+
|
283 |
+
If you use this code for your research, please cite the following work:
|
284 |
+
|
285 |
+
```
|
286 |
+
@article{xing2023svgdreamer,
|
287 |
+
title={SVGDreamer: Text Guided SVG Generation with Diffusion Model},
|
288 |
+
author={Xing, Ximing and Zhou, Haitao and Wang, Chuang and Zhang, Jing and Xu, Dong and Yu, Qian},
|
289 |
+
journal={arXiv preprint arXiv:2312.16476},
|
290 |
+
year={2023}
|
291 |
+
}
|
292 |
+
@inproceedings{xing2023diffsketcher,
|
293 |
+
title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models},
|
294 |
+
author={XiMing Xing and Chuang Wang and Haitao Zhou and Jing Zhang and Qian Yu and Dong Xu},
|
295 |
+
booktitle={Thirty-seventh Conference on Neural Information Processing Systems (NeurIPS)},
|
296 |
+
year={2023},
|
297 |
+
url={https://openreview.net/forum?id=CY1xatvEQj}
|
298 |
+
}
|
299 |
+
```
|
300 |
+
|
301 |
+
<h2 align="center">Licence</h2>
|
302 |
+
<p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
|
303 |
+
|
304 |
+
This work is licensed under a **Mozilla Public License Version 2.0**.
|
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
import tempfile
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
sys.path.append('/home/user/app/code')
|
10 |
+
|
11 |
+
# set up diffvg
|
12 |
+
# os.system('git clone https://github.com/BachiLi/diffvg.git')
|
13 |
+
os.chdir('diffvg')
|
14 |
+
os.system('git submodule update --init --recursive')
|
15 |
+
os.system('python setup.py install --user')
|
16 |
+
sys.path.append("/home/user/.local/lib/python3.10/site-packages/diffvg-0.0.1-py3.10-linux-x86_64.egg")
|
17 |
+
print("diffvg installed.")
|
18 |
+
os.chdir('/home/user/app')
|
19 |
+
|
20 |
+
|
21 |
+
def process_images(prompt, num_paths, token_index, seed, optimize_width=False, optimize_color=False):
|
22 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
23 |
+
command = [
|
24 |
+
"python", "svg_render.py",
|
25 |
+
"x=diffsketcher",
|
26 |
+
f"prompt={prompt}",
|
27 |
+
f"x.num_paths={num_paths}",
|
28 |
+
f"x.token_ind={token_index}",
|
29 |
+
f"seed={seed}",
|
30 |
+
f"x.optim_width={optimize_width}",
|
31 |
+
f"x.optim_rgba={optimize_color}",
|
32 |
+
"x.optim_opacity=False",
|
33 |
+
]
|
34 |
+
result = subprocess.run(command, check=True)
|
35 |
+
if result.returncode == 0:
|
36 |
+
output_image = Image.open(os.path.join(tmpdirname, "final_render.png"))
|
37 |
+
return output_image
|
38 |
+
|
39 |
+
|
40 |
+
with gr.Blocks() as demo:
|
41 |
+
gr.Markdown("# DiffSketcher")
|
42 |
+
gr.Markdown("DiffSketcher synthesizes **vector sketches** based on **text prompts**.")
|
43 |
+
li = [
|
44 |
+
"https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/cat.svg",
|
45 |
+
"https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/rose.svg",
|
46 |
+
"https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/elephant.svg",
|
47 |
+
"https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/elephant_silhouette.svg",
|
48 |
+
"https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/horse_width.svg",
|
49 |
+
"https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/horse_rgba.svg",
|
50 |
+
"https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera.svg",
|
51 |
+
"https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera_width.svg",
|
52 |
+
"https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera_width_color.svg",
|
53 |
+
]
|
54 |
+
gr.Gallery(li, columns=6)
|
55 |
+
with gr.Row():
|
56 |
+
with gr.Column():
|
57 |
+
text = gr.Textbox(label="prompt")
|
58 |
+
num_paths = gr.Slider(label="path number", value=96, minimum=1, maximum=500, step=1)
|
59 |
+
token_index = gr.Textbox(label="token_index", info="CLIP embedding token index. Starting from 1.")
|
60 |
+
seed = gr.Slider(0, 10000, label="random seed", value=8019)
|
61 |
+
with gr.Accordion("Selectable Inputs"):
|
62 |
+
optimize_width = gr.Checkbox(label="optimize stroke width")
|
63 |
+
optimize_color = gr.Checkbox(label="optimize stroke color")
|
64 |
+
btn = gr.Button("Synthesize")
|
65 |
+
with gr.Column():
|
66 |
+
output = gr.Image(label="output image", height=512)
|
67 |
+
btn.click(process_images,
|
68 |
+
inputs=[text, num_paths, token_index, seed, optimize_width, optimize_color],
|
69 |
+
outputs=[output])
|
70 |
+
gr.Markdown("## Examples")
|
71 |
+
gr.Markdown("Here are some config examples. Feel free to try your own prompts!")
|
72 |
+
gr.Examples(
|
73 |
+
inputs=[text, num_paths, token_index, seed, optimize_width, optimize_color],
|
74 |
+
outputs=[output],
|
75 |
+
fn=process_images,
|
76 |
+
examples=[
|
77 |
+
["A photo of Sydney opera house.", 96, 5, 8019, False, False],
|
78 |
+
["A photo of Sydney opera house.", 96, 5, 8019, True, False],
|
79 |
+
["A photo of Sydney opera house.", 128, 5, 8019, True, True],
|
80 |
+
],
|
81 |
+
)
|
82 |
+
|
83 |
+
demo.launch()
|
assets/fonts/Bell-MT.ttf
ADDED
Binary file (84.8 kB). View file
|
|
assets/fonts/DeliusUnicase-Regular.ttf
ADDED
Binary file (31.5 kB). View file
|
|
assets/fonts/HobeauxRococeaux-Sherman.ttf
ADDED
Binary file (117 kB). View file
|
|
assets/fonts/IndieFlower-Regular.ttf
ADDED
Binary file (55.4 kB). View file
|
|
assets/fonts/JosefinSans-Light.ttf
ADDED
Binary file (59.3 kB). View file
|
|
assets/fonts/KaushanScript-Regular.ttf
ADDED
Binary file (184 kB). View file
|
|
assets/fonts/LuckiestGuy-Regular.ttf
ADDED
Binary file (58.3 kB). View file
|
|
assets/fonts/Noteworthy-Bold.ttf
ADDED
Binary file (248 kB). View file
|
|
assets/fonts/Quicksand.ttf
ADDED
Binary file (124 kB). View file
|
|
assets/fonts/Saira-Regular.ttf
ADDED
Binary file (82.8 kB). View file
|
|
checkpoint/placeholder.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**place model here**
|
conf/config.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#-----------------#
|
2 |
+
# Global Config #
|
3 |
+
#-----------------#
|
4 |
+
|
5 |
+
# optional args
|
6 |
+
target: ~
|
7 |
+
prompt: ~
|
8 |
+
neg_prompt: ~ # negative prompt
|
9 |
+
|
10 |
+
# Accelerate config
|
11 |
+
state:
|
12 |
+
cpu: False # use cpu
|
13 |
+
mprec: no # mixed precision, choices: 'no', 'fp16', 'bf16'
|
14 |
+
# wandb: False
|
15 |
+
# tensorboard: False
|
16 |
+
|
17 |
+
# Diffusers config
|
18 |
+
diffuser:
|
19 |
+
download: True # Set this variable to True the first time it runs
|
20 |
+
force_download: False
|
21 |
+
resume_download: False
|
22 |
+
|
23 |
+
# PyDiffVG config
|
24 |
+
diffvg:
|
25 |
+
print_timing: False
|
26 |
+
|
27 |
+
# reproduction
|
28 |
+
seed: 951222
|
29 |
+
# multi-run
|
30 |
+
multirun: False
|
31 |
+
srange: ~ # seed range, example: [100, 100]
|
32 |
+
|
33 |
+
# log
|
34 |
+
result_path: './workspace'
|
35 |
+
save_step: 10
|
36 |
+
eval_step: 10
|
37 |
+
|
38 |
+
# visual rendering process
|
39 |
+
mv: False # make video
|
40 |
+
framefreq: 5 # save the image interval
|
41 |
+
framerate: 24 # by adjusting the frame rate, you can control the playback speed of the output video
|
42 |
+
|
43 |
+
# hydra setting
|
44 |
+
hydra:
|
45 |
+
help:
|
46 |
+
# app name, override to match the name your app is known by
|
47 |
+
app_name: 'SVGRender'
|
48 |
+
run:
|
49 |
+
# output directory for normal runs
|
50 |
+
# warning: make sure that the L56-58 of '/libs/engine/model_state.py' and 'dir' are modified together
|
51 |
+
dir: ./${result_path}/${x.method}-${now:%Y-%m-%d-%H-%M}
|
52 |
+
|
53 |
+
# default settings
|
54 |
+
defaults:
|
55 |
+
- _self_
|
56 |
+
- x: ~
|
conf/x/clipascene.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'clipascene'
|
2 |
+
|
3 |
+
im_name: ""
|
4 |
+
image_size: 224
|
5 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
6 |
+
|
7 |
+
background_layer: 2 # 2, 8, 11
|
8 |
+
background_div: 0.35 # 0.35, 0.5, 0.85
|
9 |
+
background_num_iter: 1501
|
10 |
+
|
11 |
+
foreground_layer: 2 # 2, 8, 11
|
12 |
+
foreground_div: 0.4 # 0.4, 0.5, 0.9
|
13 |
+
foreground_num_iter: 600 # 1000 if foreground_layer >= 8 else 600
|
14 |
+
|
15 |
+
# general
|
16 |
+
target: null
|
17 |
+
output_dir: null
|
18 |
+
path_svg: "none"
|
19 |
+
mask_object: 0
|
20 |
+
resize_obj: 0
|
21 |
+
fix_scale: 0
|
22 |
+
display_logs: 0
|
23 |
+
display: 0
|
24 |
+
test_name: "test"
|
25 |
+
|
26 |
+
# training
|
27 |
+
num_iter: 2001
|
28 |
+
num_stages: 1
|
29 |
+
lr_scheduler: 0
|
30 |
+
lr: 0.0001
|
31 |
+
color_lr: 0.01
|
32 |
+
width_lr: 0.0001
|
33 |
+
color_vars_threshold: 0.0
|
34 |
+
batch_size: 1
|
35 |
+
save_step: 100
|
36 |
+
eval_step: 20
|
37 |
+
loss_mask: "none"
|
38 |
+
dilated_mask: 0
|
39 |
+
mask_cls: None
|
40 |
+
mask_attention: 0
|
41 |
+
|
42 |
+
# strokes params
|
43 |
+
num_paths: 64
|
44 |
+
width: 1.5
|
45 |
+
control_points_per_seg: 4
|
46 |
+
num_segments: 1
|
47 |
+
attention_init: 1
|
48 |
+
saliency_model: "clip"
|
49 |
+
saliency_clip_model: "ViT-B/32"
|
50 |
+
xdog_intersec: 1
|
51 |
+
mask_object_attention: 0
|
52 |
+
softmax_temp: 0.3
|
53 |
+
mlp_train: 1
|
54 |
+
width_optim: 0
|
55 |
+
mlp_width_weights_path: "none"
|
56 |
+
mlp_points_weights_path: "none"
|
57 |
+
switch_loss: 0
|
58 |
+
gumbel_temp: 0.2
|
59 |
+
width_loss_weight: 0
|
60 |
+
width_loss_type: "L1"
|
61 |
+
optimize_points: 1
|
62 |
+
load_points_opt_weights: 0
|
63 |
+
gradnorm: 0
|
64 |
+
width_weights_lst: ""
|
65 |
+
ratio_loss: 0
|
66 |
+
|
67 |
+
# loss
|
68 |
+
percep_loss: "none"
|
69 |
+
perceptual_weight: 0
|
70 |
+
train_with_clip: 0
|
71 |
+
clip_weight: 0
|
72 |
+
start_clip: 0
|
73 |
+
num_aug_clip: 4
|
74 |
+
include_target_in_aug: 0
|
75 |
+
augment_both: 1
|
76 |
+
augemntations: "affine"
|
77 |
+
noise_thresh: 0.5
|
78 |
+
aug_scale_min: 0.7
|
79 |
+
force_sparse: 0
|
80 |
+
clip_conv_loss: 1
|
81 |
+
clip_mask_loss: 0
|
82 |
+
clip_conv_loss_type: "L2"
|
83 |
+
clip_conv_layer_weights: "0,0,1.0,1.0,0"
|
84 |
+
clip_model_name: "ViT-B/32"
|
85 |
+
clip_fc_loss_weight: 0
|
86 |
+
clip_text_guide: 0
|
87 |
+
text_target: None
|
conf/x/clipasso.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'clipasso'
|
2 |
+
|
3 |
+
image_size: 224
|
4 |
+
mask_object: False
|
5 |
+
fix_scale: False
|
6 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
7 |
+
|
8 |
+
# train
|
9 |
+
num_iter: 2001
|
10 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
11 |
+
lr_schedule: False
|
12 |
+
lr: 1
|
13 |
+
color_lr: 0.01
|
14 |
+
color_vars_threshold: 0.0
|
15 |
+
|
16 |
+
# SVG path attr
|
17 |
+
num_paths: 24 # number of strokes
|
18 |
+
width: 1.5 # stroke width
|
19 |
+
control_points_per_seg: 4
|
20 |
+
num_segments: 1
|
21 |
+
attention_init: 1 # if True, use the attention heads of Dino model to set the location of the initial strokes
|
22 |
+
saliency_model: "clip"
|
23 |
+
saliency_clip_model: "ViT-B/32"
|
24 |
+
xdog_intersec: 1
|
25 |
+
mask_object_attention: 0
|
26 |
+
softmax_temp: 0.3
|
27 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
28 |
+
|
29 |
+
# loss
|
30 |
+
percep_loss: "none"
|
31 |
+
perceptual_weight: 0
|
32 |
+
train_with_clip: 0
|
33 |
+
clip_weight: 0
|
34 |
+
start_clip: 0
|
35 |
+
num_aug_clip: 4
|
36 |
+
include_target_in_aug: 0
|
37 |
+
augment_both: 0
|
38 |
+
augemntations: "affine" # can be any combination of: 'affine_noise_eraserchunks_eraser_press'
|
39 |
+
noise_thresh: 0.5
|
40 |
+
aug_scale_min: 0.7
|
41 |
+
force_sparse: 0 # if True, use L1 regularization on stroke's opacity to encourage small number of strokes
|
42 |
+
clip_conv_loss: 1
|
43 |
+
clip_conv_loss_type: "L2"
|
44 |
+
clip_conv_layer_weights: "0,0,1.0,1.0,0"
|
45 |
+
clip_model_name: "RN101"
|
46 |
+
clip_fc_loss_weight: 0.1
|
47 |
+
clip_text_guide: 0
|
48 |
+
text_target: None
|
conf/x/clipdraw.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'clipdraw'
|
2 |
+
|
3 |
+
image_size: 224 # canvas size
|
4 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
5 |
+
|
6 |
+
# train
|
7 |
+
num_iter: 1000
|
8 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
9 |
+
lr_schedule: True
|
10 |
+
lr: 1
|
11 |
+
width_lr: 0.1
|
12 |
+
color_lr: 0.01
|
13 |
+
|
14 |
+
# SVG path attr
|
15 |
+
num_paths: 512 # number of strokes
|
16 |
+
max_width: 50 # stroke width
|
17 |
+
black_stroke_color: False
|
18 |
+
|
19 |
+
# loss
|
20 |
+
num_aug: 4
|
conf/x/clipfont.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'clipfont'
|
2 |
+
|
3 |
+
# optimizer
|
4 |
+
lr_base:
|
5 |
+
point: 0.1
|
6 |
+
color: 0.01
|
7 |
+
lr_decay_rate: 0.1
|
8 |
+
decay_steps: [ 1000, 1500 ]
|
9 |
+
lr_schedule: False
|
10 |
+
|
11 |
+
# train
|
12 |
+
num_iter: 200
|
13 |
+
batch_size: 1
|
14 |
+
font:
|
15 |
+
reinit: False
|
16 |
+
reinit_color: 'randn' # 'randn', 'randn_all', 'green' et al
|
17 |
+
|
18 |
+
# loss
|
19 |
+
clip:
|
20 |
+
model_name: "ViT-B/32" # RN101, 'ViT-B/32', ViT-L/14
|
21 |
+
thresh: 0.0
|
22 |
+
num_crops: 128
|
23 |
+
crop_size: 230
|
24 |
+
lam_patch: 150
|
25 |
+
lam_dir: 30
|
26 |
+
lam_lpips: 0
|
27 |
+
lam_l2: 0
|
conf/x/diffsketcher.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'diffsketcher'
|
2 |
+
|
3 |
+
image_size: 224 # canvas size
|
4 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
5 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
6 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
7 |
+
|
8 |
+
# train
|
9 |
+
num_iter: 2000
|
10 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
11 |
+
lr_schedule: False
|
12 |
+
lr_decay_rate: 0.1
|
13 |
+
decay_steps: [ 1000, 1500 ]
|
14 |
+
lr: 1
|
15 |
+
color_lr: 0.01
|
16 |
+
color_vars_threshold: 0.0 # uncomment the code
|
17 |
+
width_lr: 0.1
|
18 |
+
max_width: 50 # stroke width
|
19 |
+
|
20 |
+
# stroke attrs
|
21 |
+
num_paths: 128 # number of strokes
|
22 |
+
width: 1.5 # stroke width
|
23 |
+
control_points_per_seg: 4
|
24 |
+
num_segments: 1
|
25 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
26 |
+
optim_width: False # if True, the stroke width is optimized
|
27 |
+
optim_rgba: False # if True, the stroke RGBA is optimized
|
28 |
+
opacity_delta: 0 # stroke pruning
|
29 |
+
|
30 |
+
# init strokes
|
31 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
32 |
+
xdog_intersec: True # initialize along the edge, mix XDoG and attn up
|
33 |
+
softmax_temp: 0.5
|
34 |
+
cross_attn_res: 16
|
35 |
+
self_attn_res: 32
|
36 |
+
max_com: 20
|
37 |
+
mean_comp: False
|
38 |
+
comp_idx: 0
|
39 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
40 |
+
log_cross_attn: False # True if cross attn every step
|
41 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
42 |
+
|
43 |
+
# ldm
|
44 |
+
model_id: "sd15"
|
45 |
+
ldm_speed_up: False
|
46 |
+
enable_xformers: True
|
47 |
+
gradient_checkpoint: False
|
48 |
+
token_ind: 5
|
49 |
+
use_ddim: True
|
50 |
+
num_inference_steps: 100
|
51 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
52 |
+
|
53 |
+
# ASDS loss
|
54 |
+
sds:
|
55 |
+
crop_size: 512
|
56 |
+
augmentations: "affine"
|
57 |
+
guidance_scale: 100
|
58 |
+
grad_scale: 1e-6
|
59 |
+
t_range: [ 0.05, 0.95 ]
|
60 |
+
warmup: 2000
|
61 |
+
|
62 |
+
clip:
|
63 |
+
model_name: "RN101" # RN101, ViT-L/14
|
64 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
65 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
66 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
67 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
68 |
+
augmentations: "affine" # augmentation before clip visual computation
|
69 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
70 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
71 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
72 |
+
|
73 |
+
perceptual:
|
74 |
+
name: "lpips" # dists
|
75 |
+
lpips_net: 'vgg'
|
76 |
+
coeff: 0.2
|
conf/x/diffvg.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'diffvg'
|
2 |
+
|
3 |
+
# train
|
4 |
+
num_iter: 2000 # num_iter
|
5 |
+
lr_base:
|
6 |
+
point: 1
|
7 |
+
color: 0.01
|
8 |
+
stroke_width: 0.1
|
9 |
+
stroke_color: 0.01
|
10 |
+
lr_schedule: False # use lr_schedule
|
11 |
+
|
12 |
+
# SVG path attr
|
13 |
+
num_paths: 512 # number of paths
|
14 |
+
max_width: 5.0 # maximum width
|
15 |
+
path_type: 'unclosed' # or 'closed', using Closed curve or non-closed curve
|
16 |
+
|
17 |
+
# loss
|
18 |
+
loss_type: 'l2' # or 'l1', 'l2', 'lpips', 'l2+lpips', loss type
|
conf/x/live.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'live'
|
2 |
+
|
3 |
+
image_size: 240 # img size and canvas size
|
4 |
+
|
5 |
+
# train
|
6 |
+
num_iter: 500 # num_iter per path group
|
7 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
8 |
+
lr_base:
|
9 |
+
point: 1
|
10 |
+
color: 0.01
|
11 |
+
bg: 0.01
|
12 |
+
stroke_width: 0.1
|
13 |
+
stroke_color: 0.01
|
14 |
+
lr_schedule: True # use lr_schedule
|
15 |
+
|
16 |
+
# SVG path attr
|
17 |
+
num_paths: 5 # number of strokes
|
18 |
+
path_schedule: 'repeat'
|
19 |
+
schedule_each: 1 # [1, 3, 5, 7]
|
20 |
+
train_stroke: False # train stroke width and color
|
21 |
+
trainable_bg: False # set the background to be trainable
|
22 |
+
width: 3 # stroke width
|
23 |
+
num_segments: 4
|
24 |
+
segment_init: 'circle' # 'random'
|
25 |
+
radius: 5
|
26 |
+
coord_init: 'sparse' # 'random', 'naive', place the first control point
|
27 |
+
|
28 |
+
# loss
|
29 |
+
use_l1_loss: False
|
30 |
+
use_distance_weighted_loss: True
|
31 |
+
xing_loss_weight: 0.01
|
conf/x/styleclipdraw.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'styleclipdraw'
|
2 |
+
|
3 |
+
image_size: 224 # canvas size
|
4 |
+
path_svg: ~ # if you want to load an svg file and train from it
|
5 |
+
|
6 |
+
# train
|
7 |
+
num_iter: 1000
|
8 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
9 |
+
lr_schedule: True # anneal learning rate
|
10 |
+
lr: 1
|
11 |
+
width_lr: 0.1
|
12 |
+
color_lr: 0.01
|
13 |
+
|
14 |
+
# strokes
|
15 |
+
num_paths: 512 # number of strokes
|
16 |
+
max_width: 50 # stroke width
|
17 |
+
black_stroke_color: False
|
18 |
+
style_strength: 50 # How strong the style should be. 100 (max) is a lot. 0 (min) is no style.
|
19 |
+
|
20 |
+
# loss
|
21 |
+
num_aug: 10 # Number of image augmentations
|
conf/x/stylediffsketcher.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: 'stylediffsketcher'
|
2 |
+
|
3 |
+
image_size: 224 # canvas size
|
4 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
5 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
6 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
7 |
+
|
8 |
+
# train
|
9 |
+
num_iter: 2000
|
10 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
11 |
+
lr_schedule: False
|
12 |
+
lr_decay_rate: 0.1
|
13 |
+
decay_steps: [ 1000, 1500 ]
|
14 |
+
lr: 1
|
15 |
+
color_lr: 0.01
|
16 |
+
color_vars_threshold: 0.0 # uncomment the code
|
17 |
+
width_lr: 0.1
|
18 |
+
max_width: 50 # stroke width
|
19 |
+
|
20 |
+
# SVG path attrs
|
21 |
+
num_paths: 512 # number of strokes
|
22 |
+
width: 1.5 # init stroke width
|
23 |
+
control_points_per_seg: 4
|
24 |
+
num_segments: 1
|
25 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
26 |
+
optim_width: True # if True, the stroke width is optimized
|
27 |
+
optim_rgba: True # if True, the stroke RGBA is optimized
|
28 |
+
|
29 |
+
# init strokes
|
30 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
31 |
+
xdog_intersec: True # initialize along the edge, mix XDoG and attn up
|
32 |
+
softmax_temp: 0.4
|
33 |
+
cross_attn_res: 16
|
34 |
+
self_attn_res: 32
|
35 |
+
max_com: 20
|
36 |
+
mean_comp: False
|
37 |
+
comp_idx: 0
|
38 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
39 |
+
log_cross_attn: False
|
40 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
41 |
+
|
42 |
+
# ldm
|
43 |
+
model_id: "sd15"
|
44 |
+
ldm_speed_up: False
|
45 |
+
enable_xformers: True
|
46 |
+
gradient_checkpoint: False
|
47 |
+
token_ind: 5
|
48 |
+
use_ddim: True
|
49 |
+
num_inference_steps: 100
|
50 |
+
guidance_scale: 7.5
|
51 |
+
|
52 |
+
# ASDS loss
|
53 |
+
sds:
|
54 |
+
crop_size: 512
|
55 |
+
augmentations: "affine"
|
56 |
+
guidance_scale: 100
|
57 |
+
grad_scale: 0
|
58 |
+
t_range: [ 0.05, 0.95 ]
|
59 |
+
warmup: 120
|
60 |
+
|
61 |
+
clip:
|
62 |
+
model_name: "RN101" # RN101, ViT-L/14
|
63 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
64 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
65 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
66 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
67 |
+
augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
|
68 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
69 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
70 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
71 |
+
|
72 |
+
perceptual:
|
73 |
+
name: "lpips" # dists
|
74 |
+
lpips_net: 'vgg'
|
75 |
+
coeff: 0.2
|
76 |
+
|
77 |
+
style_strength: 1 # How strong the style should be. 100 (max) is a lot. 0 (min) is no style.
|
conf/x/svgdreamer.yaml
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: "svgdreamer"
|
2 |
+
|
3 |
+
image_size: 600 # canvas size
|
4 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
5 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
6 |
+
skip_sive: True # optimize from scratch without SIVE init
|
7 |
+
color_init: 'rand' # if skip_live=True, then use color_init to init target_img
|
8 |
+
style: "iconography" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
9 |
+
|
10 |
+
# lr and optim
|
11 |
+
lr_stage_one: # SIVE stage
|
12 |
+
point: 1 # control points
|
13 |
+
width: 0.1 # stroke width
|
14 |
+
color: 0.01 # fill color and stroke color
|
15 |
+
bg: 0.01 # bg in render_warp
|
16 |
+
optim:
|
17 |
+
name: 'adam'
|
18 |
+
betas: [ 0.9, 0.9 ]
|
19 |
+
eps: 1e-6
|
20 |
+
lr_schedule: True # use lr_scheduler
|
21 |
+
schedule:
|
22 |
+
name: 'linear'
|
23 |
+
keep_ratio: 0.2
|
24 |
+
decay_ratio: 0.4
|
25 |
+
lr_stage_two: # VPSD stage
|
26 |
+
point: 1
|
27 |
+
width: 0.1
|
28 |
+
color: 0.01
|
29 |
+
bg: 0.01
|
30 |
+
lr_schedule: True # use lr_scheduler
|
31 |
+
optim:
|
32 |
+
name: 'adam'
|
33 |
+
betas: [ 0.9, 0.9 ]
|
34 |
+
eps: 1e-6
|
35 |
+
schedule:
|
36 |
+
name: 'cosine'
|
37 |
+
warmup_steps: 10
|
38 |
+
warmup_start_lr: 0.02
|
39 |
+
warmup_end_lr: 0.8
|
40 |
+
cosine_end_lr: 0.4
|
41 |
+
|
42 |
+
# primitives
|
43 |
+
num_paths: 256 # number of strokes
|
44 |
+
trainable_bg: False # set the background to be trainable
|
45 |
+
width: 3 # stroke width
|
46 |
+
num_segments: 4
|
47 |
+
segment_init: 'circle' # 'random'
|
48 |
+
radius: 20
|
49 |
+
coord_init: 'random' # 'random', 'naive', place the first control point
|
50 |
+
grid: 50 # divide the canvas into n grids
|
51 |
+
path_reinit: # reinitializing paths
|
52 |
+
use: True
|
53 |
+
freq: 100 # every 50 iterations
|
54 |
+
stop_step: 1000 # for VPSD fine-tuning
|
55 |
+
opacity_threshold: 0.05
|
56 |
+
area_threshold: 64
|
57 |
+
|
58 |
+
# diffusion
|
59 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
60 |
+
ldm_speed_up: False
|
61 |
+
enable_xformers: True
|
62 |
+
gradient_checkpoint: False
|
63 |
+
cpu_offload: True
|
64 |
+
num_inference_steps: 50
|
65 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
66 |
+
K: 4
|
67 |
+
lora_path: ~
|
68 |
+
|
69 |
+
# VPSD loss
|
70 |
+
guidance:
|
71 |
+
use: True
|
72 |
+
type: 'vpsd'
|
73 |
+
n_particle: 1 # 4, 8, 16
|
74 |
+
vsd_n_particle: 1 # the batch size of particles
|
75 |
+
particle_aug: False # do data enhancement for the input particles
|
76 |
+
num_iter: 2000 # total iterations
|
77 |
+
guidance_scale: 7.5 # CFG value
|
78 |
+
grad_scale: 1.0 # increase or decrease the gradient
|
79 |
+
grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
|
80 |
+
t_range: [ 0.02, 0.98 ]
|
81 |
+
# 'randint': random time steps, this may have a more authentic style.
|
82 |
+
# 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
|
83 |
+
t_schedule: 'max_0.5_1000' # or 'randint'
|
84 |
+
# phi model config
|
85 |
+
phi_single: False # if False new an unet model to estimate noise
|
86 |
+
phi_model: 'lora' # 'lora', 'unet_simple'
|
87 |
+
use_attn_scale: ${x.guidance.phi_single} # use lora_attn_scale or not
|
88 |
+
lora_attn_scale: 1.0 # the scale of the attn based lora layer
|
89 |
+
phi_guidance_scale: 1.0
|
90 |
+
phi_t: False # different t for phi fine-tuning
|
91 |
+
phi_update_step: 1 # enable multi-update phi model or not
|
92 |
+
phi_lr: 0.0001 # learning rate of phi model
|
93 |
+
phi_scheduler: 'ddim' # 'dpm-solver'
|
94 |
+
phi_n_particle: 1 # the batch size of phi_model
|
95 |
+
# ReFL config
|
96 |
+
phi_ReFL: False # enable reward feed back learning
|
97 |
+
n_phi_sample: 1 # number of samples used in ReFL
|
98 |
+
phi_sample_step: 200 # the phi log step
|
99 |
+
phi_infer_step: 50 # the phi num_inference_steps
|
100 |
+
# phi model optim
|
101 |
+
phi_optim:
|
102 |
+
name: 'adamw'
|
103 |
+
betas: [ 0.9, 0.999 ]
|
104 |
+
eps: 1e-8
|
105 |
+
weight_decay: ~ # 1e-5
|
106 |
+
# phi model lr learning schedule
|
107 |
+
phi_schedule:
|
108 |
+
use: False
|
109 |
+
name: 'cosine'
|
110 |
+
warmup_steps: 50
|
111 |
+
warmup_start_lr: 0.00001
|
112 |
+
warmup_end_lr: 0.0001
|
113 |
+
total_step: 800
|
114 |
+
cosine_end_lr: 0.0001
|
115 |
+
|
116 |
+
# reward model
|
117 |
+
reward_path: './checkpoint/ImageReward'
|
118 |
+
|
119 |
+
# xing loss for closed-form paths
|
120 |
+
xing_loss:
|
121 |
+
use: False
|
122 |
+
weight: 0.01
|
conf/x/vectorfusion.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: "vectorfusion"
|
2 |
+
|
3 |
+
image_size: 600 # canvas size
|
4 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
5 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
6 |
+
skip_live: False # if skip_live then training from scratch
|
7 |
+
style: "iconography" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
8 |
+
|
9 |
+
# train
|
10 |
+
batch_size: 1
|
11 |
+
num_iter: 500 # num_iter per path group
|
12 |
+
# lr and optim
|
13 |
+
lr_stage_one:
|
14 |
+
point: 1
|
15 |
+
width: 0.1
|
16 |
+
color: 0.01
|
17 |
+
bg: 0.01
|
18 |
+
optim:
|
19 |
+
name: 'adam'
|
20 |
+
betas: [ 0.9, 0.9 ]
|
21 |
+
eps: 1e-6
|
22 |
+
lr_schedule: True # use lr_scheduler
|
23 |
+
schedule:
|
24 |
+
name: 'linear'
|
25 |
+
keep_ratio: 0.2
|
26 |
+
decay_ratio: 0.4
|
27 |
+
lr_stage_two:
|
28 |
+
point: 1
|
29 |
+
width: 0.1
|
30 |
+
color: 0.01
|
31 |
+
bg: 0.01
|
32 |
+
lr_schedule: True # use lr_scheduler
|
33 |
+
optim:
|
34 |
+
name: 'adam'
|
35 |
+
betas: [ 0.9, 0.9 ]
|
36 |
+
eps: 1e-6
|
37 |
+
schedule:
|
38 |
+
name: 'cosine'
|
39 |
+
warmup_steps: 50
|
40 |
+
warmup_start_lr: 0.02
|
41 |
+
warmup_end_lr: 1.0
|
42 |
+
cosine_end_lr: 0.4
|
43 |
+
|
44 |
+
# primitives
|
45 |
+
num_paths: 128 # number of strokes
|
46 |
+
path_schedule: 'repeat' # 'list'
|
47 |
+
schedule_each: 16 # [1, 3, 5, 7]
|
48 |
+
trainable_bg: False # set the background to be trainable
|
49 |
+
width: 3 # stroke width
|
50 |
+
num_segments: 4
|
51 |
+
segment_init: 'circle' # 'random'
|
52 |
+
radius: 20
|
53 |
+
coord_init: 'sparse' # 'random', 'naive', place the first control point
|
54 |
+
grid: 32 # divide the canvas into n grids
|
55 |
+
path_reinit: # reinitializing paths
|
56 |
+
use: True
|
57 |
+
freq: 50 # every 50 iterations
|
58 |
+
stop_step: 800 # for SDS fine-tuning
|
59 |
+
opacity_threshold: 0.05
|
60 |
+
area_threshold: 64
|
61 |
+
|
62 |
+
# diffusion
|
63 |
+
model_id: "sd15" # sd14, sd15, sd21, sd21b, sdxl
|
64 |
+
ldm_speed_up: False
|
65 |
+
enable_xformers: True
|
66 |
+
gradient_checkpoint: False
|
67 |
+
cpu_offload: True
|
68 |
+
num_inference_steps: 50
|
69 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
70 |
+
K: 6
|
71 |
+
lora_path: ~
|
72 |
+
|
73 |
+
# SDS
|
74 |
+
sds:
|
75 |
+
im_size: 512
|
76 |
+
guidance_scale: 100
|
77 |
+
grad_scale: 1.0
|
78 |
+
t_range: [ 0.05, 0.95 ]
|
79 |
+
num_iter: 1000 # fine-tuning steps
|
80 |
+
|
81 |
+
# Live loss
|
82 |
+
use_distance_weighted_loss: True
|
83 |
+
xing_loss_weight: 0.01
|
84 |
+
# pixel loss
|
85 |
+
penalty_weight: 0.05
|
conf/x/wordasimage.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: "wordasimage"
|
2 |
+
|
3 |
+
image_size: 600 # canvas size
|
4 |
+
word: "BUNNY"
|
5 |
+
optim_letter: "Y"
|
6 |
+
prompt_suffix: "minimal flat 2d vector. lineal color. trending on artstation"
|
7 |
+
|
8 |
+
# train
|
9 |
+
num_iter: 500
|
10 |
+
lr_schedule: True
|
11 |
+
lr:
|
12 |
+
point_lr: 1
|
13 |
+
lr_init: 0.002
|
14 |
+
lr_final: 0.0008
|
15 |
+
lr_delay_mult: 0.1
|
16 |
+
lr_delay_steps: 100
|
17 |
+
|
18 |
+
# font
|
19 |
+
font: 'KaushanScript-Regular'
|
20 |
+
font_path: "./assets/fonts/${x.font}.ttf"
|
21 |
+
level_of_cc: 1 # 0 - original number of cc / 1 - recommended / 2 - more control points
|
22 |
+
|
23 |
+
# diffusion
|
24 |
+
model_id: "sd15"
|
25 |
+
ldm_speed_up: False
|
26 |
+
enable_xformers: False
|
27 |
+
gradient_checkpoint: False
|
28 |
+
lora_path: ~
|
29 |
+
|
30 |
+
# SDS
|
31 |
+
sds:
|
32 |
+
im_size: 512
|
33 |
+
guidance_scale: 100
|
34 |
+
grad_scale: 1.0
|
35 |
+
t_range: [ 0.05, 0.95 ]
|
36 |
+
num_iter: 1000
|
37 |
+
|
38 |
+
tone_loss:
|
39 |
+
use: True
|
40 |
+
dist_loss_weight: 100
|
41 |
+
pixel_dist_kernel_blur: 201
|
42 |
+
pixel_dist_sigma: 30
|
43 |
+
|
44 |
+
conformal:
|
45 |
+
use: True
|
46 |
+
angeles_w: 0.5
|
data/alphabet1.svg
ADDED
|
data/ballerina.png
ADDED
![]() |
data/ch1.svg
ADDED
|
data/fallingwater.png
ADDED
![]() |
data/horse.png
ADDED
![]() |
data/simile.png
ADDED
![]() |
data/starry.png
ADDED
![]() |