hjc-owo commited on
Commit
966ae59
·
1 Parent(s): 663d321
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +172 -0
  2. .gitmodules +3 -0
  3. ImageReward/ImageReward.py +177 -0
  4. ImageReward/ReFL.py +830 -0
  5. ImageReward/__init__.py +3 -0
  6. ImageReward/models/AestheticScore.py +95 -0
  7. ImageReward/models/BLIP/__init__.py +1 -0
  8. ImageReward/models/BLIP/blip.py +70 -0
  9. ImageReward/models/BLIP/blip_pretrain.py +43 -0
  10. ImageReward/models/BLIP/med.py +947 -0
  11. ImageReward/models/BLIP/vit.py +301 -0
  12. ImageReward/models/BLIPScore.py +97 -0
  13. ImageReward/models/CLIPScore.py +78 -0
  14. ImageReward/models/__init__.py +4 -0
  15. ImageReward/utils.py +184 -0
  16. Install.md +66 -0
  17. LICENSE +373 -0
  18. README copy.md +304 -0
  19. app.py +83 -0
  20. assets/fonts/Bell-MT.ttf +0 -0
  21. assets/fonts/DeliusUnicase-Regular.ttf +0 -0
  22. assets/fonts/HobeauxRococeaux-Sherman.ttf +0 -0
  23. assets/fonts/IndieFlower-Regular.ttf +0 -0
  24. assets/fonts/JosefinSans-Light.ttf +0 -0
  25. assets/fonts/KaushanScript-Regular.ttf +0 -0
  26. assets/fonts/LuckiestGuy-Regular.ttf +0 -0
  27. assets/fonts/Noteworthy-Bold.ttf +0 -0
  28. assets/fonts/Quicksand.ttf +0 -0
  29. assets/fonts/Saira-Regular.ttf +0 -0
  30. checkpoint/placeholder.md +1 -0
  31. conf/config.yaml +56 -0
  32. conf/x/clipascene.yaml +87 -0
  33. conf/x/clipasso.yaml +48 -0
  34. conf/x/clipdraw.yaml +20 -0
  35. conf/x/clipfont.yaml +27 -0
  36. conf/x/diffsketcher.yaml +76 -0
  37. conf/x/diffvg.yaml +18 -0
  38. conf/x/live.yaml +31 -0
  39. conf/x/styleclipdraw.yaml +21 -0
  40. conf/x/stylediffsketcher.yaml +77 -0
  41. conf/x/svgdreamer.yaml +122 -0
  42. conf/x/vectorfusion.yaml +85 -0
  43. conf/x/wordasimage.yaml +46 -0
  44. data/alphabet1.svg +726 -0
  45. data/ballerina.png +0 -0
  46. data/ch1.svg +0 -0
  47. data/fallingwater.png +0 -0
  48. data/horse.png +0 -0
  49. data/simile.png +0 -0
  50. data/starry.png +0 -0
.gitignore ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+ */.ipynb_checkpoints/*
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # pipenv
89
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
91
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
92
+ # install all needed dependencies.
93
+ #Pipfile.lock
94
+
95
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96
+ __pypackages__/
97
+
98
+ # Celery stuff
99
+ celerybeat-schedule
100
+ celerybeat.pid
101
+
102
+ # SageMath parsed files
103
+ *.sage.py
104
+
105
+ # Environments
106
+ .env
107
+ .venv
108
+ env/
109
+ venv/
110
+ ENV/
111
+ env.bak/
112
+ venv.bak/
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ # pytype static type analyzer
133
+ .pytype/
134
+
135
+ # Cython debug symbols
136
+ cython_debug/
137
+
138
+ # .idea
139
+ .idea/
140
+ /idea/
141
+ *.ipr
142
+ *.iml
143
+ *.iws
144
+
145
+ # macos system
146
+ .DS_Store
147
+
148
+ ### project ###
149
+
150
+ # /diffvg/
151
+ # big-lama*
152
+
153
+ # pytorch-lighting logs
154
+ lightning_logs/*
155
+
156
+ # Edit settings
157
+ .editorconfig
158
+
159
+ # model checkpoint
160
+ /checkpoint/u2net/u2net.pth
161
+ !/checkpoint/placeholder.md
162
+
163
+ # ignore local results
164
+ /workspace/
165
+ .workspace/
166
+
167
+ # ignore files
168
+ ./tmp/
169
+ ./tmp/*
170
+ /tmp/
171
+ /tmp_select/
172
+ /tmp_select/*
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "diffvg"]
2
+ path = diffvg
3
+ url = https://github.com/BachiLi/diffvg.git
ImageReward/ImageReward.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : ImageReward.py
3
+ @Time : 2023/01/28 19:53:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: ImageReward Reward model.
7
+ * Based on CLIP code base and improved-aesthetic-predictor code base
8
+ * https://github.com/openai/CLIP
9
+ * https://github.com/christophschuhmann/improved-aesthetic-predictor
10
+ '''
11
+
12
+ import os
13
+ import torch
14
+ import torch.nn as nn
15
+ from PIL import Image
16
+ from .models.BLIP.blip_pretrain import BLIP_Pretrain
17
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
18
+
19
+ try:
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+ BICUBIC = InterpolationMode.BICUBIC
23
+ except ImportError:
24
+ BICUBIC = Image.BICUBIC
25
+
26
+
27
+ def _convert_image_to_rgb(image):
28
+ return image.convert("RGB")
29
+
30
+
31
+ def _transform(n_px):
32
+ return Compose([
33
+ Resize(n_px, interpolation=BICUBIC),
34
+ CenterCrop(n_px),
35
+ _convert_image_to_rgb,
36
+ ToTensor(),
37
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
38
+ ])
39
+
40
+
41
+ class MLP(nn.Module):
42
+ def __init__(self, input_size):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+
46
+ self.layers = nn.Sequential(
47
+ nn.Linear(self.input_size, 1024),
48
+ # nn.ReLU(),
49
+ nn.Dropout(0.2),
50
+ nn.Linear(1024, 128),
51
+ # nn.ReLU(),
52
+ nn.Dropout(0.2),
53
+ nn.Linear(128, 64),
54
+ # nn.ReLU(),
55
+ nn.Dropout(0.1),
56
+ nn.Linear(64, 16),
57
+ # nn.ReLU(),
58
+ nn.Linear(16, 1)
59
+ )
60
+
61
+ # initial MLP param
62
+ for name, param in self.layers.named_parameters():
63
+ if 'weight' in name:
64
+ nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1))
65
+ if 'bias' in name:
66
+ nn.init.constant_(param, val=0)
67
+
68
+ def forward(self, input):
69
+ return self.layers(input)
70
+
71
+
72
+ class ImageReward(nn.Module):
73
+ def __init__(self, med_config, device='cpu'):
74
+ super().__init__()
75
+ self.device = device
76
+
77
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
78
+ self.preprocess = _transform(224)
79
+ self.mlp = MLP(768)
80
+
81
+ self.mean = 0.16717362830052426
82
+ self.std = 1.0333394966054072
83
+
84
+ def score_gard(self, prompt_ids, prompt_attention_mask, image):
85
+
86
+ image_embeds = self.blip.visual_encoder(image)
87
+ # text encode cross attention with image
88
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
89
+ text_output = self.blip.text_encoder(prompt_ids,
90
+ attention_mask=prompt_attention_mask,
91
+ encoder_hidden_states=image_embeds,
92
+ encoder_attention_mask=image_atts,
93
+ return_dict=True,
94
+ )
95
+
96
+ txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim)
97
+ rewards = self.mlp(txt_features)
98
+ rewards = (rewards - self.mean) / self.std
99
+
100
+ return rewards
101
+
102
+ def score(self, prompt, image):
103
+
104
+ if (type(image).__name__ == 'list'):
105
+ _, rewards = self.inference_rank(prompt, image)
106
+ return rewards
107
+
108
+ # text encode
109
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
110
+ return_tensors="pt").to(self.device)
111
+
112
+ # image encode
113
+ if isinstance(image, Image.Image):
114
+ pil_image = image
115
+ elif isinstance(image, str):
116
+ if os.path.isfile(image):
117
+ pil_image = Image.open(image)
118
+ else:
119
+ raise TypeError(
120
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
121
+
122
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
123
+ image_embeds = self.blip.visual_encoder(image)
124
+
125
+ # text encode cross attention with image
126
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
127
+ text_output = self.blip.text_encoder(text_input.input_ids,
128
+ attention_mask=text_input.attention_mask,
129
+ encoder_hidden_states=image_embeds,
130
+ encoder_attention_mask=image_atts,
131
+ return_dict=True,
132
+ )
133
+
134
+ txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
135
+ rewards = self.mlp(txt_features)
136
+ rewards = (rewards - self.mean) / self.std
137
+
138
+ return rewards.detach().cpu().numpy().item()
139
+
140
+ def inference_rank(self, prompt, generations_list):
141
+
142
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
143
+ return_tensors="pt").to(self.device)
144
+
145
+ txt_set = []
146
+ for generation in generations_list:
147
+ # image encode
148
+ if isinstance(generation, Image.Image):
149
+ pil_image = generation
150
+ elif isinstance(generation, str):
151
+ if os.path.isfile(generation):
152
+ pil_image = Image.open(generation)
153
+ else:
154
+ raise TypeError(
155
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
156
+
157
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
158
+ image_embeds = self.blip.visual_encoder(image)
159
+
160
+ # text encode cross attention with image
161
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
162
+ text_output = self.blip.text_encoder(text_input.input_ids,
163
+ attention_mask=text_input.attention_mask,
164
+ encoder_hidden_states=image_embeds,
165
+ encoder_attention_mask=image_atts,
166
+ return_dict=True)
167
+ txt_set.append(text_output.last_hidden_state[:, 0, :])
168
+
169
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
170
+ rewards = self.mlp(txt_features) # [image_num, 1]
171
+ rewards = (rewards - self.mean) / self.std
172
+ rewards = torch.squeeze(rewards)
173
+ _, rank = torch.sort(rewards, dim=0, descending=True)
174
+ _, indices = torch.sort(rank, dim=0)
175
+ indices = indices + 1
176
+
177
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/ReFL.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : ReFL.py
3
+ @Time : 2023/05/01 19:36:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: ReFL Algorithm.
7
+ * Based on diffusers code base
8
+ * https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
9
+ '''
10
+
11
+ import argparse
12
+ import logging
13
+ import math
14
+ import os
15
+ import random
16
+ from pathlib import Path
17
+
18
+ import accelerate
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ import transformers
24
+ from accelerate import Accelerator
25
+ from accelerate.logging import get_logger
26
+ from accelerate.utils import ProjectConfiguration, set_seed
27
+ from datasets import load_dataset
28
+ from huggingface_hub import create_repo, upload_folder
29
+ from packaging import version
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from PIL import Image
34
+ import ImageReward as RM
35
+
36
+ from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
37
+
38
+ try:
39
+ from torchvision.transforms import InterpolationMode
40
+
41
+ BICUBIC = InterpolationMode.BICUBIC
42
+ except ImportError:
43
+ BICUBIC = Image.BICUBIC
44
+
45
+ import diffusers
46
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
47
+ from diffusers.optimization import get_scheduler
48
+ from diffusers.training_utils import EMAModel
49
+ from diffusers.utils import check_min_version, deprecate
50
+ from diffusers.utils.import_utils import is_xformers_available
51
+
52
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
53
+ check_min_version("0.16.0.dev0")
54
+
55
+ logger = get_logger(__name__, log_level="INFO")
56
+
57
+ DATASET_NAME_MAPPING = {
58
+ "refl": ("image", "text"),
59
+ }
60
+
61
+
62
+ def parse_args():
63
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
64
+ parser.add_argument(
65
+ "--grad_scale", type=float, default=1e-3, help="Scale divided for grad loss value."
66
+ )
67
+ parser.add_argument(
68
+ "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
69
+ )
70
+ parser.add_argument(
71
+ "--revision",
72
+ type=str,
73
+ default=None,
74
+ required=False,
75
+ help="Revision of pretrained model identifier from huggingface.co/models.",
76
+ )
77
+ parser.add_argument(
78
+ "--dataset_name",
79
+ type=str,
80
+ default=None,
81
+ help=(
82
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
83
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
84
+ " or to a folder containing files that 🤗 Datasets can understand."
85
+ ),
86
+ )
87
+ parser.add_argument(
88
+ "--dataset_config_name",
89
+ type=str,
90
+ default=None,
91
+ help="The config of the Dataset, leave as None if there's only one config.",
92
+ )
93
+ parser.add_argument(
94
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
95
+ )
96
+ parser.add_argument(
97
+ "--caption_column",
98
+ type=str,
99
+ default="text",
100
+ help="The column of the dataset containing a caption or a list of captions.",
101
+ )
102
+ parser.add_argument(
103
+ "--max_train_samples",
104
+ type=int,
105
+ default=None,
106
+ help=(
107
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
108
+ "value if set."
109
+ ),
110
+ )
111
+ parser.add_argument(
112
+ "--validation_prompts",
113
+ type=str,
114
+ default=None,
115
+ nargs="+",
116
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
117
+ )
118
+ parser.add_argument(
119
+ "--output_dir",
120
+ type=str,
121
+ default="checkpoint/refl",
122
+ help="The output directory where the model predictions and checkpoints will be written.",
123
+ )
124
+ parser.add_argument(
125
+ "--cache_dir",
126
+ type=str,
127
+ default=None,
128
+ help="The directory where the downloaded models and datasets will be stored.",
129
+ )
130
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
131
+ parser.add_argument(
132
+ "--resolution",
133
+ type=int,
134
+ default=512,
135
+ help=(
136
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
137
+ " resolution"
138
+ ),
139
+ )
140
+ parser.add_argument(
141
+ "--center_crop",
142
+ default=False,
143
+ action="store_true",
144
+ help=(
145
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
146
+ " cropped. The images will be resized to the resolution first before cropping."
147
+ ),
148
+ )
149
+ parser.add_argument(
150
+ "--random_flip",
151
+ action="store_true",
152
+ help="whether to randomly flip images horizontally",
153
+ )
154
+ parser.add_argument(
155
+ "--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader."
156
+ )
157
+ parser.add_argument("--num_train_epochs", type=int, default=100)
158
+ parser.add_argument(
159
+ "--max_train_steps",
160
+ type=int,
161
+ default=100,
162
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
163
+ )
164
+ parser.add_argument(
165
+ "--gradient_accumulation_steps",
166
+ type=int,
167
+ default=4,
168
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
169
+ )
170
+ parser.add_argument(
171
+ "--gradient_checkpointing",
172
+ action="store_true",
173
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
174
+ )
175
+ parser.add_argument(
176
+ "--learning_rate",
177
+ type=float,
178
+ default=1e-5,
179
+ help="Initial learning rate (after the potential warmup period) to use.",
180
+ )
181
+ parser.add_argument(
182
+ "--scale_lr",
183
+ action="store_true",
184
+ default=False,
185
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
186
+ )
187
+ parser.add_argument(
188
+ "--lr_scheduler",
189
+ type=str,
190
+ default="constant",
191
+ help=(
192
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
193
+ ' "constant", "constant_with_warmup"]'
194
+ ),
195
+ )
196
+ parser.add_argument(
197
+ "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
198
+ )
199
+ parser.add_argument(
200
+ "--snr_gamma",
201
+ type=float,
202
+ default=None,
203
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
204
+ "More details here: https://arxiv.org/abs/2303.09556.",
205
+ )
206
+ parser.add_argument(
207
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
208
+ )
209
+ parser.add_argument(
210
+ "--allow_tf32",
211
+ action="store_true",
212
+ help=(
213
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
214
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
215
+ ),
216
+ )
217
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
218
+ parser.add_argument(
219
+ "--non_ema_revision",
220
+ type=str,
221
+ default=None,
222
+ required=False,
223
+ help=(
224
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
225
+ " remote repository specified with --pretrained_model_name_or_path."
226
+ ),
227
+ )
228
+ parser.add_argument(
229
+ "--dataloader_num_workers",
230
+ type=int,
231
+ default=0,
232
+ help=(
233
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
234
+ ),
235
+ )
236
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
237
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
238
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
239
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
240
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
241
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
242
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
243
+ parser.add_argument(
244
+ "--hub_model_id",
245
+ type=str,
246
+ default=None,
247
+ help="The name of the repository to keep in sync with the local `output_dir`.",
248
+ )
249
+ parser.add_argument(
250
+ "--logging_dir",
251
+ type=str,
252
+ default="logs",
253
+ help=(
254
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
255
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
256
+ ),
257
+ )
258
+ parser.add_argument(
259
+ "--mixed_precision",
260
+ type=str,
261
+ default=None,
262
+ choices=["no", "fp16", "bf16"],
263
+ help=(
264
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
265
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
266
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--report_to",
271
+ type=str,
272
+ default="tensorboard",
273
+ help=(
274
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
275
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
276
+ ),
277
+ )
278
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
279
+ parser.add_argument(
280
+ "--checkpointing_steps",
281
+ type=int,
282
+ default=100,
283
+ help=(
284
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
285
+ " training using `--resume_from_checkpoint`."
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--checkpoints_total_limit",
290
+ type=int,
291
+ default=None,
292
+ help=(
293
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
294
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
295
+ " for more docs"
296
+ ),
297
+ )
298
+ parser.add_argument(
299
+ "--resume_from_checkpoint",
300
+ type=str,
301
+ default=None,
302
+ help=(
303
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
304
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
309
+ )
310
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
311
+ parser.add_argument(
312
+ "--validation_epochs",
313
+ type=int,
314
+ default=5,
315
+ help="Run validation every X epochs.",
316
+ )
317
+ parser.add_argument(
318
+ "--tracker_project_name",
319
+ type=str,
320
+ default="text2image-refl",
321
+ help=(
322
+ "The `project_name` argument passed to Accelerator.init_trackers for"
323
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
324
+ ),
325
+ )
326
+
327
+ args = parser.parse_args()
328
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
329
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
330
+ args.local_rank = env_local_rank
331
+
332
+ # default to using the same revision for the non-ema model if not specified
333
+ if args.non_ema_revision is None:
334
+ args.non_ema_revision = args.revision
335
+
336
+ return args
337
+
338
+
339
+ class Trainer(object):
340
+
341
+ def __init__(self, pretrained_model_name_or_path, train_data_dir, args):
342
+
343
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
344
+ self.train_data_dir = train_data_dir
345
+
346
+ # Sanity checks
347
+ if args.dataset_name is None and self.train_data_dir is None:
348
+ raise ValueError("Need either a dataset name or a training folder.")
349
+
350
+ if args.non_ema_revision is not None:
351
+ deprecate(
352
+ "non_ema_revision!=None",
353
+ "0.15.0",
354
+ message=(
355
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
356
+ " use `--variant=non_ema` instead."
357
+ ),
358
+ )
359
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
360
+
361
+ accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
362
+
363
+ self.accelerator = Accelerator(
364
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
365
+ mixed_precision=args.mixed_precision,
366
+ log_with=args.report_to,
367
+ logging_dir=logging_dir,
368
+ project_config=accelerator_project_config,
369
+ )
370
+
371
+ # Make one log on every process with the configuration for debugging.
372
+ logging.basicConfig(
373
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
374
+ datefmt="%m/%d/%Y %H:%M:%S",
375
+ level=logging.INFO,
376
+ )
377
+ logger.info(self.accelerator.state, main_process_only=False)
378
+ if self.accelerator.is_local_main_process:
379
+ transformers.utils.logging.set_verbosity_warning()
380
+ diffusers.utils.logging.set_verbosity_info()
381
+ else:
382
+ transformers.utils.logging.set_verbosity_error()
383
+ diffusers.utils.logging.set_verbosity_error()
384
+
385
+ # If passed along, set the training seed now.
386
+ if args.seed is not None:
387
+ set_seed(args.seed)
388
+
389
+ # Handle the repository creation
390
+ if self.accelerator.is_main_process:
391
+ if args.output_dir is not None:
392
+ os.makedirs(args.output_dir, exist_ok=True)
393
+
394
+ if args.push_to_hub:
395
+ self.repo_id = create_repo(
396
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
397
+ ).repo_id
398
+
399
+ # Load scheduler, tokenizer and models.
400
+ self.noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
401
+ tokenizer = CLIPTokenizer.from_pretrained(
402
+ self.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
403
+ )
404
+ self.text_encoder = CLIPTextModel.from_pretrained(
405
+ self.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
406
+ )
407
+ self.vae = AutoencoderKL.from_pretrained(self.pretrained_model_name_or_path, subfolder="vae",
408
+ revision=args.revision)
409
+ self.unet = UNet2DConditionModel.from_pretrained(
410
+ self.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
411
+ )
412
+ self.reward_model = RM.load("ImageReward-v1.0", device=self.accelerator.device)
413
+
414
+ # Freeze vae and text_encoder
415
+ self.vae.requires_grad_(False)
416
+ self.text_encoder.requires_grad_(False)
417
+ self.reward_model.requires_grad_(False)
418
+
419
+ # Create EMA for the unet.
420
+ if args.use_ema:
421
+ self.ema_unet = UNet2DConditionModel.from_pretrained(
422
+ self.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
423
+ )
424
+ self.ema_unet = EMAModel(self.ema_unet.parameters(), model_cls=UNet2DConditionModel,
425
+ model_config=self.ema_unet.config)
426
+
427
+ if args.enable_xformers_memory_efficient_attention:
428
+ if is_xformers_available():
429
+ import xformers
430
+
431
+ xformers_version = version.parse(xformers.__version__)
432
+ if xformers_version == version.parse("0.0.16"):
433
+ logger.warn(
434
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
435
+ )
436
+ self.unet.enable_xformers_memory_efficient_attention()
437
+ else:
438
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
439
+
440
+ # `accelerate` 0.16.0 will have better support for customized saving
441
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
442
+ # create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format
443
+ def save_model_hook(models, weights, output_dir):
444
+ if args.use_ema:
445
+ self.ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
446
+
447
+ for i, model in enumerate(models):
448
+ model.save_pretrained(os.path.join(output_dir, "unet"))
449
+
450
+ # make sure to pop weight so that corresponding model is not saved again
451
+ weights.pop()
452
+
453
+ def load_model_hook(models, input_dir):
454
+ if args.use_ema:
455
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
456
+ self.ema_unet.load_state_dict(load_model.state_dict())
457
+ self.ema_unet.to(self.accelerator.device)
458
+ del load_model
459
+
460
+ for i in range(len(models)):
461
+ # pop models so that they are not loaded again
462
+ model = models.pop()
463
+
464
+ # load diffusers style into model
465
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
466
+ model.register_to_config(**load_model.config)
467
+
468
+ model.load_state_dict(load_model.state_dict())
469
+ del load_model
470
+
471
+ self.accelerator.register_save_state_pre_hook(save_model_hook)
472
+ self.accelerator.register_load_state_pre_hook(load_model_hook)
473
+
474
+ if args.gradient_checkpointing:
475
+ self.unet.enable_gradient_checkpointing()
476
+
477
+ # Enable TF32 for faster training on Ampere GPUs,
478
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
479
+ if args.allow_tf32:
480
+ torch.backends.cuda.matmul.allow_tf32 = True
481
+
482
+ if args.scale_lr:
483
+ args.learning_rate = (
484
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes
485
+ )
486
+
487
+ # Initialize the optimizer
488
+ if args.use_8bit_adam:
489
+ try:
490
+ import bitsandbytes as bnb
491
+ except ImportError:
492
+ raise ImportError(
493
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
494
+ )
495
+
496
+ optimizer_cls = bnb.optim.AdamW8bit
497
+ else:
498
+ optimizer_cls = torch.optim.AdamW
499
+
500
+ self.optimizer = optimizer_cls(
501
+ self.unet.parameters(),
502
+ lr=args.learning_rate,
503
+ betas=(args.adam_beta1, args.adam_beta2),
504
+ weight_decay=args.adam_weight_decay,
505
+ eps=args.adam_epsilon,
506
+ )
507
+
508
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
509
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
510
+
511
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
512
+ # download the dataset.
513
+ if args.dataset_name is not None:
514
+ # Downloading and loading a dataset from the hub.
515
+ dataset = load_dataset(
516
+ args.dataset_name,
517
+ args.dataset_config_name,
518
+ cache_dir=args.cache_dir,
519
+ )
520
+ else:
521
+ data_files = {}
522
+ data_files["train"] = self.train_data_dir
523
+ dataset = load_dataset(
524
+ "json",
525
+ data_files=data_files,
526
+ cache_dir=args.cache_dir,
527
+ )
528
+ # See more about loading custom images at
529
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
530
+
531
+ # Preprocessing the datasets.
532
+ # We need to tokenize inputs and targets.
533
+ column_names = dataset["train"].column_names
534
+
535
+ # Get the column names for input/target.
536
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
537
+ if args.image_column is None:
538
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
539
+ else:
540
+ image_column = args.image_column
541
+ if image_column not in column_names:
542
+ raise ValueError(
543
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
544
+ )
545
+ if args.caption_column is None:
546
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
547
+ else:
548
+ caption_column = args.caption_column
549
+ if caption_column not in column_names:
550
+ raise ValueError(
551
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
552
+ )
553
+
554
+ # Preprocessing the datasets.
555
+ # We need to tokenize input captions and transform the images.
556
+ def tokenize_captions(examples, is_train=True):
557
+ captions = []
558
+ for caption in examples[caption_column]:
559
+ if isinstance(caption, str):
560
+ captions.append(caption)
561
+ elif isinstance(caption, (list, np.ndarray)):
562
+ # take a random caption if there are multiple
563
+ captions.append(random.choice(caption) if is_train else caption[0])
564
+ else:
565
+ raise ValueError(
566
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
567
+ )
568
+ inputs = tokenizer(
569
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
570
+ return_tensors="pt"
571
+ )
572
+ return inputs.input_ids
573
+
574
+ def preprocess_train(examples):
575
+ examples["input_ids"] = tokenize_captions(examples)
576
+ examples["rm_input_ids"] = self.reward_model.blip.tokenizer(examples[caption_column], padding='max_length',
577
+ truncation=True, max_length=35,
578
+ return_tensors="pt").input_ids
579
+ examples["rm_attention_mask"] = self.reward_model.blip.tokenizer(examples[caption_column],
580
+ padding='max_length', truncation=True,
581
+ max_length=35,
582
+ return_tensors="pt").attention_mask
583
+ return examples
584
+
585
+ with self.accelerator.main_process_first():
586
+ if args.max_train_samples is not None:
587
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
588
+ # Set the training transforms
589
+ self.train_dataset = dataset["train"].with_transform(preprocess_train)
590
+
591
+ def collate_fn(examples):
592
+ input_ids = torch.stack([example["input_ids"] for example in examples])
593
+ rm_input_ids = torch.stack([example["rm_input_ids"] for example in examples])
594
+ rm_attention_mask = torch.stack([example["rm_attention_mask"] for example in examples])
595
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
596
+ rm_input_ids = rm_input_ids.view(-1, rm_input_ids.shape[-1])
597
+ rm_attention_mask = rm_attention_mask.view(-1, rm_attention_mask.shape[-1])
598
+ return {"input_ids": input_ids, "rm_input_ids": rm_input_ids, "rm_attention_mask": rm_attention_mask}
599
+
600
+ # DataLoaders creation:
601
+ self.train_dataloader = torch.utils.data.DataLoader(
602
+ self.train_dataset,
603
+ shuffle=True,
604
+ collate_fn=collate_fn,
605
+ batch_size=args.train_batch_size,
606
+ num_workers=args.dataloader_num_workers,
607
+ )
608
+
609
+ # Scheduler and math around the number of training steps.
610
+ overrode_max_train_steps = False
611
+ self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
612
+ if args.max_train_steps is None:
613
+ args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
614
+ overrode_max_train_steps = True
615
+
616
+ self.lr_scheduler = get_scheduler(
617
+ args.lr_scheduler,
618
+ optimizer=self.optimizer,
619
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
620
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
621
+ )
622
+
623
+ # Prepare everything with our `self.accelerator`.
624
+ self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
625
+ self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler
626
+ )
627
+
628
+ if args.use_ema:
629
+ self.ema_unet.to(self.accelerator.device)
630
+
631
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
632
+ # as these models are only used for inference, keeping weights in full precision is not required.
633
+ self.weight_dtype = torch.float32
634
+ if self.accelerator.mixed_precision == "fp16":
635
+ self.weight_dtype = torch.float16
636
+ elif self.accelerator.mixed_precision == "bf16":
637
+ self.weight_dtype = torch.bfloat16
638
+
639
+ # Move text_encode and vae to gpu and cast to self.weight_dtype
640
+ self.text_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
641
+ self.vae.to(self.accelerator.device, dtype=self.weight_dtype)
642
+ self.reward_model.to(self.accelerator.device, dtype=self.weight_dtype)
643
+
644
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
645
+ self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
646
+ if overrode_max_train_steps:
647
+ args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
648
+ # Afterwards we recalculate our number of training epochs
649
+ args.num_train_epochs = math.ceil(args.max_train_steps / self.num_update_steps_per_epoch)
650
+
651
+ # We need to initialize the trackers we use, and also store our configuration.
652
+ # The trackers initializes automatically on the main process.
653
+ if self.accelerator.is_main_process:
654
+ tracker_config = dict(vars(args))
655
+ tracker_config.pop("validation_prompts")
656
+ self.accelerator.init_trackers(args.tracker_project_name, tracker_config)
657
+
658
+ def train(self, args):
659
+
660
+ # Train!
661
+ total_batch_size = args.train_batch_size * self.accelerator.num_processes * args.gradient_accumulation_steps
662
+
663
+ logger.info("***** Running training *****")
664
+ logger.info(f" Num examples = {len(self.train_dataset)}")
665
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
666
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
667
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
668
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
669
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
670
+ global_step = 0
671
+ first_epoch = 0
672
+
673
+ # Potentially load in the weights and states from a previous save
674
+ if args.resume_from_checkpoint:
675
+ if args.resume_from_checkpoint != "latest":
676
+ path = os.path.basename(args.resume_from_checkpoint)
677
+ else:
678
+ # Get the most recent checkpoint
679
+ dirs = os.listdir(args.output_dir)
680
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
681
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
682
+ path = dirs[-1] if len(dirs) > 0 else None
683
+
684
+ if path is None:
685
+ self.accelerator.print(
686
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
687
+ )
688
+ args.resume_from_checkpoint = None
689
+ else:
690
+ self.accelerator.print(f"Resuming from checkpoint {path}")
691
+ self.accelerator.load_state(os.path.join(args.output_dir, path))
692
+ global_step = int(path.split("-")[1])
693
+
694
+ resume_global_step = global_step * args.gradient_accumulation_steps
695
+ first_epoch = global_step // self.num_update_steps_per_epoch
696
+ resume_step = resume_global_step % (self.num_update_steps_per_epoch * args.gradient_accumulation_steps)
697
+
698
+ # Only show the progress bar once on each machine.
699
+ progress_bar = tqdm(range(global_step, args.max_train_steps),
700
+ disable=not self.accelerator.is_local_main_process)
701
+ progress_bar.set_description("Steps")
702
+
703
+ for epoch in range(first_epoch, args.num_train_epochs):
704
+ self.unet.train()
705
+ train_loss = 0.0
706
+ for step, batch in enumerate(self.train_dataloader):
707
+ # Skip steps until we reach the resumed step
708
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
709
+ if step % args.gradient_accumulation_steps == 0:
710
+ progress_bar.update(1)
711
+ continue
712
+
713
+ with self.accelerator.accumulate(self.unet):
714
+ encoder_hidden_states = self.text_encoder(batch["input_ids"])[0]
715
+ latents = torch.randn((args.train_batch_size, 4, 64, 64), device=self.accelerator.device)
716
+
717
+ self.noise_scheduler.set_timesteps(40, device=self.accelerator.device)
718
+ timesteps = self.noise_scheduler.timesteps
719
+
720
+ mid_timestep = random.randint(30, 39)
721
+
722
+ for i, t in enumerate(timesteps[:mid_timestep]):
723
+ with torch.no_grad():
724
+ latent_model_input = latents
725
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
726
+ noise_pred = self.unet(
727
+ latent_model_input,
728
+ t,
729
+ encoder_hidden_states=encoder_hidden_states,
730
+ ).sample
731
+ latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
732
+
733
+ latent_model_input = latents
734
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input,
735
+ timesteps[mid_timestep])
736
+ noise_pred = self.unet(
737
+ latent_model_input,
738
+ timesteps[mid_timestep],
739
+ encoder_hidden_states=encoder_hidden_states,
740
+ ).sample
741
+ pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep],
742
+ latents).pred_original_sample.to(self.weight_dtype)
743
+
744
+ pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
745
+ image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
746
+ image = (image / 2 + 0.5).clamp(0, 1)
747
+
748
+ # image encode
749
+ def _transform():
750
+ return Compose([
751
+ Resize(224, interpolation=BICUBIC),
752
+ CenterCrop(224),
753
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
754
+ ])
755
+
756
+ rm_preprocess = _transform()
757
+ image = rm_preprocess(image).to(self.accelerator.device)
758
+
759
+ rewards = self.reward_model.score_gard(batch["rm_input_ids"], batch["rm_attention_mask"], image)
760
+ loss = F.relu(-rewards + 2)
761
+ loss = loss.mean() * args.grad_scale
762
+
763
+ # Gather the losses across all processes for logging (if we use distributed training).
764
+ avg_loss = self.accelerator.gather(loss.repeat(args.train_batch_size)).mean()
765
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
766
+
767
+ # Backpropagate
768
+ self.accelerator.backward(loss)
769
+ if self.accelerator.sync_gradients:
770
+ self.accelerator.clip_grad_norm_(self.unet.parameters(), args.max_grad_norm)
771
+ self.optimizer.step()
772
+ self.lr_scheduler.step()
773
+ self.optimizer.zero_grad()
774
+
775
+ # Checks if the self.accelerator has performed an optimization step behind the scenes
776
+ if self.accelerator.sync_gradients:
777
+ if args.use_ema:
778
+ self.ema_unet.step(self.unet.parameters())
779
+ progress_bar.update(1)
780
+ global_step += 1
781
+ self.accelerator.log({"train_loss": train_loss}, step=global_step)
782
+ train_loss = 0.0
783
+
784
+ if global_step % args.checkpointing_steps == 0:
785
+ if self.accelerator.is_main_process:
786
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
787
+ self.accelerator.save_state(save_path)
788
+ logger.info(f"Saved state to {save_path}")
789
+
790
+ logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
791
+ progress_bar.set_postfix(**logs)
792
+
793
+ if global_step >= args.max_train_steps:
794
+ break
795
+
796
+ if self.accelerator.is_main_process:
797
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
798
+ if args.use_ema:
799
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
800
+ self.ema_unet.store(self.unet.parameters())
801
+ self.ema_unet.copy_to(self.unet.parameters())
802
+ if args.use_ema:
803
+ # Switch back to the original UNet parameters.
804
+ self.ema_unet.restore(self.unet.parameters())
805
+
806
+ # Create the pipeline using the trained modules and save it.
807
+ self.accelerator.wait_for_everyone()
808
+ if self.accelerator.is_main_process:
809
+ self.unet = self.accelerator.unwrap_model(self.unet)
810
+ if args.use_ema:
811
+ self.ema_unet.copy_to(self.unet.parameters())
812
+
813
+ pipeline = StableDiffusionPipeline.from_pretrained(
814
+ self.pretrained_model_name_or_path,
815
+ text_encoder=self.text_encoder,
816
+ vae=self.vae,
817
+ unet=self.unet,
818
+ revision=args.revision,
819
+ )
820
+ pipeline.save_pretrained(args.output_dir)
821
+
822
+ if args.push_to_hub:
823
+ upload_folder(
824
+ repo_id=self.repo_id,
825
+ folder_path=args.output_dir,
826
+ commit_message="End of training",
827
+ ignore_patterns=["step_*", "epoch_*"],
828
+ )
829
+
830
+ self.accelerator.end_training()
ImageReward/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils import *
2
+ from .models import *
3
+ from .ReFL import *
ImageReward/models/AestheticScore.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : AestheticScore.py
3
+ @Time : 2023/02/12 14:54:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: AestheticScore.
7
+ * Based on improved-aesthetic-predictor code base
8
+ * https://github.com/christophschuhmann/improved-aesthetic-predictor
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ import clip
16
+
17
+
18
+ # if you changed the MLP architecture during training, change it also here:
19
+ class MLP(nn.Module):
20
+ def __init__(self, input_size):
21
+ super().__init__()
22
+ self.input_size = input_size
23
+ self.layers = nn.Sequential(
24
+ nn.Linear(self.input_size, 1024),
25
+ # nn.ReLU(),
26
+ nn.Dropout(0.2),
27
+ nn.Linear(1024, 128),
28
+ # nn.ReLU(),
29
+ nn.Dropout(0.2),
30
+ nn.Linear(128, 64),
31
+ # nn.ReLU(),
32
+ nn.Dropout(0.1),
33
+
34
+ nn.Linear(64, 16),
35
+ # nn.ReLU(),
36
+
37
+ nn.Linear(16, 1)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.layers(x)
42
+
43
+
44
+ class AestheticScore(nn.Module):
45
+ def __init__(self, download_root, device='cpu'):
46
+ super().__init__()
47
+ self.device = device
48
+ self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
49
+ download_root=download_root)
50
+ self.mlp = MLP(768)
51
+
52
+ if device == "cpu":
53
+ self.clip_model.float()
54
+ else:
55
+ clip.model.convert_weights(
56
+ self.clip_model) # Actually this line is unnecessary since clip by default already on float16
57
+
58
+ # have clip.logit_scale require no grad.
59
+ self.clip_model.logit_scale.requires_grad_(False)
60
+
61
+ def score(self, prompt, image_path):
62
+
63
+ if (type(image_path).__name__ == 'list'):
64
+ _, rewards = self.inference_rank(prompt, image_path)
65
+ return rewards
66
+
67
+ # image encode
68
+ pil_image = Image.open(image_path)
69
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
70
+ image_features = F.normalize(self.clip_model.encode_image(image)).float()
71
+
72
+ # score
73
+ rewards = self.mlp(image_features)
74
+
75
+ return rewards.detach().cpu().numpy().item()
76
+
77
+ def inference_rank(self, prompt, generations_list):
78
+
79
+ img_set = []
80
+ for generations in generations_list:
81
+ # image encode
82
+ img_path = generations
83
+ pil_image = Image.open(img_path)
84
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
85
+ image_features = F.normalize(self.clip_model.encode_image(image))
86
+ img_set.append(image_features)
87
+
88
+ img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
89
+ rewards = self.mlp(img_features)
90
+ rewards = torch.squeeze(rewards)
91
+ _, rank = torch.sort(rewards, dim=0, descending=True)
92
+ _, indices = torch.sort(rank, dim=0)
93
+ indices = indices + 1
94
+
95
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/models/BLIP/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .blip_pretrain import *
ImageReward/models/BLIP/blip.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import torch
9
+ import os
10
+ from urllib.parse import urlparse
11
+ from timm.models.hub import download_cached_file
12
+ from transformers import BertTokenizer
13
+ from .vit import VisionTransformer, interpolate_pos_embed
14
+
15
+
16
+ def init_tokenizer():
17
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
18
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
19
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
20
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
21
+ return tokenizer
22
+
23
+
24
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
25
+
26
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
27
+ if vit=='base':
28
+ vision_width = 768
29
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
30
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
31
+ drop_path_rate=0 or drop_path_rate
32
+ )
33
+ elif vit=='large':
34
+ vision_width = 1024
35
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
36
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
37
+ drop_path_rate=0.1 or drop_path_rate
38
+ )
39
+ return visual_encoder, vision_width
40
+
41
+
42
+ def is_url(url_or_filename):
43
+ parsed = urlparse(url_or_filename)
44
+ return parsed.scheme in ("http", "https")
45
+
46
+ def load_checkpoint(model,url_or_filename):
47
+ if is_url(url_or_filename):
48
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
49
+ checkpoint = torch.load(cached_file, map_location='cpu')
50
+ elif os.path.isfile(url_or_filename):
51
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
52
+ else:
53
+ raise RuntimeError('checkpoint url or path is invalid')
54
+
55
+ state_dict = checkpoint['model']
56
+
57
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
58
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
59
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
60
+ model.visual_encoder_m)
61
+ for key in model.state_dict().keys():
62
+ if key in state_dict.keys():
63
+ if state_dict[key].shape!=model.state_dict()[key].shape:
64
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
65
+ del state_dict[key]
66
+
67
+ msg = model.load_state_dict(state_dict,strict=False)
68
+ print('load checkpoint from %s'%url_or_filename)
69
+ return model,msg
70
+
ImageReward/models/BLIP/blip_pretrain.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ from torch import nn
9
+ import os
10
+ from .med import BertConfig, BertModel
11
+ from .blip import create_vit, init_tokenizer
12
+
13
+ class BLIP_Pretrain(nn.Module):
14
+ def __init__(self,
15
+ med_config = "med_config.json",
16
+ image_size = 224,
17
+ vit = 'base',
18
+ vit_grad_ckpt = False,
19
+ vit_ckpt_layer = 0,
20
+ embed_dim = 256,
21
+ queue_size = 57600,
22
+ momentum = 0.995,
23
+ ):
24
+ """
25
+ Args:
26
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
27
+ image_size (int): input image size
28
+ vit (str): model size of vision transformer
29
+ """
30
+ super().__init__()
31
+
32
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
33
+
34
+ self.tokenizer = init_tokenizer()
35
+ encoder_config = BertConfig.from_json_file(med_config)
36
+ encoder_config.encoder_width = vision_width
37
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
38
+
39
+ text_width = self.text_encoder.config.hidden_size
40
+
41
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
42
+ self.text_proj = nn.Linear(text_width, embed_dim)
43
+
ImageReward/models/BLIP/med.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on huggingface code base
4
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
5
+ '''
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch import Tensor, device, nn
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.file_utils import (
18
+ ModelOutput,
19
+ )
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutputWithPastAndCrossAttentions,
22
+ BaseModelOutputWithPoolingAndCrossAttentions,
23
+ CausalLMOutputWithCrossAttentions,
24
+ MaskedLMOutput,
25
+ MultipleChoiceModelOutput,
26
+ NextSentencePredictorOutput,
27
+ QuestionAnsweringModelOutput,
28
+ SequenceClassifierOutput,
29
+ TokenClassifierOutput,
30
+ )
31
+ from transformers.modeling_utils import (
32
+ PreTrainedModel,
33
+ apply_chunking_to_forward,
34
+ find_pruneable_heads_and_indices,
35
+ prune_linear_layer,
36
+ )
37
+ from transformers.utils import logging
38
+ from transformers.models.bert.configuration_bert import BertConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class BertEmbeddings(nn.Module):
45
+ """Construct the embeddings from word and position embeddings."""
46
+
47
+ def __init__(self, config):
48
+ super().__init__()
49
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
50
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
51
+
52
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
53
+ # any TensorFlow checkpoint file
54
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
56
+
57
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
58
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
59
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
60
+
61
+ self.config = config
62
+
63
+ def forward(
64
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
65
+ ):
66
+ if input_ids is not None:
67
+ input_shape = input_ids.size()
68
+ else:
69
+ input_shape = inputs_embeds.size()[:-1]
70
+
71
+ seq_length = input_shape[1]
72
+
73
+ if position_ids is None:
74
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
75
+
76
+ if inputs_embeds is None:
77
+ inputs_embeds = self.word_embeddings(input_ids)
78
+
79
+ embeddings = inputs_embeds
80
+
81
+ if self.position_embedding_type == "absolute":
82
+ position_embeddings = self.position_embeddings(position_ids)
83
+ embeddings += position_embeddings
84
+ embeddings = self.LayerNorm(embeddings)
85
+ embeddings = self.dropout(embeddings)
86
+ return embeddings
87
+
88
+
89
+ class BertSelfAttention(nn.Module):
90
+ def __init__(self, config, is_cross_attention):
91
+ super().__init__()
92
+ self.config = config
93
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
94
+ raise ValueError(
95
+ "The hidden size (%d) is not a multiple of the number of attention "
96
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
97
+ )
98
+
99
+ self.num_attention_heads = config.num_attention_heads
100
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
101
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
102
+
103
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
104
+ if is_cross_attention:
105
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
106
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
107
+ else:
108
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
109
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
110
+
111
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
112
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
113
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
114
+ self.max_position_embeddings = config.max_position_embeddings
115
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
116
+ self.save_attention = False
117
+
118
+ def save_attn_gradients(self, attn_gradients):
119
+ self.attn_gradients = attn_gradients
120
+
121
+ def get_attn_gradients(self):
122
+ return self.attn_gradients
123
+
124
+ def save_attention_map(self, attention_map):
125
+ self.attention_map = attention_map
126
+
127
+ def get_attention_map(self):
128
+ return self.attention_map
129
+
130
+ def transpose_for_scores(self, x):
131
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
132
+ x = x.view(*new_x_shape)
133
+ return x.permute(0, 2, 1, 3)
134
+
135
+ def forward(
136
+ self,
137
+ hidden_states,
138
+ attention_mask=None,
139
+ head_mask=None,
140
+ encoder_hidden_states=None,
141
+ encoder_attention_mask=None,
142
+ past_key_value=None,
143
+ output_attentions=False,
144
+ ):
145
+ mixed_query_layer = self.query(hidden_states)
146
+
147
+ # If this is instantiated as a cross-attention module, the keys
148
+ # and values come from an encoder; the attention mask needs to be
149
+ # such that the encoder's padding tokens are not attended to.
150
+ is_cross_attention = encoder_hidden_states is not None
151
+
152
+ if is_cross_attention:
153
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
154
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
155
+ attention_mask = encoder_attention_mask
156
+ elif past_key_value is not None:
157
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
158
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
159
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
160
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
161
+ else:
162
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
163
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
164
+
165
+ query_layer = self.transpose_for_scores(mixed_query_layer)
166
+
167
+ past_key_value = (key_layer, value_layer)
168
+
169
+ # Take the dot product between "query" and "key" to get the raw attention scores.
170
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
171
+
172
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
173
+ seq_length = hidden_states.size()[1]
174
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
175
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
176
+ distance = position_ids_l - position_ids_r
177
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
178
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
179
+
180
+ if self.position_embedding_type == "relative_key":
181
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
182
+ attention_scores = attention_scores + relative_position_scores
183
+ elif self.position_embedding_type == "relative_key_query":
184
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
185
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
186
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
187
+
188
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
189
+ if attention_mask is not None:
190
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
191
+ attention_scores = attention_scores + attention_mask
192
+
193
+ # Normalize the attention scores to probabilities.
194
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
195
+
196
+ if is_cross_attention and self.save_attention:
197
+ self.save_attention_map(attention_probs)
198
+ attention_probs.register_hook(self.save_attn_gradients)
199
+
200
+ # This is actually dropping out entire tokens to attend to, which might
201
+ # seem a bit unusual, but is taken from the original Transformer paper.
202
+ attention_probs_dropped = self.dropout(attention_probs)
203
+
204
+ # Mask heads if we want to
205
+ if head_mask is not None:
206
+ attention_probs_dropped = attention_probs_dropped * head_mask
207
+
208
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
209
+
210
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
211
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
212
+ context_layer = context_layer.view(*new_context_layer_shape)
213
+
214
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
215
+
216
+ outputs = outputs + (past_key_value,)
217
+ return outputs
218
+
219
+
220
+ class BertSelfOutput(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
224
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
225
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
226
+
227
+ def forward(self, hidden_states, input_tensor):
228
+ hidden_states = self.dense(hidden_states)
229
+ hidden_states = self.dropout(hidden_states)
230
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
231
+ return hidden_states
232
+
233
+
234
+ class BertAttention(nn.Module):
235
+ def __init__(self, config, is_cross_attention=False):
236
+ super().__init__()
237
+ self.self = BertSelfAttention(config, is_cross_attention)
238
+ self.output = BertSelfOutput(config)
239
+ self.pruned_heads = set()
240
+
241
+ def prune_heads(self, heads):
242
+ if len(heads) == 0:
243
+ return
244
+ heads, index = find_pruneable_heads_and_indices(
245
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
246
+ )
247
+
248
+ # Prune linear layers
249
+ self.self.query = prune_linear_layer(self.self.query, index)
250
+ self.self.key = prune_linear_layer(self.self.key, index)
251
+ self.self.value = prune_linear_layer(self.self.value, index)
252
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
253
+
254
+ # Update hyper params and store pruned heads
255
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
256
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
257
+ self.pruned_heads = self.pruned_heads.union(heads)
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states,
262
+ attention_mask=None,
263
+ head_mask=None,
264
+ encoder_hidden_states=None,
265
+ encoder_attention_mask=None,
266
+ past_key_value=None,
267
+ output_attentions=False,
268
+ ):
269
+ self_outputs = self.self(
270
+ hidden_states,
271
+ attention_mask,
272
+ head_mask,
273
+ encoder_hidden_states,
274
+ encoder_attention_mask,
275
+ past_key_value,
276
+ output_attentions,
277
+ )
278
+ attention_output = self.output(self_outputs[0], hidden_states)
279
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
280
+ return outputs
281
+
282
+
283
+ class BertIntermediate(nn.Module):
284
+ def __init__(self, config):
285
+ super().__init__()
286
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
287
+ if isinstance(config.hidden_act, str):
288
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
289
+ else:
290
+ self.intermediate_act_fn = config.hidden_act
291
+
292
+ def forward(self, hidden_states):
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.intermediate_act_fn(hidden_states)
295
+ return hidden_states
296
+
297
+
298
+ class BertOutput(nn.Module):
299
+ def __init__(self, config):
300
+ super().__init__()
301
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
302
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
303
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
304
+
305
+ def forward(self, hidden_states, input_tensor):
306
+ hidden_states = self.dense(hidden_states)
307
+ hidden_states = self.dropout(hidden_states)
308
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
309
+ return hidden_states
310
+
311
+
312
+ class BertLayer(nn.Module):
313
+ def __init__(self, config, layer_num):
314
+ super().__init__()
315
+ self.config = config
316
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
317
+ self.seq_len_dim = 1
318
+ self.attention = BertAttention(config)
319
+ self.layer_num = layer_num
320
+ if self.config.add_cross_attention:
321
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
322
+ self.intermediate = BertIntermediate(config)
323
+ self.output = BertOutput(config)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states,
328
+ attention_mask=None,
329
+ head_mask=None,
330
+ encoder_hidden_states=None,
331
+ encoder_attention_mask=None,
332
+ past_key_value=None,
333
+ output_attentions=False,
334
+ mode=None,
335
+ ):
336
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
337
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
338
+ self_attention_outputs = self.attention(
339
+ hidden_states,
340
+ attention_mask,
341
+ head_mask,
342
+ output_attentions=output_attentions,
343
+ past_key_value=self_attn_past_key_value,
344
+ )
345
+ attention_output = self_attention_outputs[0]
346
+
347
+ outputs = self_attention_outputs[1:-1]
348
+ present_key_value = self_attention_outputs[-1]
349
+
350
+ if mode=='multimodal':
351
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
352
+
353
+ cross_attention_outputs = self.crossattention(
354
+ attention_output,
355
+ attention_mask,
356
+ head_mask,
357
+ encoder_hidden_states,
358
+ encoder_attention_mask,
359
+ output_attentions=output_attentions,
360
+ )
361
+ attention_output = cross_attention_outputs[0]
362
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
363
+ layer_output = apply_chunking_to_forward(
364
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
365
+ )
366
+ outputs = (layer_output,) + outputs
367
+
368
+ outputs = outputs + (present_key_value,)
369
+
370
+ return outputs
371
+
372
+ def feed_forward_chunk(self, attention_output):
373
+ intermediate_output = self.intermediate(attention_output)
374
+ layer_output = self.output(intermediate_output, attention_output)
375
+ return layer_output
376
+
377
+
378
+ class BertEncoder(nn.Module):
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.config = config
382
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
383
+ self.gradient_checkpointing = False
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states,
388
+ attention_mask=None,
389
+ head_mask=None,
390
+ encoder_hidden_states=None,
391
+ encoder_attention_mask=None,
392
+ past_key_values=None,
393
+ use_cache=None,
394
+ output_attentions=False,
395
+ output_hidden_states=False,
396
+ return_dict=True,
397
+ mode='multimodal',
398
+ ):
399
+ all_hidden_states = () if output_hidden_states else None
400
+ all_self_attentions = () if output_attentions else None
401
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
402
+
403
+ next_decoder_cache = () if use_cache else None
404
+
405
+ for i in range(self.config.num_hidden_layers):
406
+ layer_module = self.layer[i]
407
+ if output_hidden_states:
408
+ all_hidden_states = all_hidden_states + (hidden_states,)
409
+
410
+ layer_head_mask = head_mask[i] if head_mask is not None else None
411
+ past_key_value = past_key_values[i] if past_key_values is not None else None
412
+
413
+ if self.gradient_checkpointing and self.training:
414
+
415
+ if use_cache:
416
+ logger.warn(
417
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
418
+ )
419
+ use_cache = False
420
+
421
+ def create_custom_forward(module):
422
+ def custom_forward(*inputs):
423
+ return module(*inputs, past_key_value, output_attentions)
424
+
425
+ return custom_forward
426
+
427
+ layer_outputs = torch.utils.checkpoint.checkpoint(
428
+ create_custom_forward(layer_module),
429
+ hidden_states,
430
+ attention_mask,
431
+ layer_head_mask,
432
+ encoder_hidden_states,
433
+ encoder_attention_mask,
434
+ mode=mode,
435
+ )
436
+ else:
437
+ layer_outputs = layer_module(
438
+ hidden_states,
439
+ attention_mask,
440
+ layer_head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ past_key_value,
444
+ output_attentions,
445
+ mode=mode,
446
+ )
447
+
448
+ hidden_states = layer_outputs[0]
449
+ if use_cache:
450
+ next_decoder_cache += (layer_outputs[-1],)
451
+ if output_attentions:
452
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
453
+
454
+ if output_hidden_states:
455
+ all_hidden_states = all_hidden_states + (hidden_states,)
456
+
457
+ if not return_dict:
458
+ return tuple(
459
+ v
460
+ for v in [
461
+ hidden_states,
462
+ next_decoder_cache,
463
+ all_hidden_states,
464
+ all_self_attentions,
465
+ all_cross_attentions,
466
+ ]
467
+ if v is not None
468
+ )
469
+ return BaseModelOutputWithPastAndCrossAttentions(
470
+ last_hidden_state=hidden_states,
471
+ past_key_values=next_decoder_cache,
472
+ hidden_states=all_hidden_states,
473
+ attentions=all_self_attentions,
474
+ cross_attentions=all_cross_attentions,
475
+ )
476
+
477
+
478
+ class BertPooler(nn.Module):
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
482
+ self.activation = nn.Tanh()
483
+
484
+ def forward(self, hidden_states):
485
+ # We "pool" the model by simply taking the hidden state corresponding
486
+ # to the first token.
487
+ first_token_tensor = hidden_states[:, 0]
488
+ pooled_output = self.dense(first_token_tensor)
489
+ pooled_output = self.activation(pooled_output)
490
+ return pooled_output
491
+
492
+
493
+ class BertPredictionHeadTransform(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
497
+ if isinstance(config.hidden_act, str):
498
+ self.transform_act_fn = ACT2FN[config.hidden_act]
499
+ else:
500
+ self.transform_act_fn = config.hidden_act
501
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502
+
503
+ def forward(self, hidden_states):
504
+ hidden_states = self.dense(hidden_states)
505
+ hidden_states = self.transform_act_fn(hidden_states)
506
+ hidden_states = self.LayerNorm(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertLMPredictionHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.transform = BertPredictionHeadTransform(config)
514
+
515
+ # The output weights are the same as the input embeddings, but there is
516
+ # an output-only bias for each token.
517
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
518
+
519
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
520
+
521
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
522
+ self.decoder.bias = self.bias
523
+
524
+ def forward(self, hidden_states):
525
+ hidden_states = self.transform(hidden_states)
526
+ hidden_states = self.decoder(hidden_states)
527
+ return hidden_states
528
+
529
+
530
+ class BertOnlyMLMHead(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+
535
+ def forward(self, sequence_output):
536
+ prediction_scores = self.predictions(sequence_output)
537
+ return prediction_scores
538
+
539
+
540
+ class BertPreTrainedModel(PreTrainedModel):
541
+ """
542
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
543
+ models.
544
+ """
545
+
546
+ config_class = BertConfig
547
+ base_model_prefix = "bert"
548
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
549
+
550
+ def _init_weights(self, module):
551
+ """ Initialize the weights """
552
+ if isinstance(module, (nn.Linear, nn.Embedding)):
553
+ # Slightly different from the TF version which uses truncated_normal for initialization
554
+ # cf https://github.com/pytorch/pytorch/pull/5617
555
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
556
+ elif isinstance(module, nn.LayerNorm):
557
+ module.bias.data.zero_()
558
+ module.weight.data.fill_(1.0)
559
+ if isinstance(module, nn.Linear) and module.bias is not None:
560
+ module.bias.data.zero_()
561
+
562
+
563
+ class BertModel(BertPreTrainedModel):
564
+ """
565
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
566
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
567
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
568
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
569
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
570
+ input to the forward pass.
571
+ """
572
+
573
+ def __init__(self, config, add_pooling_layer=True):
574
+ super().__init__(config)
575
+ self.config = config
576
+
577
+ self.embeddings = BertEmbeddings(config)
578
+
579
+ self.encoder = BertEncoder(config)
580
+
581
+ self.pooler = BertPooler(config) if add_pooling_layer else None
582
+
583
+ self.init_weights()
584
+
585
+
586
+ def get_input_embeddings(self):
587
+ return self.embeddings.word_embeddings
588
+
589
+ def set_input_embeddings(self, value):
590
+ self.embeddings.word_embeddings = value
591
+
592
+ def _prune_heads(self, heads_to_prune):
593
+ """
594
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
595
+ class PreTrainedModel
596
+ """
597
+ for layer, heads in heads_to_prune.items():
598
+ self.encoder.layer[layer].attention.prune_heads(heads)
599
+
600
+
601
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
602
+ """
603
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
604
+
605
+ Arguments:
606
+ attention_mask (:obj:`torch.Tensor`):
607
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
608
+ input_shape (:obj:`Tuple[int]`):
609
+ The shape of the input to the model.
610
+ device: (:obj:`torch.device`):
611
+ The device of the input to the model.
612
+
613
+ Returns:
614
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
615
+ """
616
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
617
+ # ourselves in which case we just need to make it broadcastable to all heads.
618
+ if attention_mask.dim() == 3:
619
+ extended_attention_mask = attention_mask[:, None, :, :]
620
+ elif attention_mask.dim() == 2:
621
+ # Provided a padding mask of dimensions [batch_size, seq_length]
622
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
623
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
624
+ if is_decoder:
625
+ batch_size, seq_length = input_shape
626
+
627
+ seq_ids = torch.arange(seq_length, device=device)
628
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
629
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
630
+ # causal and attention masks must have same type with pytorch version < 1.3
631
+ causal_mask = causal_mask.to(attention_mask.dtype)
632
+
633
+ if causal_mask.shape[1] < attention_mask.shape[1]:
634
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
635
+ causal_mask = torch.cat(
636
+ [
637
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
638
+ causal_mask,
639
+ ],
640
+ axis=-1,
641
+ )
642
+
643
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
644
+ else:
645
+ extended_attention_mask = attention_mask[:, None, None, :]
646
+ else:
647
+ raise ValueError(
648
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
649
+ input_shape, attention_mask.shape
650
+ )
651
+ )
652
+
653
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654
+ # masked positions, this operation will create a tensor which is 0.0 for
655
+ # positions we want to attend and -10000.0 for masked positions.
656
+ # Since we are adding it to the raw scores before the softmax, this is
657
+ # effectively the same as removing these entirely.
658
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
659
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
660
+ return extended_attention_mask
661
+
662
+ def forward(
663
+ self,
664
+ input_ids=None,
665
+ attention_mask=None,
666
+ position_ids=None,
667
+ head_mask=None,
668
+ inputs_embeds=None,
669
+ encoder_embeds=None,
670
+ encoder_hidden_states=None,
671
+ encoder_attention_mask=None,
672
+ past_key_values=None,
673
+ use_cache=None,
674
+ output_attentions=None,
675
+ output_hidden_states=None,
676
+ return_dict=None,
677
+ is_decoder=False,
678
+ mode='multimodal',
679
+ ):
680
+ r"""
681
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
682
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
683
+ the model is configured as a decoder.
684
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
685
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
686
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
687
+ - 1 for tokens that are **not masked**,
688
+ - 0 for tokens that are **masked**.
689
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
690
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
691
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
692
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
693
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
694
+ use_cache (:obj:`bool`, `optional`):
695
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
696
+ decoding (see :obj:`past_key_values`).
697
+ """
698
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
699
+ output_hidden_states = (
700
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
701
+ )
702
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
703
+
704
+ if is_decoder:
705
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
706
+ else:
707
+ use_cache = False
708
+
709
+ if input_ids is not None and inputs_embeds is not None:
710
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
711
+ elif input_ids is not None:
712
+ input_shape = input_ids.size()
713
+ batch_size, seq_length = input_shape
714
+ device = input_ids.device
715
+ elif inputs_embeds is not None:
716
+ input_shape = inputs_embeds.size()[:-1]
717
+ batch_size, seq_length = input_shape
718
+ device = inputs_embeds.device
719
+ elif encoder_embeds is not None:
720
+ input_shape = encoder_embeds.size()[:-1]
721
+ batch_size, seq_length = input_shape
722
+ device = encoder_embeds.device
723
+ else:
724
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
725
+
726
+ # past_key_values_length
727
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
728
+
729
+ if attention_mask is None:
730
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
731
+
732
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
733
+ # ourselves in which case we just need to make it broadcastable to all heads.
734
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
735
+ device, is_decoder)
736
+
737
+ # If a 2D or 3D attention mask is provided for the cross-attention
738
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
739
+ if encoder_hidden_states is not None:
740
+ if type(encoder_hidden_states) == list:
741
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
742
+ else:
743
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
744
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
745
+
746
+ if type(encoder_attention_mask) == list:
747
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
748
+ elif encoder_attention_mask is None:
749
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
750
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
751
+ else:
752
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
753
+ else:
754
+ encoder_extended_attention_mask = None
755
+
756
+ # Prepare head mask if needed
757
+ # 1.0 in head_mask indicate we keep the head
758
+ # attention_probs has shape bsz x n_heads x N x N
759
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
760
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
761
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
762
+
763
+ if encoder_embeds is None:
764
+ embedding_output = self.embeddings(
765
+ input_ids=input_ids,
766
+ position_ids=position_ids,
767
+ inputs_embeds=inputs_embeds,
768
+ past_key_values_length=past_key_values_length,
769
+ )
770
+ else:
771
+ embedding_output = encoder_embeds
772
+
773
+ encoder_outputs = self.encoder(
774
+ embedding_output,
775
+ attention_mask=extended_attention_mask,
776
+ head_mask=head_mask,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ encoder_attention_mask=encoder_extended_attention_mask,
779
+ past_key_values=past_key_values,
780
+ use_cache=use_cache,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ return_dict=return_dict,
784
+ mode=mode,
785
+ )
786
+ sequence_output = encoder_outputs[0]
787
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
788
+
789
+ if not return_dict:
790
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
791
+
792
+ return BaseModelOutputWithPoolingAndCrossAttentions(
793
+ last_hidden_state=sequence_output,
794
+ pooler_output=pooled_output,
795
+ past_key_values=encoder_outputs.past_key_values,
796
+ hidden_states=encoder_outputs.hidden_states,
797
+ attentions=encoder_outputs.attentions,
798
+ cross_attentions=encoder_outputs.cross_attentions,
799
+ )
800
+
801
+
802
+
803
+ class BertLMHeadModel(BertPreTrainedModel):
804
+
805
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
806
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
807
+
808
+ def __init__(self, config):
809
+ super().__init__(config)
810
+
811
+ self.bert = BertModel(config, add_pooling_layer=False)
812
+ self.cls = BertOnlyMLMHead(config)
813
+
814
+ self.init_weights()
815
+
816
+ def get_output_embeddings(self):
817
+ return self.cls.predictions.decoder
818
+
819
+ def set_output_embeddings(self, new_embeddings):
820
+ self.cls.predictions.decoder = new_embeddings
821
+
822
+ def forward(
823
+ self,
824
+ input_ids=None,
825
+ attention_mask=None,
826
+ position_ids=None,
827
+ head_mask=None,
828
+ inputs_embeds=None,
829
+ encoder_hidden_states=None,
830
+ encoder_attention_mask=None,
831
+ labels=None,
832
+ past_key_values=None,
833
+ use_cache=None,
834
+ output_attentions=None,
835
+ output_hidden_states=None,
836
+ return_dict=None,
837
+ return_logits=False,
838
+ is_decoder=True,
839
+ reduction='mean',
840
+ mode='multimodal',
841
+ ):
842
+ r"""
843
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
844
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
845
+ the model is configured as a decoder.
846
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
847
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
848
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
849
+ - 1 for tokens that are **not masked**,
850
+ - 0 for tokens that are **masked**.
851
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
852
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
853
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
854
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
855
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
856
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
857
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
858
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
859
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
860
+ use_cache (:obj:`bool`, `optional`):
861
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
862
+ decoding (see :obj:`past_key_values`).
863
+ Returns:
864
+ Example::
865
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
866
+ >>> import torch
867
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
868
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
869
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
870
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
871
+ >>> outputs = model(**inputs)
872
+ >>> prediction_logits = outputs.logits
873
+ """
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+ if labels is not None:
876
+ use_cache = False
877
+
878
+ outputs = self.bert(
879
+ input_ids,
880
+ attention_mask=attention_mask,
881
+ position_ids=position_ids,
882
+ head_mask=head_mask,
883
+ inputs_embeds=inputs_embeds,
884
+ encoder_hidden_states=encoder_hidden_states,
885
+ encoder_attention_mask=encoder_attention_mask,
886
+ past_key_values=past_key_values,
887
+ use_cache=use_cache,
888
+ output_attentions=output_attentions,
889
+ output_hidden_states=output_hidden_states,
890
+ return_dict=return_dict,
891
+ is_decoder=is_decoder,
892
+ mode=mode,
893
+ )
894
+
895
+ sequence_output = outputs[0]
896
+ prediction_scores = self.cls(sequence_output)
897
+
898
+ if return_logits:
899
+ return prediction_scores[:, :-1, :].contiguous()
900
+
901
+ lm_loss = None
902
+ if labels is not None:
903
+ # we are doing next-token prediction; shift prediction scores and input ids by one
904
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
905
+ labels = labels[:, 1:].contiguous()
906
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
907
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
908
+ if reduction=='none':
909
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
910
+
911
+ if not return_dict:
912
+ output = (prediction_scores,) + outputs[2:]
913
+ return ((lm_loss,) + output) if lm_loss is not None else output
914
+
915
+ return CausalLMOutputWithCrossAttentions(
916
+ loss=lm_loss,
917
+ logits=prediction_scores,
918
+ past_key_values=outputs.past_key_values,
919
+ hidden_states=outputs.hidden_states,
920
+ attentions=outputs.attentions,
921
+ cross_attentions=outputs.cross_attentions,
922
+ )
923
+
924
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
925
+ input_shape = input_ids.shape
926
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
927
+ if attention_mask is None:
928
+ attention_mask = input_ids.new_ones(input_shape)
929
+
930
+ # cut decoder_input_ids if past is used
931
+ if past is not None:
932
+ input_ids = input_ids[:, -1:]
933
+
934
+ return {
935
+ "input_ids": input_ids,
936
+ "attention_mask": attention_mask,
937
+ "past_key_values": past,
938
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
939
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
940
+ "is_decoder": True,
941
+ }
942
+
943
+ def _reorder_cache(self, past, beam_idx):
944
+ reordered_past = ()
945
+ for layer_past in past:
946
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
947
+ return reordered_past
ImageReward/models/BLIP/vit.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on timm code base
4
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ '''
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from functools import partial
11
+
12
+ from timm.models.vision_transformer import _cfg, PatchEmbed
13
+ from timm.models.registry import register_model
14
+ from timm.models.layers import trunc_normal_, DropPath
15
+ from timm.models.helpers import named_apply, adapt_input_conv
16
+
17
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
18
+
19
+ class Mlp(nn.Module):
20
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
21
+ """
22
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.fc1 = nn.Linear(in_features, hidden_features)
27
+ self.act = act_layer()
28
+ self.fc2 = nn.Linear(hidden_features, out_features)
29
+ self.drop = nn.Dropout(drop)
30
+
31
+ def forward(self, x):
32
+ x = self.fc1(x)
33
+ x = self.act(x)
34
+ x = self.drop(x)
35
+ x = self.fc2(x)
36
+ x = self.drop(x)
37
+ return x
38
+
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
42
+ super().__init__()
43
+ self.num_heads = num_heads
44
+ head_dim = dim // num_heads
45
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
46
+ self.scale = qk_scale or head_dim ** -0.5
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
+ self.attn_drop = nn.Dropout(attn_drop)
49
+ self.proj = nn.Linear(dim, dim)
50
+ self.proj_drop = nn.Dropout(proj_drop)
51
+ self.attn_gradients = None
52
+ self.attention_map = None
53
+
54
+ def save_attn_gradients(self, attn_gradients):
55
+ self.attn_gradients = attn_gradients
56
+
57
+ def get_attn_gradients(self):
58
+ return self.attn_gradients
59
+
60
+ def save_attention_map(self, attention_map):
61
+ self.attention_map = attention_map
62
+
63
+ def get_attention_map(self):
64
+ return self.attention_map
65
+
66
+ def forward(self, x, register_hook=False):
67
+ B, N, C = x.shape
68
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
70
+
71
+ attn = (q @ k.transpose(-2, -1)) * self.scale
72
+ attn = attn.softmax(dim=-1)
73
+ attn = self.attn_drop(attn)
74
+
75
+ if register_hook:
76
+ self.save_attention_map(attn)
77
+ attn.register_hook(self.save_attn_gradients)
78
+
79
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
80
+ x = self.proj(x)
81
+ x = self.proj_drop(x)
82
+ return x
83
+
84
+
85
+ class Block(nn.Module):
86
+
87
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
88
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
93
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ self.norm2 = norm_layer(dim)
96
+ mlp_hidden_dim = int(dim * mlp_ratio)
97
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
+
99
+ if use_grad_checkpointing:
100
+ self.attn = checkpoint_wrapper(self.attn)
101
+ self.mlp = checkpoint_wrapper(self.mlp)
102
+
103
+ def forward(self, x, register_hook=False):
104
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
105
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
106
+ return x
107
+
108
+
109
+ class VisionTransformer(nn.Module):
110
+ """ Vision Transformer
111
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
112
+ https://arxiv.org/abs/2010.11929
113
+ """
114
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
115
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
116
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
117
+ use_grad_checkpointing=False, ckpt_layer=0):
118
+ """
119
+ Args:
120
+ img_size (int, tuple): input image size
121
+ patch_size (int, tuple): patch size
122
+ in_chans (int): number of input channels
123
+ num_classes (int): number of classes for classification head
124
+ embed_dim (int): embedding dimension
125
+ depth (int): depth of transformer
126
+ num_heads (int): number of attention heads
127
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
128
+ qkv_bias (bool): enable bias for qkv if True
129
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
130
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
131
+ drop_rate (float): dropout rate
132
+ attn_drop_rate (float): attention dropout rate
133
+ drop_path_rate (float): stochastic depth rate
134
+ norm_layer: (nn.Module): normalization layer
135
+ """
136
+ super().__init__()
137
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
138
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
139
+
140
+ self.patch_embed = PatchEmbed(
141
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
142
+
143
+ num_patches = self.patch_embed.num_patches
144
+
145
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147
+ self.pos_drop = nn.Dropout(p=drop_rate)
148
+
149
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
150
+ self.blocks = nn.ModuleList([
151
+ Block(
152
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
153
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
154
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
155
+ )
156
+ for i in range(depth)])
157
+ self.norm = norm_layer(embed_dim)
158
+
159
+ trunc_normal_(self.pos_embed, std=.02)
160
+ trunc_normal_(self.cls_token, std=.02)
161
+ self.apply(self._init_weights)
162
+
163
+ def _init_weights(self, m):
164
+ if isinstance(m, nn.Linear):
165
+ trunc_normal_(m.weight, std=.02)
166
+ if isinstance(m, nn.Linear) and m.bias is not None:
167
+ nn.init.constant_(m.bias, 0)
168
+ elif isinstance(m, nn.LayerNorm):
169
+ nn.init.constant_(m.bias, 0)
170
+ nn.init.constant_(m.weight, 1.0)
171
+
172
+ @torch.jit.ignore
173
+ def no_weight_decay(self):
174
+ return {'pos_embed', 'cls_token'}
175
+
176
+ def forward(self, x, register_blk=-1):
177
+ B = x.shape[0]
178
+ x = self.patch_embed(x)
179
+
180
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
181
+ x = torch.cat((cls_tokens, x), dim=1)
182
+
183
+ x = x + self.pos_embed[:,:x.size(1),:]
184
+ x = self.pos_drop(x)
185
+
186
+ for i,blk in enumerate(self.blocks):
187
+ x = blk(x, register_blk==i)
188
+ x = self.norm(x)
189
+
190
+ return x
191
+
192
+ @torch.jit.ignore()
193
+ def load_pretrained(self, checkpoint_path, prefix=''):
194
+ _load_weights(self, checkpoint_path, prefix)
195
+
196
+
197
+ @torch.no_grad()
198
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
199
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
200
+ """
201
+ import numpy as np
202
+
203
+ def _n2p(w, t=True):
204
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
205
+ w = w.flatten()
206
+ if t:
207
+ if w.ndim == 4:
208
+ w = w.transpose([3, 2, 0, 1])
209
+ elif w.ndim == 3:
210
+ w = w.transpose([2, 0, 1])
211
+ elif w.ndim == 2:
212
+ w = w.transpose([1, 0])
213
+ return torch.from_numpy(w)
214
+
215
+ w = np.load(checkpoint_path)
216
+ if not prefix and 'opt/target/embedding/kernel' in w:
217
+ prefix = 'opt/target/'
218
+
219
+ if hasattr(model.patch_embed, 'backbone'):
220
+ # hybrid
221
+ backbone = model.patch_embed.backbone
222
+ stem_only = not hasattr(backbone, 'stem')
223
+ stem = backbone if stem_only else backbone.stem
224
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
225
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
226
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
227
+ if not stem_only:
228
+ for i, stage in enumerate(backbone.stages):
229
+ for j, block in enumerate(stage.blocks):
230
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
231
+ for r in range(3):
232
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
233
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
234
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
235
+ if block.downsample is not None:
236
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
237
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
238
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
239
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
240
+ else:
241
+ embed_conv_w = adapt_input_conv(
242
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
243
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
244
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
245
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
246
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
247
+ if pos_embed_w.shape != model.pos_embed.shape:
248
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
249
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
250
+ model.pos_embed.copy_(pos_embed_w)
251
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
252
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
253
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
254
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
255
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
256
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
257
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
258
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
259
+ for i, block in enumerate(model.blocks.children()):
260
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
261
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
262
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
263
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
264
+ block.attn.qkv.weight.copy_(torch.cat([
265
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
266
+ block.attn.qkv.bias.copy_(torch.cat([
267
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
268
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
269
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
270
+ for r in range(2):
271
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
272
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
273
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
274
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
275
+
276
+
277
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
278
+ # interpolate position embedding
279
+ embedding_size = pos_embed_checkpoint.shape[-1]
280
+ num_patches = visual_encoder.patch_embed.num_patches
281
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
282
+ # height (== width) for the checkpoint position embedding
283
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
284
+ # height (== width) for the new position embedding
285
+ new_size = int(num_patches ** 0.5)
286
+
287
+ if orig_size!=new_size:
288
+ # class_token and dist_token are kept unchanged
289
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
290
+ # only the position tokens are interpolated
291
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
292
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
293
+ pos_tokens = torch.nn.functional.interpolate(
294
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
295
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
296
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
297
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
298
+
299
+ return new_pos_embed
300
+ else:
301
+ return pos_embed_checkpoint
ImageReward/models/BLIPScore.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : BLIPScore.py
3
+ @Time : 2023/02/19 20:48:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: BLIPScore.
7
+ * Based on BLIP code base
8
+ * https://github.com/salesforce/BLIP
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ from ImageReward.models.BLIP.blip_pretrain import BLIP_Pretrain
16
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
17
+
18
+ try:
19
+ from torchvision.transforms import InterpolationMode
20
+ BICUBIC = InterpolationMode.BICUBIC
21
+ except ImportError:
22
+ BICUBIC = Image.BICUBIC
23
+
24
+
25
+ def _convert_image_to_rgb(image):
26
+ return image.convert("RGB")
27
+
28
+
29
+ def _transform(n_px):
30
+ return Compose([
31
+ Resize(n_px, interpolation=BICUBIC),
32
+ CenterCrop(n_px),
33
+ _convert_image_to_rgb,
34
+ ToTensor(),
35
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
36
+ ])
37
+
38
+
39
+ class BLIPScore(nn.Module):
40
+ def __init__(self, med_config, device='cpu'):
41
+ super().__init__()
42
+ self.device = device
43
+
44
+ self.preprocess = _transform(224)
45
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
46
+
47
+
48
+ def score(self, prompt, image_path):
49
+
50
+ if (type(image_path).__name__=='list'):
51
+ _, rewards = self.inference_rank(prompt, image_path)
52
+ return rewards
53
+
54
+ # text encode
55
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
56
+ text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
57
+ txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
58
+
59
+ # image encode
60
+ pil_image = Image.open(image_path)
61
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
62
+ image_embeds = self.blip.visual_encoder(image)
63
+ image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
64
+
65
+ # score
66
+ rewards = torch.sum(torch.mul(txt_feature, image_features), dim=1, keepdim=True)
67
+
68
+ return rewards.detach().cpu().numpy().item()
69
+
70
+
71
+ def inference_rank(self, prompt, generations_list):
72
+
73
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
74
+ text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
75
+ txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
76
+
77
+ txt_set = []
78
+ img_set = []
79
+ for generations in generations_list:
80
+ # image encode
81
+ img_path = generations
82
+ pil_image = Image.open(img_path)
83
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
84
+ image_embeds = self.blip.visual_encoder(image)
85
+ image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
86
+ img_set.append(image_features)
87
+ txt_set.append(txt_feature)
88
+
89
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
90
+ img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
91
+ rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
92
+ rewards = torch.squeeze(rewards)
93
+ _, rank = torch.sort(rewards, dim=0, descending=True)
94
+ _, indices = torch.sort(rank, dim=0)
95
+ indices = indices + 1
96
+
97
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/models/CLIPScore.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : CLIPScore.py
3
+ @Time : 2023/02/12 13:14:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: CLIPScore.
7
+ * Based on CLIP code base
8
+ * https://github.com/openai/CLIP
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ import clip
16
+
17
+ class CLIPScore(nn.Module):
18
+ def __init__(self, download_root, device='cpu'):
19
+ super().__init__()
20
+ self.device = device
21
+ self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
22
+ download_root=download_root)
23
+
24
+ if device == "cpu":
25
+ self.clip_model.float()
26
+ else:
27
+ clip.model.convert_weights(self.clip_model) # Actually this line is unnecessary since clip by default already on float16
28
+
29
+ # have clip.logit_scale require no grad.
30
+ self.clip_model.logit_scale.requires_grad_(False)
31
+
32
+
33
+ def score(self, prompt, image_path):
34
+
35
+ if (type(image_path).__name__=='list'):
36
+ _, rewards = self.inference_rank(prompt, image_path)
37
+ return rewards
38
+
39
+ # text encode
40
+ text = clip.tokenize(prompt, truncate=True).to(self.device)
41
+ txt_features = F.normalize(self.clip_model.encode_text(text))
42
+
43
+ # image encode
44
+ pil_image = Image.open(image_path)
45
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
46
+ image_features = F.normalize(self.clip_model.encode_image(image))
47
+
48
+ # score
49
+ rewards = torch.sum(torch.mul(txt_features, image_features), dim=1, keepdim=True)
50
+
51
+ return rewards.detach().cpu().numpy().item()
52
+
53
+
54
+ def inference_rank(self, prompt, generations_list):
55
+
56
+ text = clip.tokenize(prompt, truncate=True).to(self.device)
57
+ txt_feature = F.normalize(self.clip_model.encode_text(text))
58
+
59
+ txt_set = []
60
+ img_set = []
61
+ for generations in generations_list:
62
+ # image encode
63
+ img_path = generations
64
+ pil_image = Image.open(img_path)
65
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
66
+ image_features = F.normalize(self.clip_model.encode_image(image))
67
+ img_set.append(image_features)
68
+ txt_set.append(txt_feature)
69
+
70
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
71
+ img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
72
+ rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
73
+ rewards = torch.squeeze(rewards)
74
+ _, rank = torch.sort(rewards, dim=0, descending=True)
75
+ _, indices = torch.sort(rank, dim=0)
76
+ indices = indices + 1
77
+
78
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .AestheticScore import *
2
+ from .BLIPScore import *
3
+ from .CLIPScore import *
4
+ from .BLIP import *
ImageReward/utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : utils.py
3
+ @Time : 2023/04/05 19:18:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ * Based on CLIP code base
7
+ * https://github.com/openai/CLIP
8
+ * Checkpoint of CLIP/BLIP/Aesthetic are from:
9
+ * https://github.com/openai/CLIP
10
+ * https://github.com/salesforce/BLIP
11
+ * https://github.com/christophschuhmann/improved-aesthetic-predictor
12
+ '''
13
+
14
+ import os
15
+ import urllib
16
+ from typing import Union, List
17
+ import pathlib
18
+
19
+ import torch
20
+ from tqdm import tqdm
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ from .ImageReward import ImageReward
24
+ from .models.CLIPScore import CLIPScore
25
+ from .models.BLIPScore import BLIPScore
26
+ from .models.AestheticScore import AestheticScore
27
+
28
+ _MODELS = {
29
+ "ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt",
30
+ }
31
+
32
+
33
+ def available_models() -> List[str]:
34
+ """Returns the names of available ImageReward models"""
35
+ return list(_MODELS.keys())
36
+
37
+
38
+ def ImageReward_download(url: str, root: str):
39
+ os.makedirs(root, exist_ok=True)
40
+ filename = os.path.basename(url)
41
+ download_target = os.path.join(root, filename)
42
+ hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root)
43
+ return download_target
44
+
45
+
46
+ def load(name: str = "ImageReward-v1.0",
47
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
48
+ download_root: str = None,
49
+ med_config_path: str = None):
50
+ """Load a ImageReward model
51
+
52
+ Parameters
53
+ ----------
54
+ name: str
55
+ A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
56
+ device: Union[str, torch.device]
57
+ The device to put the loaded model
58
+ download_root: str
59
+ path to download the model files; by default, it uses "~/.cache/ImageReward"
60
+ med_config_path: str
61
+
62
+ Returns
63
+ -------
64
+ model : torch.nn.Module
65
+ The ImageReward model
66
+ """
67
+ if name in _MODELS:
68
+ download_root = download_root or "~/.cache/ImageReward"
69
+ download_root = pathlib.Path(download_root)
70
+ model_path = pathlib.Path(download_root) / 'ImageReward.pt'
71
+
72
+ if not model_path.exists():
73
+ model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix())
74
+ elif os.path.isfile(name):
75
+ model_path = name
76
+ else:
77
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
78
+
79
+ print('-> load ImageReward model from %s' % model_path)
80
+ state_dict = torch.load(model_path, map_location='cpu')
81
+
82
+ # med_config
83
+ if med_config_path is None:
84
+ med_config_root = download_root or "~/.cache/ImageReward"
85
+ med_config_root = pathlib.Path(med_config_root)
86
+ med_config_path = med_config_root / 'med_config.json'
87
+
88
+ if not med_config_path.exists():
89
+ med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
90
+ root=med_config_root.as_posix())
91
+ print('-> load ImageReward med_config from %s' % med_config_path)
92
+
93
+ model = ImageReward(device=device, med_config=med_config_path).to(device)
94
+ msg = model.load_state_dict(state_dict, strict=False)
95
+ model.eval()
96
+
97
+ return model
98
+
99
+
100
+ _SCORES = {
101
+ "CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
102
+ "BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
103
+ "Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
104
+ }
105
+
106
+
107
+ def available_scores() -> List[str]:
108
+ """Returns the names of available ImageReward scores"""
109
+ return list(_SCORES.keys())
110
+
111
+
112
+ def _download(url: str, root: str):
113
+ os.makedirs(root, exist_ok=True)
114
+ filename = os.path.basename(url)
115
+
116
+ download_target = os.path.join(root, filename)
117
+
118
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
119
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
120
+
121
+ if os.path.isfile(download_target):
122
+ return download_target
123
+
124
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
125
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
126
+ unit_divisor=1024) as loop:
127
+ while True:
128
+ buffer = source.read(8192)
129
+ if not buffer:
130
+ break
131
+
132
+ output.write(buffer)
133
+ loop.update(len(buffer))
134
+
135
+ return download_target
136
+
137
+
138
+ def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
139
+ download_root: str = None):
140
+ """Load a ImageReward model
141
+
142
+ Parameters
143
+ ----------
144
+ name : str
145
+ A model name listed by `ImageReward.available_models()`
146
+
147
+ device : Union[str, torch.device]
148
+ The device to put the loaded model
149
+
150
+ download_root: str
151
+ path to download the model files; by default, it uses "~/.cache/ImageReward"
152
+
153
+ Returns
154
+ -------
155
+ model : torch.nn.Module
156
+ The ImageReward model
157
+ """
158
+ model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward")
159
+
160
+ if name in _SCORES:
161
+ model_path = _download(_SCORES[name], model_download_root)
162
+ else:
163
+ raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
164
+
165
+ print('load checkpoint from %s' % model_path)
166
+ if name == "BLIP":
167
+ state_dict = torch.load(model_path, map_location='cpu')
168
+ med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
169
+ model_download_root)
170
+ model = BLIPScore(med_config=med_config, device=device).to(device)
171
+ model.blip.load_state_dict(state_dict['model'], strict=False)
172
+ elif name == "CLIP":
173
+ model = CLIPScore(download_root=model_download_root, device=device).to(device)
174
+ elif name == "Aesthetic":
175
+ state_dict = torch.load(model_path, map_location='cpu')
176
+ model = AestheticScore(download_root=model_download_root, device=device).to(device)
177
+ model.mlp.load_state_dict(state_dict, strict=False)
178
+ else:
179
+ raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
180
+
181
+ print("checkpoint loaded")
182
+ model.eval()
183
+
184
+ return model
Install.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+
3
+ Create a new conda environment:
4
+
5
+ ```shell
6
+ conda create --name svgrender python=3.10
7
+ conda activate svgrender
8
+ ```
9
+
10
+ Install pytorch and the following libraries:
11
+
12
+ ```shell
13
+ conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
14
+ pip install hydra-core omegaconf
15
+ pip install freetype-py shapely svgutils
16
+ pip install opencv-python scikit-image matplotlib visdom wandb BeautifulSoup4
17
+ pip install triton numba
18
+ pip install numpy scipy scikit-fmm einops timm fairscale=0.4.13
19
+ pip install accelerate transformers safetensors datasets
20
+ ```
21
+
22
+ Install LaMa:
23
+
24
+ ```shell
25
+ pip install easydict scikit-learn pytorch_lightning webdataset
26
+ pip install albumentations==0.5.2
27
+ pip install kornia==0.5.0
28
+ pip install wldhx.yadisk-direct
29
+
30
+ cd lama
31
+ # download LaMa model weights
32
+ # raw link(deprecated): curl -L $(yadisk-direct https://disk.yandex.ru/d/kHJkc7bs7mKIVA) -o big-lama.zip
33
+ curl -O -L https://huggingface.co/xingxm/PyTorch-SVGRender-models/resolve/main/big-lama.zip
34
+ unzip big-lama.zip
35
+ ```
36
+
37
+ Install CLIP:
38
+
39
+ ```shell
40
+ pip install ftfy regex tqdm
41
+ pip install git+https://github.com/openai/CLIP.git
42
+ ```
43
+
44
+ Install diffusers:
45
+
46
+ ```shell
47
+ pip install diffusers==0.20.2
48
+ ```
49
+
50
+ Install xformers (require `python=3.10`):
51
+
52
+ ```shell
53
+ conda install xformers -c xformers
54
+ ```
55
+
56
+ Install diffvg:
57
+
58
+ ```shell
59
+ git clone https://github.com/BachiLi/diffvg.git
60
+ cd diffvg
61
+ git submodule update --init --recursive
62
+ conda install -y -c anaconda cmake
63
+ conda install -y -c conda-forge ffmpeg
64
+ pip install svgwrite svgpathtools cssutils torch-tools
65
+ python setup.py install
66
+ ```
LICENSE ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Mozilla Public License Version 2.0
2
+ ==================================
3
+
4
+ 1. Definitions
5
+ --------------
6
+
7
+ 1.1. "Contributor"
8
+ means each individual or legal entity that creates, contributes to
9
+ the creation of, or owns Covered Software.
10
+
11
+ 1.2. "Contributor Version"
12
+ means the combination of the Contributions of others (if any) used
13
+ by a Contributor and that particular Contributor's Contribution.
14
+
15
+ 1.3. "Contribution"
16
+ means Covered Software of a particular Contributor.
17
+
18
+ 1.4. "Covered Software"
19
+ means Source Code Form to which the initial Contributor has attached
20
+ the notice in Exhibit A, the Executable Form of such Source Code
21
+ Form, and Modifications of such Source Code Form, in each case
22
+ including portions thereof.
23
+
24
+ 1.5. "Incompatible With Secondary Licenses"
25
+ means
26
+
27
+ (a) that the initial Contributor has attached the notice described
28
+ in Exhibit B to the Covered Software; or
29
+
30
+ (b) that the Covered Software was made available under the terms of
31
+ version 1.1 or earlier of the License, but not also under the
32
+ terms of a Secondary License.
33
+
34
+ 1.6. "Executable Form"
35
+ means any form of the work other than Source Code Form.
36
+
37
+ 1.7. "Larger Work"
38
+ means a work that combines Covered Software with other material, in
39
+ a separate file or files, that is not Covered Software.
40
+
41
+ 1.8. "License"
42
+ means this document.
43
+
44
+ 1.9. "Licensable"
45
+ means having the right to grant, to the maximum extent possible,
46
+ whether at the time of the initial grant or subsequently, any and
47
+ all of the rights conveyed by this License.
48
+
49
+ 1.10. "Modifications"
50
+ means any of the following:
51
+
52
+ (a) any file in Source Code Form that results from an addition to,
53
+ deletion from, or modification of the contents of Covered
54
+ Software; or
55
+
56
+ (b) any new file in Source Code Form that contains any Covered
57
+ Software.
58
+
59
+ 1.11. "Patent Claims" of a Contributor
60
+ means any patent claim(s), including without limitation, method,
61
+ process, and apparatus claims, in any patent Licensable by such
62
+ Contributor that would be infringed, but for the grant of the
63
+ License, by the making, using, selling, offering for sale, having
64
+ made, import, or transfer of either its Contributions or its
65
+ Contributor Version.
66
+
67
+ 1.12. "Secondary License"
68
+ means either the GNU General Public License, Version 2.0, the GNU
69
+ Lesser General Public License, Version 2.1, the GNU Affero General
70
+ Public License, Version 3.0, or any later versions of those
71
+ licenses.
72
+
73
+ 1.13. "Source Code Form"
74
+ means the form of the work preferred for making modifications.
75
+
76
+ 1.14. "You" (or "Your")
77
+ means an individual or a legal entity exercising rights under this
78
+ License. For legal entities, "You" includes any entity that
79
+ controls, is controlled by, or is under common control with You. For
80
+ purposes of this definition, "control" means (a) the power, direct
81
+ or indirect, to cause the direction or management of such entity,
82
+ whether by contract or otherwise, or (b) ownership of more than
83
+ fifty percent (50%) of the outstanding shares or beneficial
84
+ ownership of such entity.
85
+
86
+ 2. License Grants and Conditions
87
+ --------------------------------
88
+
89
+ 2.1. Grants
90
+
91
+ Each Contributor hereby grants You a world-wide, royalty-free,
92
+ non-exclusive license:
93
+
94
+ (a) under intellectual property rights (other than patent or trademark)
95
+ Licensable by such Contributor to use, reproduce, make available,
96
+ modify, display, perform, distribute, and otherwise exploit its
97
+ Contributions, either on an unmodified basis, with Modifications, or
98
+ as part of a Larger Work; and
99
+
100
+ (b) under Patent Claims of such Contributor to make, use, sell, offer
101
+ for sale, have made, import, and otherwise transfer either its
102
+ Contributions or its Contributor Version.
103
+
104
+ 2.2. Effective Date
105
+
106
+ The licenses granted in Section 2.1 with respect to any Contribution
107
+ become effective for each Contribution on the date the Contributor first
108
+ distributes such Contribution.
109
+
110
+ 2.3. Limitations on Grant Scope
111
+
112
+ The licenses granted in this Section 2 are the only rights granted under
113
+ this License. No additional rights or licenses will be implied from the
114
+ distribution or licensing of Covered Software under this License.
115
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
116
+ Contributor:
117
+
118
+ (a) for any code that a Contributor has removed from Covered Software;
119
+ or
120
+
121
+ (b) for infringements caused by: (i) Your and any other third party's
122
+ modifications of Covered Software, or (ii) the combination of its
123
+ Contributions with other software (except as part of its Contributor
124
+ Version); or
125
+
126
+ (c) under Patent Claims infringed by Covered Software in the absence of
127
+ its Contributions.
128
+
129
+ This License does not grant any rights in the trademarks, service marks,
130
+ or logos of any Contributor (except as may be necessary to comply with
131
+ the notice requirements in Section 3.4).
132
+
133
+ 2.4. Subsequent Licenses
134
+
135
+ No Contributor makes additional grants as a result of Your choice to
136
+ distribute the Covered Software under a subsequent version of this
137
+ License (see Section 10.2) or under the terms of a Secondary License (if
138
+ permitted under the terms of Section 3.3).
139
+
140
+ 2.5. Representation
141
+
142
+ Each Contributor represents that the Contributor believes its
143
+ Contributions are its original creation(s) or it has sufficient rights
144
+ to grant the rights to its Contributions conveyed by this License.
145
+
146
+ 2.6. Fair Use
147
+
148
+ This License is not intended to limit any rights You have under
149
+ applicable copyright doctrines of fair use, fair dealing, or other
150
+ equivalents.
151
+
152
+ 2.7. Conditions
153
+
154
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
155
+ in Section 2.1.
156
+
157
+ 3. Responsibilities
158
+ -------------------
159
+
160
+ 3.1. Distribution of Source Form
161
+
162
+ All distribution of Covered Software in Source Code Form, including any
163
+ Modifications that You create or to which You contribute, must be under
164
+ the terms of this License. You must inform recipients that the Source
165
+ Code Form of the Covered Software is governed by the terms of this
166
+ License, and how they can obtain a copy of this License. You may not
167
+ attempt to alter or restrict the recipients' rights in the Source Code
168
+ Form.
169
+
170
+ 3.2. Distribution of Executable Form
171
+
172
+ If You distribute Covered Software in Executable Form then:
173
+
174
+ (a) such Covered Software must also be made available in Source Code
175
+ Form, as described in Section 3.1, and You must inform recipients of
176
+ the Executable Form how they can obtain a copy of such Source Code
177
+ Form by reasonable means in a timely manner, at a charge no more
178
+ than the cost of distribution to the recipient; and
179
+
180
+ (b) You may distribute such Executable Form under the terms of this
181
+ License, or sublicense it under different terms, provided that the
182
+ license for the Executable Form does not attempt to limit or alter
183
+ the recipients' rights in the Source Code Form under this License.
184
+
185
+ 3.3. Distribution of a Larger Work
186
+
187
+ You may create and distribute a Larger Work under terms of Your choice,
188
+ provided that You also comply with the requirements of this License for
189
+ the Covered Software. If the Larger Work is a combination of Covered
190
+ Software with a work governed by one or more Secondary Licenses, and the
191
+ Covered Software is not Incompatible With Secondary Licenses, this
192
+ License permits You to additionally distribute such Covered Software
193
+ under the terms of such Secondary License(s), so that the recipient of
194
+ the Larger Work may, at their option, further distribute the Covered
195
+ Software under the terms of either this License or such Secondary
196
+ License(s).
197
+
198
+ 3.4. Notices
199
+
200
+ You may not remove or alter the substance of any license notices
201
+ (including copyright notices, patent notices, disclaimers of warranty,
202
+ or limitations of liability) contained within the Source Code Form of
203
+ the Covered Software, except that You may alter any license notices to
204
+ the extent required to remedy known factual inaccuracies.
205
+
206
+ 3.5. Application of Additional Terms
207
+
208
+ You may choose to offer, and to charge a fee for, warranty, support,
209
+ indemnity or liability obligations to one or more recipients of Covered
210
+ Software. However, You may do so only on Your own behalf, and not on
211
+ behalf of any Contributor. You must make it absolutely clear that any
212
+ such warranty, support, indemnity, or liability obligation is offered by
213
+ You alone, and You hereby agree to indemnify every Contributor for any
214
+ liability incurred by such Contributor as a result of warranty, support,
215
+ indemnity or liability terms You offer. You may include additional
216
+ disclaimers of warranty and limitations of liability specific to any
217
+ jurisdiction.
218
+
219
+ 4. Inability to Comply Due to Statute or Regulation
220
+ ---------------------------------------------------
221
+
222
+ If it is impossible for You to comply with any of the terms of this
223
+ License with respect to some or all of the Covered Software due to
224
+ statute, judicial order, or regulation then You must: (a) comply with
225
+ the terms of this License to the maximum extent possible; and (b)
226
+ describe the limitations and the code they affect. Such description must
227
+ be placed in a text file included with all distributions of the Covered
228
+ Software under this License. Except to the extent prohibited by statute
229
+ or regulation, such description must be sufficiently detailed for a
230
+ recipient of ordinary skill to be able to understand it.
231
+
232
+ 5. Termination
233
+ --------------
234
+
235
+ 5.1. The rights granted under this License will terminate automatically
236
+ if You fail to comply with any of its terms. However, if You become
237
+ compliant, then the rights granted under this License from a particular
238
+ Contributor are reinstated (a) provisionally, unless and until such
239
+ Contributor explicitly and finally terminates Your grants, and (b) on an
240
+ ongoing basis, if such Contributor fails to notify You of the
241
+ non-compliance by some reasonable means prior to 60 days after You have
242
+ come back into compliance. Moreover, Your grants from a particular
243
+ Contributor are reinstated on an ongoing basis if such Contributor
244
+ notifies You of the non-compliance by some reasonable means, this is the
245
+ first time You have received notice of non-compliance with this License
246
+ from such Contributor, and You become compliant prior to 30 days after
247
+ Your receipt of the notice.
248
+
249
+ 5.2. If You initiate litigation against any entity by asserting a patent
250
+ infringement claim (excluding declaratory judgment actions,
251
+ counter-claims, and cross-claims) alleging that a Contributor Version
252
+ directly or indirectly infringes any patent, then the rights granted to
253
+ You by any and all Contributors for the Covered Software under Section
254
+ 2.1 of this License shall terminate.
255
+
256
+ 5.3. In the event of termination under Sections 5.1 or 5.2 above, all
257
+ end user license agreements (excluding distributors and resellers) which
258
+ have been validly granted by You or Your distributors under this License
259
+ prior to termination shall survive termination.
260
+
261
+ ************************************************************************
262
+ * *
263
+ * 6. Disclaimer of Warranty *
264
+ * ------------------------- *
265
+ * *
266
+ * Covered Software is provided under this License on an "as is" *
267
+ * basis, without warranty of any kind, either expressed, implied, or *
268
+ * statutory, including, without limitation, warranties that the *
269
+ * Covered Software is free of defects, merchantable, fit for a *
270
+ * particular purpose or non-infringing. The entire risk as to the *
271
+ * quality and performance of the Covered Software is with You. *
272
+ * Should any Covered Software prove defective in any respect, You *
273
+ * (not any Contributor) assume the cost of any necessary servicing, *
274
+ * repair, or correction. This disclaimer of warranty constitutes an *
275
+ * essential part of this License. No use of any Covered Software is *
276
+ * authorized under this License except under this disclaimer. *
277
+ * *
278
+ ************************************************************************
279
+
280
+ ************************************************************************
281
+ * *
282
+ * 7. Limitation of Liability *
283
+ * -------------------------- *
284
+ * *
285
+ * Under no circumstances and under no legal theory, whether tort *
286
+ * (including negligence), contract, or otherwise, shall any *
287
+ * Contributor, or anyone who distributes Covered Software as *
288
+ * permitted above, be liable to You for any direct, indirect, *
289
+ * special, incidental, or consequential damages of any character *
290
+ * including, without limitation, damages for lost profits, loss of *
291
+ * goodwill, work stoppage, computer failure or malfunction, or any *
292
+ * and all other commercial damages or losses, even if such party *
293
+ * shall have been informed of the possibility of such damages. This *
294
+ * limitation of liability shall not apply to liability for death or *
295
+ * personal injury resulting from such party's negligence to the *
296
+ * extent applicable law prohibits such limitation. Some *
297
+ * jurisdictions do not allow the exclusion or limitation of *
298
+ * incidental or consequential damages, so this exclusion and *
299
+ * limitation may not apply to You. *
300
+ * *
301
+ ************************************************************************
302
+
303
+ 8. Litigation
304
+ -------------
305
+
306
+ Any litigation relating to this License may be brought only in the
307
+ courts of a jurisdiction where the defendant maintains its principal
308
+ place of business and such litigation shall be governed by laws of that
309
+ jurisdiction, without reference to its conflict-of-law provisions.
310
+ Nothing in this Section shall prevent a party's ability to bring
311
+ cross-claims or counter-claims.
312
+
313
+ 9. Miscellaneous
314
+ ----------------
315
+
316
+ This License represents the complete agreement concerning the subject
317
+ matter hereof. If any provision of this License is held to be
318
+ unenforceable, such provision shall be reformed only to the extent
319
+ necessary to make it enforceable. Any law or regulation which provides
320
+ that the language of a contract shall be construed against the drafter
321
+ shall not be used to construe this License against a Contributor.
322
+
323
+ 10. Versions of the License
324
+ ---------------------------
325
+
326
+ 10.1. New Versions
327
+
328
+ Mozilla Foundation is the license steward. Except as provided in Section
329
+ 10.3, no one other than the license steward has the right to modify or
330
+ publish new versions of this License. Each version will be given a
331
+ distinguishing version number.
332
+
333
+ 10.2. Effect of New Versions
334
+
335
+ You may distribute the Covered Software under the terms of the version
336
+ of the License under which You originally received the Covered Software,
337
+ or under the terms of any subsequent version published by the license
338
+ steward.
339
+
340
+ 10.3. Modified Versions
341
+
342
+ If you create software not governed by this License, and you want to
343
+ create a new license for such software, you may create and use a
344
+ modified version of this License if you rename the license and remove
345
+ any references to the name of the license steward (except to note that
346
+ such modified license differs from this License).
347
+
348
+ 10.4. Distributing Source Code Form that is Incompatible With Secondary
349
+ Licenses
350
+
351
+ If You choose to distribute Source Code Form that is Incompatible With
352
+ Secondary Licenses under the terms of this version of the License, the
353
+ notice described in Exhibit B of this License must be attached.
354
+
355
+ Exhibit A - Source Code Form License Notice
356
+ -------------------------------------------
357
+
358
+ This Source Code Form is subject to the terms of the Mozilla Public
359
+ License, v. 2.0. If a copy of the MPL was not distributed with this
360
+ file, You can obtain one at http://mozilla.org/MPL/2.0/.
361
+
362
+ If it is not possible or desirable to put the notice in a particular
363
+ file, then You may include the notice in a location (such as a LICENSE
364
+ file in a relevant directory) where a recipient would be likely to look
365
+ for such a notice.
366
+
367
+ You may add additional accurate notices of copyright ownership.
368
+
369
+ Exhibit B - "Incompatible With Secondary Licenses" Notice
370
+ ---------------------------------------------------------
371
+
372
+ This Source Code Form is "Incompatible With Secondary Licenses", as
373
+ defined by the Mozilla Public License, v. 2.0.
README copy.md ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SVGRender
3
+ emoji: 💻
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.20.1
8
+ python_version: 3.10.12
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ ---
13
+
14
+ <h1 id="ptsvg" align="center">Pytorch-SVGRender</h1>
15
+
16
+ <p align="center">
17
+ <a href="https://www.python.org/"><img src="https://img.shields.io/badge/python-3.10-or?logo=python" alt="pyhton"></a>
18
+ <a href="http://mozilla.org/MPL/2.0/"><img src="https://img.shields.io/badge/license-MPL2.0-orange" alt="license"></a>
19
+ <a href="https://ximinng.github.io/PyTorch-SVGRender-project/"><img src="https://img.shields.io/badge/website-Gitpage-yellow" alt="website"></a>
20
+ <a href="https://pytorch-svgrender.readthedocs.io/en/latest/index.html"><img src="https://img.shields.io/badge/docs-readthedocs-purple" alt="docs"></a>
21
+ </p>
22
+
23
+ <div align="center">
24
+ <img src="./assets/logo.png" style="width: 350px; height: 300px;" alt="Pytorch-SVGRender">
25
+ <p><strong>Pytorch-SVGRender: </strong>The go-to library for differentiable rendering methods for SVG generation.</p>
26
+ </div>
27
+ <p align="center">
28
+ <a href="#recent-updates">Updates</a> •
29
+ <a href="#table-of-contents">Table of Contents</a> •
30
+ <a href="#installation">Installation</a> •
31
+ <a href="#quickstart">Quickstart</a> •
32
+ <a href="#faq">FAQ</a> •
33
+ <a href="#todo">TODO</a> •
34
+ <a href="#acknowledgement">Acknowledgment</a> •
35
+ <a href="#citation">Citation</a> •
36
+ <a href="#licence">Licence</a>
37
+ </p>
38
+
39
+ <h2 align="center">Recent Updates</h2>
40
+
41
+ - [12/2023] 🔥 We open-sourced Pytorch-SVGRender V1.0.
42
+
43
+ <h2 align="center">Table of Contents</h2>
44
+ <p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
45
+
46
+ ### 1. Image Vectorization
47
+
48
+ - DiffVG: Differentiable Vector Graphics Rasterization for Editing and Learning (`SIGGRAPH 2020`)
49
+
50
+ [[Project]](https://people.csail.mit.edu/tzumao/diffvg/) [[Paper]](https://cseweb.ucsd.edu/~tzli/diffvg/diffvg.pdf) [[Code]](https://github.com/BachiLi/diffvg)
51
+
52
+ DiffVG is a differentiable rasterizer for 2D vector graphics. **This repository is heavily based on DiffVG.**
53
+
54
+ - LIVE: Towards Layer-wise Image Vectorization (`CVPR 2022`)
55
+
56
+ [[Project]](https://ma-xu.github.io/LIVE/) [[Paper]](https://ma-xu.github.io/LIVE/index_files/CVPR22_LIVE_main.pdf) [[Code]](https://github.com/Picsart-AI-Research/LIVE-Layerwise-Image-Vectorization)
57
+
58
+ - CLIPasso: Semantically-Aware Object Sketching (`SIGGRAPH 2022`)
59
+
60
+ [[Project]](https://clipasso.github.io/clipasso/) [[Paper]](https://arxiv.org/abs/2202.05822) [[Code]](https://github.com/yael-vinker/CLIPasso)
61
+
62
+ - CLIPascene: Scene Sketching with Different Types and Levels of Abstraction (`ICCV 2023`)
63
+
64
+ [[Project]](https://clipascene.github.io/CLIPascene/) [[Paper]](https://arxiv.org/abs/2211.17256) [[Code]](https://github.com/yael-vinker/SceneSketch)
65
+
66
+ ### 2. Text-to-SVG Synthesis
67
+
68
+ - CLIPDraw: Exploring Text-to-Drawing Synthesis through Language-Image Encoders (`NIPS 2022`)
69
+
70
+ [[Paper]](https://arxiv.org/abs/2106.14843) [[Code]](https://github.com/kvfrans/clipdraw)
71
+
72
+ - StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Synthesis
73
+
74
+ [[Live]](https://slideslive.com/38970834/styleclipdraw-coupling-content-and-style-in-texttodrawing-synthesis?ref=account-folder-92044-folders) [[Paper]](https://arxiv.org/abs/2202.12362) [[Code]](https://github.com/pschaldenbrand/StyleCLIPDraw)
75
+
76
+ - CLIPFont: Texture Guided Vector WordArt Generation (`BMVC 2022`)
77
+
78
+ [[Paper]](https://bmvc2022.mpi-inf.mpg.de/0543.pdf) [[Code]](https://github.com/songyiren98/CLIPFont)
79
+
80
+ - VectorFusion: Text-to-SVG by Abstracting Pixel-Based Diffusion Models (`CVPR 2023`)
81
+
82
+ [[Project]](https://vectorfusion.github.io/) [[Paper]](https://openaccess.thecvf.com/content/CVPR2023/papers/Jain_VectorFusion_Text-to-SVG_by_Abstracting_Pixel-Based_Diffusion_Models_CVPR_2023_paper.pdf)
83
+
84
+ - DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models (`NIPS 2023`)
85
+
86
+ [[Project]](https://ximinng.github.io/DiffSketcher-project/) [[Live]](https://neurips.cc/virtual/2023/poster/72425) [[Paper]](https://arxiv.org/abs/2306.14685) [[Code]](https://github.com/ximinng/DiffSketcher)
87
+
88
+ - Word-As-Image for Semantic Typography (`SIGGRAPH 2023`)
89
+
90
+ [[Project]](https://wordasimage.github.io/Word-As-Image-Page/) [[Paper]](https://arxiv.org/abs/2303.01818) [[Code]](https://github.com/Shiriluz/Word-As-Image)
91
+
92
+ - SVGDreamer: Text Guided SVG Generation with Diffusion Model (`CVPR 2024`)
93
+
94
+ [[Project]](https://ximinng.github.io/SVGDreamer-project/) [[Paper]](https://arxiv.org/abs/2312.16476) [[code]](https://github.com/ximinng/SVGDreamer)
95
+
96
+ <h2 align="center">Installation</h2>
97
+
98
+ You can follow the steps below to quickly get up and running with PyTorch-SVGRender.
99
+ These steps will let you run quick inference locally.
100
+
101
+ In the top level directory run,
102
+
103
+ ```bash
104
+ sh script/install.sh
105
+ ```
106
+
107
+ Note: Make sure that the script file has execution **permissions** (you can give them using `chmod +x script.sh`), and
108
+ then run the script.
109
+
110
+ For more information, please refer to
111
+ the [Install.md](https://github.com/ximinng/PyTorch-SVGRender/blob/main/Install.md).
112
+
113
+ <h2 align="center">Quickstart</h2>
114
+ <p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
115
+
116
+ **For more information, [read the docs](https://pytorch-svgrender.readthedocs.io/en/latest/index.html).**
117
+
118
+ ### 1. Basic Usage
119
+
120
+ **DiffVG** vectorizes any raster images:
121
+
122
+ ```shell
123
+ python svg_render.py x=diffvg target='./data/fallingwater.png'
124
+ # change 'num_paths' and 'num_iter' for better results
125
+ python svg_render.py x=diffvg target='./data/fallingwater.png' x.num_paths=512 x.num_iter=2000
126
+ ```
127
+
128
+ **LIVE** vectorizes the raster emojis images (in original PNG format):
129
+
130
+ ```shell
131
+ python svg_render.py x=live target='./data/simile.png'
132
+ # change 'num_paths' and 'schedule_each' for better results
133
+ python svg_render.py x=live target='./data/simile.png' x.num_paths=5 x.schedule_each=1
134
+ ```
135
+
136
+ **CLIPasso** synthesizes vectorized sketches from images:
137
+
138
+ **note:** first download the U2Net model `sh script/download_u2net.sh`.
139
+
140
+ ```shell
141
+ python svg_render.py x=clipasso target='./data/horse.png'
142
+ ```
143
+
144
+ **CLIPascene** synthesizes vectorized sketches from images:
145
+
146
+ **note:** first download the U2Net model `sh script/download_u2net.sh`, and make sure the `./data/background` folder and
147
+ the `./data/scene` folder exist with target images.
148
+
149
+ ```shell
150
+ python svg_render.py x=clipascene target='ballerina.png'
151
+ ```
152
+
153
+ **CLIPDraw** synthesizes SVGs based on text prompts:
154
+
155
+ ```shell
156
+ python svg_render.py x=clipdraw "prompt='a photo of a cat'"
157
+ ```
158
+
159
+ **StyleCLIPDraw** synthesizes SVG based on a text prompt and a reference image:
160
+
161
+ ```shell
162
+ python svg_render.py x=styleclipdraw "prompt='a photo of a cat'" target='./data/starry.png'
163
+ ```
164
+
165
+ **CLIPFont** styles vector fonts according to text prompts:
166
+
167
+ ```shell
168
+ python svg_render.py x=clipfont "prompt='Starry Night by Vincent van gogh'" target='./data/alphabet1.svg'
169
+ ```
170
+
171
+ ---
172
+
173
+ > Because the following methods rely on stable diffusion, add `diffuser.download=True` to the command the **first time** you
174
+ run the script.
175
+
176
+ **SVGDreamer** generates various styles of SVG based on text prompts. It supports the use of six vector primitives,
177
+ including Iconography, Sketch, Pixel Art, Low-Poly, Painting, and Ink and Wash.
178
+
179
+ ```shell
180
+ # primitive: iconography
181
+ ## 1. German shepherd
182
+ python svg_render.py x=svgdreamer "prompt='A colorful German shepherd in vector art. tending on artstation.'" save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 result_path='./svgdreamer/GermanShepherd'
183
+ ## 2. sydney opera house
184
+ python svg_render.py x=svgdreamer "prompt='Sydney opera house. oil painting. by Van Gogh'" save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.num_paths=512 result_path='./svgdreamer/Sydney'
185
+ # primitive: low-ploy
186
+ python svg_render.py x=svgdreamer "prompt='A picture of a bald eagle. low-ploy. polygon'" x.style='low-poly' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.num_iter=1000 result_path='./svgdreamer/eagle'
187
+ # primitive: pixel-art
188
+ python svg_render.py x=svgdreamer "prompt='Darth vader with lightsaber. ultrarealistic. pixelart. trending on artstation.'" x.style='pixelart' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.num_iter=1000 result_path='./svgdreamer/DarthVader'
189
+ # primitive: painting
190
+ python svg_render.py x=svgdreamer "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.style='painting' save_step=50 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.t_schedule='randint' x.num_paths=1500 result_path='./svgdreamer/VanGogh_portrait'
191
+ # primitive: sketch
192
+ python svg_render.py x=svgdreamer "prompt='A free-hand drawing of A speeding Lamborghini. black and white drawing.'" x.style='sketch' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.t_schedule='randint' x.num_paths=128 result_path='./svgdreamer/Lamborghini'
193
+ # primitive: ink and wash
194
+ python svg_render.py x=svgdreamer "prompt='Big Wild Goose Pagoda. ink style. Minimalist abstract art grayscale watercolor.'" x.style='ink' save_step=30 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2 x.guidance.t_schedule='randint' x.num_paths=128 x.width=6 result_path='./svgdreamer/BigWildGoosePagoda'
195
+ ```
196
+
197
+ **VectorFusion** synthesizes SVGs in various styles based on text prompts:
198
+
199
+ ```shell
200
+ # Iconography style
201
+ python svg_render.py x=vectorfusion x.style='iconography' "prompt='a panda rowing a boat in a pond. minimal flat 2d vector icon. lineal color. trending on artstation.'"
202
+ # PixelArt style
203
+ python svg_render.py x=vectorfusion x.style='pixelart' "prompt='a panda rowing a boat in a pond. pixel art. trending on artstation.'"
204
+ # Sketch style
205
+ python svg_render.py x=vectorfusion x.style='sketch' "prompt='a panda rowing a boat in a pond. minimal 2d line drawing. trending on artstation.'"
206
+ ```
207
+
208
+ Following SVGDreamer, we've added three additional styles (`Paining`, `Ink and Wash` and `low-ploy`) to VectorFusion.
209
+
210
+ **DiffSketcher** synthesizes vector sketches based on text prompts:
211
+
212
+ ```shell
213
+ # DiffSketcher
214
+ python svg_render.py x=diffsketcher "prompt='a photo of Sydney opera house'" x.token_ind=5 seed=8019
215
+ # DiffSketcher, variable stroke width
216
+ python svg_render.py x=diffsketcher "prompt='a photo of Sydney opera house'" x.token_ind=5 x.optim_width=True seed=8019
217
+ # DiffSketcher RGBA version
218
+ python svg_render.py x=diffsketcher "prompt='a photo of Sydney opera house'" x.token_ind=5 x.optim_width=True x.optim_rgba=True x.optim_opacity=False seed=8019
219
+ # DiffSketcher + style transfer
220
+ python svg_render.py x=stylediffsketcher "prompt='The French Revolution. highly detailed. 8k. ornate. intricate. cinematic. dehazed. atmospheric. oil painting. by Van Gogh'" x.token_ind=4 x.num_paths=2000 target='./data/starry.png' seed=876809
221
+ ```
222
+
223
+ **Word-As-Image** follow a text prompt to style a letter in a word:
224
+
225
+ ```shell
226
+ # Inject the meaning of the word bunny into the 'Y' in the word 'BUNNY'
227
+ python svg_render.py x=wordasimage x.word='BUNNY' prompt='BUNNY' x.optim_letter='Y'
228
+ ```
229
+
230
+ ### 2. SDS Loss based Approach
231
+
232
+ This is achieved by utilizing a pretrained text-to-image diffusion model as a strong image prior to supervise the
233
+ training of the PyDiffVG, enabling rendering SVG alignment with the text. This remarkable capability is fundamentally
234
+ grounded in the use of Score Distillation Sampling (SDS). SDS acts as the core mechanism that lifts raster images from
235
+ diffusion models to the SVG domain, enabling the training of SVG parameters without images.
236
+ This includes the methods VectorFusion, DiffSketcher and SVGDreamer.
237
+
238
+ We only compare the performance of SDS, which means that no other loss is used:
239
+
240
+ ```shell
241
+ # SDS loss
242
+ python svg_render.py x=vectorfusion "prompt='a panda rowing a boat in a pond. minimal flat 2d vector icon. lineal color. trending on artstation.'"
243
+ # Input Augmentation SDS loss (LSDS loss)
244
+ python svg_render.py x=vectorfusion x.style='sketch' "prompt='an elephant. minimal 2d line drawing. trending on artstation.'"
245
+ # Input Augmentation SDS loss (ASDS loss)
246
+ python svg_render.py x=diffsketcher "prompt='an elephant. minimal 2d line drawing. trending on artstation.'" x.token_ind=2 x.sds.grad_scale=1 x.sds.num_aug=4 x.clip.vis_loss=0 x.perceptual.coeff=0 x.opacity_delta=0.3
247
+ # Vectorized Particle-based Score Distillation (VPSD loss)
248
+ python svg_render.py x=svgdreamer "prompt='a panda rowing a boat in a pond. minimal flat 2d vector icon. lineal color. trending on artstation.'" save_step=60 x.guidance.n_particle=6 x.guidance.vsd_n_particle=4 x.guidance.phi_n_particle=2
249
+ ```
250
+
251
+ <h2 align="center">FAQ</h2>
252
+ <p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
253
+
254
+ - Q: Where can I get more scripts and visualizations?
255
+ - A: check the [pytorch-svgrender.readthedocs.io](https://pytorch-svgrender.readthedocs.io/en/latest/index.html).
256
+
257
+ - Q: An error says HuggingFace cannot find the model in the disk cache.
258
+ - A: Add *`diffuser.download=True`* to the command for downloading model checkpoints the **first time** you run the script.
259
+
260
+ <h2 align="center">TODO</h2>
261
+ <p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
262
+
263
+ - [x] integrated SVGDreamer.
264
+
265
+ <h2 align="center">Acknowledgement</h2>
266
+ <p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
267
+
268
+ The project is built based on the following repository:
269
+
270
+ [BachiLi/diffvg](https://github.com/BachiLi/diffvg),
271
+ [huggingface/diffusers](https://github.com/huggingface/diffusers),
272
+ [threestudio-project/threestudio](https://github.com/threestudio-project/threestudio),
273
+ [yael-vinker/CLIPasso](https://github.com/yael-vinker/CLIPasso),
274
+ [ximinng/DiffSketcher](https://github.com/ximinng/DiffSketcher),
275
+ [THUDM/ImageReward](https://github.com/THUDM/ImageReward),
276
+ [advimman/lama](https://github.com/advimman/lama)
277
+
278
+ We gratefully thank the authors for their wonderful works.
279
+
280
+ <h2 align="center">Citation</h2>
281
+ <p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
282
+
283
+ If you use this code for your research, please cite the following work:
284
+
285
+ ```
286
+ @article{xing2023svgdreamer,
287
+ title={SVGDreamer: Text Guided SVG Generation with Diffusion Model},
288
+ author={Xing, Ximing and Zhou, Haitao and Wang, Chuang and Zhang, Jing and Xu, Dong and Yu, Qian},
289
+ journal={arXiv preprint arXiv:2312.16476},
290
+ year={2023}
291
+ }
292
+ @inproceedings{xing2023diffsketcher,
293
+ title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models},
294
+ author={XiMing Xing and Chuang Wang and Haitao Zhou and Jing Zhang and Qian Yu and Dong Xu},
295
+ booktitle={Thirty-seventh Conference on Neural Information Processing Systems (NeurIPS)},
296
+ year={2023},
297
+ url={https://openreview.net/forum?id=CY1xatvEQj}
298
+ }
299
+ ```
300
+
301
+ <h2 align="center">Licence</h2>
302
+ <p align="right"><a href="#ptsvg"><sup>▴ Back to top</sup></a></p>
303
+
304
+ This work is licensed under a **Mozilla Public License Version 2.0**.
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ import tempfile
5
+
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+ sys.path.append('/home/user/app/code')
10
+
11
+ # set up diffvg
12
+ # os.system('git clone https://github.com/BachiLi/diffvg.git')
13
+ os.chdir('diffvg')
14
+ os.system('git submodule update --init --recursive')
15
+ os.system('python setup.py install --user')
16
+ sys.path.append("/home/user/.local/lib/python3.10/site-packages/diffvg-0.0.1-py3.10-linux-x86_64.egg")
17
+ print("diffvg installed.")
18
+ os.chdir('/home/user/app')
19
+
20
+
21
+ def process_images(prompt, num_paths, token_index, seed, optimize_width=False, optimize_color=False):
22
+ with tempfile.TemporaryDirectory() as tmpdirname:
23
+ command = [
24
+ "python", "svg_render.py",
25
+ "x=diffsketcher",
26
+ f"prompt={prompt}",
27
+ f"x.num_paths={num_paths}",
28
+ f"x.token_ind={token_index}",
29
+ f"seed={seed}",
30
+ f"x.optim_width={optimize_width}",
31
+ f"x.optim_rgba={optimize_color}",
32
+ "x.optim_opacity=False",
33
+ ]
34
+ result = subprocess.run(command, check=True)
35
+ if result.returncode == 0:
36
+ output_image = Image.open(os.path.join(tmpdirname, "final_render.png"))
37
+ return output_image
38
+
39
+
40
+ with gr.Blocks() as demo:
41
+ gr.Markdown("# DiffSketcher")
42
+ gr.Markdown("DiffSketcher synthesizes **vector sketches** based on **text prompts**.")
43
+ li = [
44
+ "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/cat.svg",
45
+ "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/rose.svg",
46
+ "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/elephant.svg",
47
+ "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/elephant_silhouette.svg",
48
+ "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/horse_width.svg",
49
+ "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/horse_rgba.svg",
50
+ "https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera.svg",
51
+ "https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera_width.svg",
52
+ "https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera_width_color.svg",
53
+ ]
54
+ gr.Gallery(li, columns=6)
55
+ with gr.Row():
56
+ with gr.Column():
57
+ text = gr.Textbox(label="prompt")
58
+ num_paths = gr.Slider(label="path number", value=96, minimum=1, maximum=500, step=1)
59
+ token_index = gr.Textbox(label="token_index", info="CLIP embedding token index. Starting from 1.")
60
+ seed = gr.Slider(0, 10000, label="random seed", value=8019)
61
+ with gr.Accordion("Selectable Inputs"):
62
+ optimize_width = gr.Checkbox(label="optimize stroke width")
63
+ optimize_color = gr.Checkbox(label="optimize stroke color")
64
+ btn = gr.Button("Synthesize")
65
+ with gr.Column():
66
+ output = gr.Image(label="output image", height=512)
67
+ btn.click(process_images,
68
+ inputs=[text, num_paths, token_index, seed, optimize_width, optimize_color],
69
+ outputs=[output])
70
+ gr.Markdown("## Examples")
71
+ gr.Markdown("Here are some config examples. Feel free to try your own prompts!")
72
+ gr.Examples(
73
+ inputs=[text, num_paths, token_index, seed, optimize_width, optimize_color],
74
+ outputs=[output],
75
+ fn=process_images,
76
+ examples=[
77
+ ["A photo of Sydney opera house.", 96, 5, 8019, False, False],
78
+ ["A photo of Sydney opera house.", 96, 5, 8019, True, False],
79
+ ["A photo of Sydney opera house.", 128, 5, 8019, True, True],
80
+ ],
81
+ )
82
+
83
+ demo.launch()
assets/fonts/Bell-MT.ttf ADDED
Binary file (84.8 kB). View file
 
assets/fonts/DeliusUnicase-Regular.ttf ADDED
Binary file (31.5 kB). View file
 
assets/fonts/HobeauxRococeaux-Sherman.ttf ADDED
Binary file (117 kB). View file
 
assets/fonts/IndieFlower-Regular.ttf ADDED
Binary file (55.4 kB). View file
 
assets/fonts/JosefinSans-Light.ttf ADDED
Binary file (59.3 kB). View file
 
assets/fonts/KaushanScript-Regular.ttf ADDED
Binary file (184 kB). View file
 
assets/fonts/LuckiestGuy-Regular.ttf ADDED
Binary file (58.3 kB). View file
 
assets/fonts/Noteworthy-Bold.ttf ADDED
Binary file (248 kB). View file
 
assets/fonts/Quicksand.ttf ADDED
Binary file (124 kB). View file
 
assets/fonts/Saira-Regular.ttf ADDED
Binary file (82.8 kB). View file
 
checkpoint/placeholder.md ADDED
@@ -0,0 +1 @@
 
 
1
+ **place model here**
conf/config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-----------------#
2
+ # Global Config #
3
+ #-----------------#
4
+
5
+ # optional args
6
+ target: ~
7
+ prompt: ~
8
+ neg_prompt: ~ # negative prompt
9
+
10
+ # Accelerate config
11
+ state:
12
+ cpu: False # use cpu
13
+ mprec: no # mixed precision, choices: 'no', 'fp16', 'bf16'
14
+ # wandb: False
15
+ # tensorboard: False
16
+
17
+ # Diffusers config
18
+ diffuser:
19
+ download: True # Set this variable to True the first time it runs
20
+ force_download: False
21
+ resume_download: False
22
+
23
+ # PyDiffVG config
24
+ diffvg:
25
+ print_timing: False
26
+
27
+ # reproduction
28
+ seed: 951222
29
+ # multi-run
30
+ multirun: False
31
+ srange: ~ # seed range, example: [100, 100]
32
+
33
+ # log
34
+ result_path: './workspace'
35
+ save_step: 10
36
+ eval_step: 10
37
+
38
+ # visual rendering process
39
+ mv: False # make video
40
+ framefreq: 5 # save the image interval
41
+ framerate: 24 # by adjusting the frame rate, you can control the playback speed of the output video
42
+
43
+ # hydra setting
44
+ hydra:
45
+ help:
46
+ # app name, override to match the name your app is known by
47
+ app_name: 'SVGRender'
48
+ run:
49
+ # output directory for normal runs
50
+ # warning: make sure that the L56-58 of '/libs/engine/model_state.py' and 'dir' are modified together
51
+ dir: ./${result_path}/${x.method}-${now:%Y-%m-%d-%H-%M}
52
+
53
+ # default settings
54
+ defaults:
55
+ - _self_
56
+ - x: ~
conf/x/clipascene.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'clipascene'
2
+
3
+ im_name: ""
4
+ image_size: 224
5
+ u2net_path: "./checkpoint/u2net/u2net.pth"
6
+
7
+ background_layer: 2 # 2, 8, 11
8
+ background_div: 0.35 # 0.35, 0.5, 0.85
9
+ background_num_iter: 1501
10
+
11
+ foreground_layer: 2 # 2, 8, 11
12
+ foreground_div: 0.4 # 0.4, 0.5, 0.9
13
+ foreground_num_iter: 600 # 1000 if foreground_layer >= 8 else 600
14
+
15
+ # general
16
+ target: null
17
+ output_dir: null
18
+ path_svg: "none"
19
+ mask_object: 0
20
+ resize_obj: 0
21
+ fix_scale: 0
22
+ display_logs: 0
23
+ display: 0
24
+ test_name: "test"
25
+
26
+ # training
27
+ num_iter: 2001
28
+ num_stages: 1
29
+ lr_scheduler: 0
30
+ lr: 0.0001
31
+ color_lr: 0.01
32
+ width_lr: 0.0001
33
+ color_vars_threshold: 0.0
34
+ batch_size: 1
35
+ save_step: 100
36
+ eval_step: 20
37
+ loss_mask: "none"
38
+ dilated_mask: 0
39
+ mask_cls: None
40
+ mask_attention: 0
41
+
42
+ # strokes params
43
+ num_paths: 64
44
+ width: 1.5
45
+ control_points_per_seg: 4
46
+ num_segments: 1
47
+ attention_init: 1
48
+ saliency_model: "clip"
49
+ saliency_clip_model: "ViT-B/32"
50
+ xdog_intersec: 1
51
+ mask_object_attention: 0
52
+ softmax_temp: 0.3
53
+ mlp_train: 1
54
+ width_optim: 0
55
+ mlp_width_weights_path: "none"
56
+ mlp_points_weights_path: "none"
57
+ switch_loss: 0
58
+ gumbel_temp: 0.2
59
+ width_loss_weight: 0
60
+ width_loss_type: "L1"
61
+ optimize_points: 1
62
+ load_points_opt_weights: 0
63
+ gradnorm: 0
64
+ width_weights_lst: ""
65
+ ratio_loss: 0
66
+
67
+ # loss
68
+ percep_loss: "none"
69
+ perceptual_weight: 0
70
+ train_with_clip: 0
71
+ clip_weight: 0
72
+ start_clip: 0
73
+ num_aug_clip: 4
74
+ include_target_in_aug: 0
75
+ augment_both: 1
76
+ augemntations: "affine"
77
+ noise_thresh: 0.5
78
+ aug_scale_min: 0.7
79
+ force_sparse: 0
80
+ clip_conv_loss: 1
81
+ clip_mask_loss: 0
82
+ clip_conv_loss_type: "L2"
83
+ clip_conv_layer_weights: "0,0,1.0,1.0,0"
84
+ clip_model_name: "ViT-B/32"
85
+ clip_fc_loss_weight: 0
86
+ clip_text_guide: 0
87
+ text_target: None
conf/x/clipasso.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'clipasso'
2
+
3
+ image_size: 224
4
+ mask_object: False
5
+ fix_scale: False
6
+ path_svg: ~ # if you want to load a svg file and train from it
7
+
8
+ # train
9
+ num_iter: 2001
10
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
11
+ lr_schedule: False
12
+ lr: 1
13
+ color_lr: 0.01
14
+ color_vars_threshold: 0.0
15
+
16
+ # SVG path attr
17
+ num_paths: 24 # number of strokes
18
+ width: 1.5 # stroke width
19
+ control_points_per_seg: 4
20
+ num_segments: 1
21
+ attention_init: 1 # if True, use the attention heads of Dino model to set the location of the initial strokes
22
+ saliency_model: "clip"
23
+ saliency_clip_model: "ViT-B/32"
24
+ xdog_intersec: 1
25
+ mask_object_attention: 0
26
+ softmax_temp: 0.3
27
+ u2net_path: "./checkpoint/u2net/u2net.pth"
28
+
29
+ # loss
30
+ percep_loss: "none"
31
+ perceptual_weight: 0
32
+ train_with_clip: 0
33
+ clip_weight: 0
34
+ start_clip: 0
35
+ num_aug_clip: 4
36
+ include_target_in_aug: 0
37
+ augment_both: 0
38
+ augemntations: "affine" # can be any combination of: 'affine_noise_eraserchunks_eraser_press'
39
+ noise_thresh: 0.5
40
+ aug_scale_min: 0.7
41
+ force_sparse: 0 # if True, use L1 regularization on stroke's opacity to encourage small number of strokes
42
+ clip_conv_loss: 1
43
+ clip_conv_loss_type: "L2"
44
+ clip_conv_layer_weights: "0,0,1.0,1.0,0"
45
+ clip_model_name: "RN101"
46
+ clip_fc_loss_weight: 0.1
47
+ clip_text_guide: 0
48
+ text_target: None
conf/x/clipdraw.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'clipdraw'
2
+
3
+ image_size: 224 # canvas size
4
+ path_svg: ~ # if you want to load a svg file and train from it
5
+
6
+ # train
7
+ num_iter: 1000
8
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
9
+ lr_schedule: True
10
+ lr: 1
11
+ width_lr: 0.1
12
+ color_lr: 0.01
13
+
14
+ # SVG path attr
15
+ num_paths: 512 # number of strokes
16
+ max_width: 50 # stroke width
17
+ black_stroke_color: False
18
+
19
+ # loss
20
+ num_aug: 4
conf/x/clipfont.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'clipfont'
2
+
3
+ # optimizer
4
+ lr_base:
5
+ point: 0.1
6
+ color: 0.01
7
+ lr_decay_rate: 0.1
8
+ decay_steps: [ 1000, 1500 ]
9
+ lr_schedule: False
10
+
11
+ # train
12
+ num_iter: 200
13
+ batch_size: 1
14
+ font:
15
+ reinit: False
16
+ reinit_color: 'randn' # 'randn', 'randn_all', 'green' et al
17
+
18
+ # loss
19
+ clip:
20
+ model_name: "ViT-B/32" # RN101, 'ViT-B/32', ViT-L/14
21
+ thresh: 0.0
22
+ num_crops: 128
23
+ crop_size: 230
24
+ lam_patch: 150
25
+ lam_dir: 30
26
+ lam_lpips: 0
27
+ lam_l2: 0
conf/x/diffsketcher.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'diffsketcher'
2
+
3
+ image_size: 224 # canvas size
4
+ path_svg: ~ # if you want to load a svg file and train from it
5
+ mask_object: False # if the target image contains background, it's better to mask it out
6
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
7
+
8
+ # train
9
+ num_iter: 2000
10
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
11
+ lr_schedule: False
12
+ lr_decay_rate: 0.1
13
+ decay_steps: [ 1000, 1500 ]
14
+ lr: 1
15
+ color_lr: 0.01
16
+ color_vars_threshold: 0.0 # uncomment the code
17
+ width_lr: 0.1
18
+ max_width: 50 # stroke width
19
+
20
+ # stroke attrs
21
+ num_paths: 128 # number of strokes
22
+ width: 1.5 # stroke width
23
+ control_points_per_seg: 4
24
+ num_segments: 1
25
+ optim_opacity: True # if True, the stroke opacity is optimized
26
+ optim_width: False # if True, the stroke width is optimized
27
+ optim_rgba: False # if True, the stroke RGBA is optimized
28
+ opacity_delta: 0 # stroke pruning
29
+
30
+ # init strokes
31
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
32
+ xdog_intersec: True # initialize along the edge, mix XDoG and attn up
33
+ softmax_temp: 0.5
34
+ cross_attn_res: 16
35
+ self_attn_res: 32
36
+ max_com: 20
37
+ mean_comp: False
38
+ comp_idx: 0
39
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
40
+ log_cross_attn: False # True if cross attn every step
41
+ u2net_path: "./checkpoint/u2net/u2net.pth"
42
+
43
+ # ldm
44
+ model_id: "sd15"
45
+ ldm_speed_up: False
46
+ enable_xformers: True
47
+ gradient_checkpoint: False
48
+ token_ind: 5
49
+ use_ddim: True
50
+ num_inference_steps: 100
51
+ guidance_scale: 7.5 # sdxl default 5.0
52
+
53
+ # ASDS loss
54
+ sds:
55
+ crop_size: 512
56
+ augmentations: "affine"
57
+ guidance_scale: 100
58
+ grad_scale: 1e-6
59
+ t_range: [ 0.05, 0.95 ]
60
+ warmup: 2000
61
+
62
+ clip:
63
+ model_name: "RN101" # RN101, ViT-L/14
64
+ feats_loss_type: "l2" # clip visual loss type, conv layers
65
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
66
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
67
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
68
+ augmentations: "affine" # augmentation before clip visual computation
69
+ num_aug: 4 # num of augmentation before clip visual computation
70
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
71
+ text_visual_coeff: 0 # cosine similarity between text and img
72
+
73
+ perceptual:
74
+ name: "lpips" # dists
75
+ lpips_net: 'vgg'
76
+ coeff: 0.2
conf/x/diffvg.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'diffvg'
2
+
3
+ # train
4
+ num_iter: 2000 # num_iter
5
+ lr_base:
6
+ point: 1
7
+ color: 0.01
8
+ stroke_width: 0.1
9
+ stroke_color: 0.01
10
+ lr_schedule: False # use lr_schedule
11
+
12
+ # SVG path attr
13
+ num_paths: 512 # number of paths
14
+ max_width: 5.0 # maximum width
15
+ path_type: 'unclosed' # or 'closed', using Closed curve or non-closed curve
16
+
17
+ # loss
18
+ loss_type: 'l2' # or 'l1', 'l2', 'lpips', 'l2+lpips', loss type
conf/x/live.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'live'
2
+
3
+ image_size: 240 # img size and canvas size
4
+
5
+ # train
6
+ num_iter: 500 # num_iter per path group
7
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
8
+ lr_base:
9
+ point: 1
10
+ color: 0.01
11
+ bg: 0.01
12
+ stroke_width: 0.1
13
+ stroke_color: 0.01
14
+ lr_schedule: True # use lr_schedule
15
+
16
+ # SVG path attr
17
+ num_paths: 5 # number of strokes
18
+ path_schedule: 'repeat'
19
+ schedule_each: 1 # [1, 3, 5, 7]
20
+ train_stroke: False # train stroke width and color
21
+ trainable_bg: False # set the background to be trainable
22
+ width: 3 # stroke width
23
+ num_segments: 4
24
+ segment_init: 'circle' # 'random'
25
+ radius: 5
26
+ coord_init: 'sparse' # 'random', 'naive', place the first control point
27
+
28
+ # loss
29
+ use_l1_loss: False
30
+ use_distance_weighted_loss: True
31
+ xing_loss_weight: 0.01
conf/x/styleclipdraw.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'styleclipdraw'
2
+
3
+ image_size: 224 # canvas size
4
+ path_svg: ~ # if you want to load an svg file and train from it
5
+
6
+ # train
7
+ num_iter: 1000
8
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
9
+ lr_schedule: True # anneal learning rate
10
+ lr: 1
11
+ width_lr: 0.1
12
+ color_lr: 0.01
13
+
14
+ # strokes
15
+ num_paths: 512 # number of strokes
16
+ max_width: 50 # stroke width
17
+ black_stroke_color: False
18
+ style_strength: 50 # How strong the style should be. 100 (max) is a lot. 0 (min) is no style.
19
+
20
+ # loss
21
+ num_aug: 10 # Number of image augmentations
conf/x/stylediffsketcher.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: 'stylediffsketcher'
2
+
3
+ image_size: 224 # canvas size
4
+ path_svg: ~ # if you want to load a svg file and train from it
5
+ mask_object: False # if the target image contains background, it's better to mask it out
6
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
7
+
8
+ # train
9
+ num_iter: 2000
10
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
11
+ lr_schedule: False
12
+ lr_decay_rate: 0.1
13
+ decay_steps: [ 1000, 1500 ]
14
+ lr: 1
15
+ color_lr: 0.01
16
+ color_vars_threshold: 0.0 # uncomment the code
17
+ width_lr: 0.1
18
+ max_width: 50 # stroke width
19
+
20
+ # SVG path attrs
21
+ num_paths: 512 # number of strokes
22
+ width: 1.5 # init stroke width
23
+ control_points_per_seg: 4
24
+ num_segments: 1
25
+ optim_opacity: True # if True, the stroke opacity is optimized
26
+ optim_width: True # if True, the stroke width is optimized
27
+ optim_rgba: True # if True, the stroke RGBA is optimized
28
+
29
+ # init strokes
30
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
31
+ xdog_intersec: True # initialize along the edge, mix XDoG and attn up
32
+ softmax_temp: 0.4
33
+ cross_attn_res: 16
34
+ self_attn_res: 32
35
+ max_com: 20
36
+ mean_comp: False
37
+ comp_idx: 0
38
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
39
+ log_cross_attn: False
40
+ u2net_path: "./checkpoint/u2net/u2net.pth"
41
+
42
+ # ldm
43
+ model_id: "sd15"
44
+ ldm_speed_up: False
45
+ enable_xformers: True
46
+ gradient_checkpoint: False
47
+ token_ind: 5
48
+ use_ddim: True
49
+ num_inference_steps: 100
50
+ guidance_scale: 7.5
51
+
52
+ # ASDS loss
53
+ sds:
54
+ crop_size: 512
55
+ augmentations: "affine"
56
+ guidance_scale: 100
57
+ grad_scale: 0
58
+ t_range: [ 0.05, 0.95 ]
59
+ warmup: 120
60
+
61
+ clip:
62
+ model_name: "RN101" # RN101, ViT-L/14
63
+ feats_loss_type: "l2" # clip visual loss type, conv layers
64
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
65
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
66
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
67
+ augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
68
+ num_aug: 4 # num of augmentation before clip visual computation
69
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
70
+ text_visual_coeff: 0 # cosine similarity between text and img
71
+
72
+ perceptual:
73
+ name: "lpips" # dists
74
+ lpips_net: 'vgg'
75
+ coeff: 0.2
76
+
77
+ style_strength: 1 # How strong the style should be. 100 (max) is a lot. 0 (min) is no style.
conf/x/svgdreamer.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: "svgdreamer"
2
+
3
+ image_size: 600 # canvas size
4
+ path_svg: ~ # if you want to load a svg file and train from it
5
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
6
+ skip_sive: True # optimize from scratch without SIVE init
7
+ color_init: 'rand' # if skip_live=True, then use color_init to init target_img
8
+ style: "iconography" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
9
+
10
+ # lr and optim
11
+ lr_stage_one: # SIVE stage
12
+ point: 1 # control points
13
+ width: 0.1 # stroke width
14
+ color: 0.01 # fill color and stroke color
15
+ bg: 0.01 # bg in render_warp
16
+ optim:
17
+ name: 'adam'
18
+ betas: [ 0.9, 0.9 ]
19
+ eps: 1e-6
20
+ lr_schedule: True # use lr_scheduler
21
+ schedule:
22
+ name: 'linear'
23
+ keep_ratio: 0.2
24
+ decay_ratio: 0.4
25
+ lr_stage_two: # VPSD stage
26
+ point: 1
27
+ width: 0.1
28
+ color: 0.01
29
+ bg: 0.01
30
+ lr_schedule: True # use lr_scheduler
31
+ optim:
32
+ name: 'adam'
33
+ betas: [ 0.9, 0.9 ]
34
+ eps: 1e-6
35
+ schedule:
36
+ name: 'cosine'
37
+ warmup_steps: 10
38
+ warmup_start_lr: 0.02
39
+ warmup_end_lr: 0.8
40
+ cosine_end_lr: 0.4
41
+
42
+ # primitives
43
+ num_paths: 256 # number of strokes
44
+ trainable_bg: False # set the background to be trainable
45
+ width: 3 # stroke width
46
+ num_segments: 4
47
+ segment_init: 'circle' # 'random'
48
+ radius: 20
49
+ coord_init: 'random' # 'random', 'naive', place the first control point
50
+ grid: 50 # divide the canvas into n grids
51
+ path_reinit: # reinitializing paths
52
+ use: True
53
+ freq: 100 # every 50 iterations
54
+ stop_step: 1000 # for VPSD fine-tuning
55
+ opacity_threshold: 0.05
56
+ area_threshold: 64
57
+
58
+ # diffusion
59
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
60
+ ldm_speed_up: False
61
+ enable_xformers: True
62
+ gradient_checkpoint: False
63
+ cpu_offload: True
64
+ num_inference_steps: 50
65
+ guidance_scale: 7.5 # sdxl default 5.0
66
+ K: 4
67
+ lora_path: ~
68
+
69
+ # VPSD loss
70
+ guidance:
71
+ use: True
72
+ type: 'vpsd'
73
+ n_particle: 1 # 4, 8, 16
74
+ vsd_n_particle: 1 # the batch size of particles
75
+ particle_aug: False # do data enhancement for the input particles
76
+ num_iter: 2000 # total iterations
77
+ guidance_scale: 7.5 # CFG value
78
+ grad_scale: 1.0 # increase or decrease the gradient
79
+ grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
80
+ t_range: [ 0.02, 0.98 ]
81
+ # 'randint': random time steps, this may have a more authentic style.
82
+ # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
83
+ t_schedule: 'max_0.5_1000' # or 'randint'
84
+ # phi model config
85
+ phi_single: False # if False new an unet model to estimate noise
86
+ phi_model: 'lora' # 'lora', 'unet_simple'
87
+ use_attn_scale: ${x.guidance.phi_single} # use lora_attn_scale or not
88
+ lora_attn_scale: 1.0 # the scale of the attn based lora layer
89
+ phi_guidance_scale: 1.0
90
+ phi_t: False # different t for phi fine-tuning
91
+ phi_update_step: 1 # enable multi-update phi model or not
92
+ phi_lr: 0.0001 # learning rate of phi model
93
+ phi_scheduler: 'ddim' # 'dpm-solver'
94
+ phi_n_particle: 1 # the batch size of phi_model
95
+ # ReFL config
96
+ phi_ReFL: False # enable reward feed back learning
97
+ n_phi_sample: 1 # number of samples used in ReFL
98
+ phi_sample_step: 200 # the phi log step
99
+ phi_infer_step: 50 # the phi num_inference_steps
100
+ # phi model optim
101
+ phi_optim:
102
+ name: 'adamw'
103
+ betas: [ 0.9, 0.999 ]
104
+ eps: 1e-8
105
+ weight_decay: ~ # 1e-5
106
+ # phi model lr learning schedule
107
+ phi_schedule:
108
+ use: False
109
+ name: 'cosine'
110
+ warmup_steps: 50
111
+ warmup_start_lr: 0.00001
112
+ warmup_end_lr: 0.0001
113
+ total_step: 800
114
+ cosine_end_lr: 0.0001
115
+
116
+ # reward model
117
+ reward_path: './checkpoint/ImageReward'
118
+
119
+ # xing loss for closed-form paths
120
+ xing_loss:
121
+ use: False
122
+ weight: 0.01
conf/x/vectorfusion.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: "vectorfusion"
2
+
3
+ image_size: 600 # canvas size
4
+ path_svg: ~ # if you want to load a svg file and train from it
5
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
6
+ skip_live: False # if skip_live then training from scratch
7
+ style: "iconography" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
8
+
9
+ # train
10
+ batch_size: 1
11
+ num_iter: 500 # num_iter per path group
12
+ # lr and optim
13
+ lr_stage_one:
14
+ point: 1
15
+ width: 0.1
16
+ color: 0.01
17
+ bg: 0.01
18
+ optim:
19
+ name: 'adam'
20
+ betas: [ 0.9, 0.9 ]
21
+ eps: 1e-6
22
+ lr_schedule: True # use lr_scheduler
23
+ schedule:
24
+ name: 'linear'
25
+ keep_ratio: 0.2
26
+ decay_ratio: 0.4
27
+ lr_stage_two:
28
+ point: 1
29
+ width: 0.1
30
+ color: 0.01
31
+ bg: 0.01
32
+ lr_schedule: True # use lr_scheduler
33
+ optim:
34
+ name: 'adam'
35
+ betas: [ 0.9, 0.9 ]
36
+ eps: 1e-6
37
+ schedule:
38
+ name: 'cosine'
39
+ warmup_steps: 50
40
+ warmup_start_lr: 0.02
41
+ warmup_end_lr: 1.0
42
+ cosine_end_lr: 0.4
43
+
44
+ # primitives
45
+ num_paths: 128 # number of strokes
46
+ path_schedule: 'repeat' # 'list'
47
+ schedule_each: 16 # [1, 3, 5, 7]
48
+ trainable_bg: False # set the background to be trainable
49
+ width: 3 # stroke width
50
+ num_segments: 4
51
+ segment_init: 'circle' # 'random'
52
+ radius: 20
53
+ coord_init: 'sparse' # 'random', 'naive', place the first control point
54
+ grid: 32 # divide the canvas into n grids
55
+ path_reinit: # reinitializing paths
56
+ use: True
57
+ freq: 50 # every 50 iterations
58
+ stop_step: 800 # for SDS fine-tuning
59
+ opacity_threshold: 0.05
60
+ area_threshold: 64
61
+
62
+ # diffusion
63
+ model_id: "sd15" # sd14, sd15, sd21, sd21b, sdxl
64
+ ldm_speed_up: False
65
+ enable_xformers: True
66
+ gradient_checkpoint: False
67
+ cpu_offload: True
68
+ num_inference_steps: 50
69
+ guidance_scale: 7.5 # sdxl default 5.0
70
+ K: 6
71
+ lora_path: ~
72
+
73
+ # SDS
74
+ sds:
75
+ im_size: 512
76
+ guidance_scale: 100
77
+ grad_scale: 1.0
78
+ t_range: [ 0.05, 0.95 ]
79
+ num_iter: 1000 # fine-tuning steps
80
+
81
+ # Live loss
82
+ use_distance_weighted_loss: True
83
+ xing_loss_weight: 0.01
84
+ # pixel loss
85
+ penalty_weight: 0.05
conf/x/wordasimage.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: "wordasimage"
2
+
3
+ image_size: 600 # canvas size
4
+ word: "BUNNY"
5
+ optim_letter: "Y"
6
+ prompt_suffix: "minimal flat 2d vector. lineal color. trending on artstation"
7
+
8
+ # train
9
+ num_iter: 500
10
+ lr_schedule: True
11
+ lr:
12
+ point_lr: 1
13
+ lr_init: 0.002
14
+ lr_final: 0.0008
15
+ lr_delay_mult: 0.1
16
+ lr_delay_steps: 100
17
+
18
+ # font
19
+ font: 'KaushanScript-Regular'
20
+ font_path: "./assets/fonts/${x.font}.ttf"
21
+ level_of_cc: 1 # 0 - original number of cc / 1 - recommended / 2 - more control points
22
+
23
+ # diffusion
24
+ model_id: "sd15"
25
+ ldm_speed_up: False
26
+ enable_xformers: False
27
+ gradient_checkpoint: False
28
+ lora_path: ~
29
+
30
+ # SDS
31
+ sds:
32
+ im_size: 512
33
+ guidance_scale: 100
34
+ grad_scale: 1.0
35
+ t_range: [ 0.05, 0.95 ]
36
+ num_iter: 1000
37
+
38
+ tone_loss:
39
+ use: True
40
+ dist_loss_weight: 100
41
+ pixel_dist_kernel_blur: 201
42
+ pixel_dist_sigma: 30
43
+
44
+ conformal:
45
+ use: True
46
+ angeles_w: 0.5
data/alphabet1.svg ADDED
data/ballerina.png ADDED
data/ch1.svg ADDED
data/fallingwater.png ADDED
data/horse.png ADDED
data/simile.png ADDED
data/starry.png ADDED