charlescxk
commited on
Commit
·
a38a851
1
Parent(s):
2c415b6
update
Browse files- README.md +0 -13
- Upsample/__init__.py +1 -0
- Upsample/__pycache__/__init__.cpython-38.pyc +0 -0
- Upsample/__pycache__/arch_utils.cpython-38.pyc +0 -0
- Upsample/__pycache__/model.cpython-38.pyc +0 -0
- Upsample/__pycache__/rrdbnet_arch.cpython-38.pyc +0 -0
- Upsample/__pycache__/utils.cpython-38.pyc +0 -0
- Upsample/arch_utils.py +197 -0
- Upsample/model.py +93 -0
- Upsample/rrdbnet_arch.py +121 -0
- Upsample/utils.py +135 -0
- app.py +268 -0
- doge.png +0 -0
- equation.png +0 -0
- janus/__init__.py +31 -0
- janus/__pycache__/__init__.cpython-38.pyc +0 -0
- janus/models/__init__.py +28 -0
- janus/models/__pycache__/__init__.cpython-38.pyc +0 -0
- janus/models/__pycache__/clip_encoder.cpython-38.pyc +0 -0
- janus/models/__pycache__/image_processing_vlm.cpython-38.pyc +0 -0
- janus/models/__pycache__/modeling_vlm.cpython-38.pyc +0 -0
- janus/models/__pycache__/processing_vlm.cpython-38.pyc +0 -0
- janus/models/__pycache__/projector.cpython-38.pyc +0 -0
- janus/models/__pycache__/siglip_vit.cpython-38.pyc +0 -0
- janus/models/__pycache__/vq_model.cpython-38.pyc +0 -0
- janus/models/clip_encoder.py +122 -0
- janus/models/image_processing_vlm.py +208 -0
- janus/models/modeling_vlm.py +272 -0
- janus/models/processing_vlm.py +418 -0
- janus/models/projector.py +100 -0
- janus/models/siglip_vit.py +681 -0
- janus/models/vq_model.py +527 -0
- janus/utils/__init__.py +18 -0
- janus/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- janus/utils/__pycache__/conversation.cpython-38.pyc +0 -0
- janus/utils/__pycache__/io.cpython-38.pyc +0 -0
- janus/utils/conversation.py +365 -0
- janus/utils/io.py +89 -0
- requirements.txt +8 -0
- weights/RealESRGAN_x2.pth +3 -0
README.md
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Janus Pro 7B
|
3 |
-
emoji: 🐢
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: gray
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.13.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Upsample/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import RealESRGAN
|
Upsample/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (213 Bytes). View file
|
|
Upsample/__pycache__/arch_utils.cpython-38.pyc
ADDED
Binary file (7.14 kB). View file
|
|
Upsample/__pycache__/model.cpython-38.pyc
ADDED
Binary file (3.11 kB). View file
|
|
Upsample/__pycache__/rrdbnet_arch.cpython-38.pyc
ADDED
Binary file (4.47 kB). View file
|
|
Upsample/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.05 kB). View file
|
|
Upsample/arch_utils.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn import init as init
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
10 |
+
"""Initialize network weights.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
14 |
+
scale (float): Scale initialized weights, especially for residual
|
15 |
+
blocks. Default: 1.
|
16 |
+
bias_fill (float): The value to fill bias. Default: 0
|
17 |
+
kwargs (dict): Other arguments for initialization function.
|
18 |
+
"""
|
19 |
+
if not isinstance(module_list, list):
|
20 |
+
module_list = [module_list]
|
21 |
+
for module in module_list:
|
22 |
+
for m in module.modules():
|
23 |
+
if isinstance(m, nn.Conv2d):
|
24 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
25 |
+
m.weight.data *= scale
|
26 |
+
if m.bias is not None:
|
27 |
+
m.bias.data.fill_(bias_fill)
|
28 |
+
elif isinstance(m, nn.Linear):
|
29 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
30 |
+
m.weight.data *= scale
|
31 |
+
if m.bias is not None:
|
32 |
+
m.bias.data.fill_(bias_fill)
|
33 |
+
elif isinstance(m, _BatchNorm):
|
34 |
+
init.constant_(m.weight, 1)
|
35 |
+
if m.bias is not None:
|
36 |
+
m.bias.data.fill_(bias_fill)
|
37 |
+
|
38 |
+
|
39 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
40 |
+
"""Make layers by stacking the same blocks.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
basic_block (nn.module): nn.module class for basic block.
|
44 |
+
num_basic_block (int): number of blocks.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
48 |
+
"""
|
49 |
+
layers = []
|
50 |
+
for _ in range(num_basic_block):
|
51 |
+
layers.append(basic_block(**kwarg))
|
52 |
+
return nn.Sequential(*layers)
|
53 |
+
|
54 |
+
|
55 |
+
class ResidualBlockNoBN(nn.Module):
|
56 |
+
"""Residual block without BN.
|
57 |
+
|
58 |
+
It has a style of:
|
59 |
+
---Conv-ReLU-Conv-+-
|
60 |
+
|________________|
|
61 |
+
|
62 |
+
Args:
|
63 |
+
num_feat (int): Channel number of intermediate features.
|
64 |
+
Default: 64.
|
65 |
+
res_scale (float): Residual scale. Default: 1.
|
66 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
67 |
+
otherwise, use default_init_weights. Default: False.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
71 |
+
super(ResidualBlockNoBN, self).__init__()
|
72 |
+
self.res_scale = res_scale
|
73 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
74 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
75 |
+
self.relu = nn.ReLU(inplace=True)
|
76 |
+
|
77 |
+
if not pytorch_init:
|
78 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
identity = x
|
82 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
83 |
+
return identity + out * self.res_scale
|
84 |
+
|
85 |
+
|
86 |
+
class Upsample(nn.Sequential):
|
87 |
+
"""Upsample module.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
91 |
+
num_feat (int): Channel number of intermediate features.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, scale, num_feat):
|
95 |
+
m = []
|
96 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
97 |
+
for _ in range(int(math.log(scale, 2))):
|
98 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
99 |
+
m.append(nn.PixelShuffle(2))
|
100 |
+
elif scale == 3:
|
101 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
102 |
+
m.append(nn.PixelShuffle(3))
|
103 |
+
else:
|
104 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
105 |
+
super(Upsample, self).__init__(*m)
|
106 |
+
|
107 |
+
|
108 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
109 |
+
"""Warp an image or feature map with optical flow.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
113 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
114 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
115 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
116 |
+
Default: 'zeros'.
|
117 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
118 |
+
align_corners=True. After pytorch 1.3, the default value is
|
119 |
+
align_corners=False. Here, we use the True as default.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
Tensor: Warped image or feature map.
|
123 |
+
"""
|
124 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
125 |
+
_, _, h, w = x.size()
|
126 |
+
# create mesh grid
|
127 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
128 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
129 |
+
grid.requires_grad = False
|
130 |
+
|
131 |
+
vgrid = grid + flow
|
132 |
+
# scale grid to [-1,1]
|
133 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
134 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
135 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
136 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
137 |
+
|
138 |
+
# TODO, what if align_corners=False
|
139 |
+
return output
|
140 |
+
|
141 |
+
|
142 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
143 |
+
"""Resize a flow according to ratio or shape.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
147 |
+
size_type (str): 'ratio' or 'shape'.
|
148 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
149 |
+
shape.
|
150 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
151 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
152 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
153 |
+
ratio > 1.0).
|
154 |
+
2) The order of output_size should be [out_h, out_w].
|
155 |
+
interp_mode (str): The mode of interpolation for resizing.
|
156 |
+
Default: 'bilinear'.
|
157 |
+
align_corners (bool): Whether align corners. Default: False.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Tensor: Resized flow.
|
161 |
+
"""
|
162 |
+
_, _, flow_h, flow_w = flow.size()
|
163 |
+
if size_type == 'ratio':
|
164 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
165 |
+
elif size_type == 'shape':
|
166 |
+
output_h, output_w = sizes[0], sizes[1]
|
167 |
+
else:
|
168 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
169 |
+
|
170 |
+
input_flow = flow.clone()
|
171 |
+
ratio_h = output_h / flow_h
|
172 |
+
ratio_w = output_w / flow_w
|
173 |
+
input_flow[:, 0, :, :] *= ratio_w
|
174 |
+
input_flow[:, 1, :, :] *= ratio_h
|
175 |
+
resized_flow = F.interpolate(
|
176 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
177 |
+
return resized_flow
|
178 |
+
|
179 |
+
|
180 |
+
# TODO: may write a cpp file
|
181 |
+
def pixel_unshuffle(x, scale):
|
182 |
+
""" Pixel unshuffle.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
186 |
+
scale (int): Downsample ratio.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
Tensor: the pixel unshuffled feature.
|
190 |
+
"""
|
191 |
+
b, c, hh, hw = x.size()
|
192 |
+
out_channel = c * (scale**2)
|
193 |
+
assert hh % scale == 0 and hw % scale == 0
|
194 |
+
h = hh // scale
|
195 |
+
w = hw // scale
|
196 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
197 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
Upsample/model.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
from huggingface_hub import hf_hub_url, hf_hub_download
|
8 |
+
|
9 |
+
from .rrdbnet_arch import RRDBNet
|
10 |
+
from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
|
11 |
+
unpad_image
|
12 |
+
|
13 |
+
HF_MODELS = {
|
14 |
+
2: dict(
|
15 |
+
repo_id='sberbank-ai/Real-ESRGAN',
|
16 |
+
filename='RealESRGAN_x2.pth',
|
17 |
+
),
|
18 |
+
4: dict(
|
19 |
+
repo_id='sberbank-ai/Real-ESRGAN',
|
20 |
+
filename='RealESRGAN_x4.pth',
|
21 |
+
),
|
22 |
+
8: dict(
|
23 |
+
repo_id='sberbank-ai/Real-ESRGAN',
|
24 |
+
filename='RealESRGAN_x8.pth',
|
25 |
+
),
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class RealESRGAN:
|
30 |
+
def __init__(self, device, scale=4):
|
31 |
+
self.device = device
|
32 |
+
self.scale = scale
|
33 |
+
self.model = RRDBNet(
|
34 |
+
num_in_ch=3, num_out_ch=3, num_feat=64,
|
35 |
+
num_block=23, num_grow_ch=32, scale=scale
|
36 |
+
)
|
37 |
+
|
38 |
+
def load_weights(self, model_path, download=True):
|
39 |
+
if not os.path.exists(model_path) and download:
|
40 |
+
assert self.scale in [2, 4, 8], 'You can download models only with scales: 2, 4, 8'
|
41 |
+
config = HF_MODELS[self.scale]
|
42 |
+
cache_dir = os.path.dirname(model_path)
|
43 |
+
local_filename = os.path.basename(model_path)
|
44 |
+
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
|
45 |
+
htr = hf_hub_download(repo_id=config['repo_id'], cache_dir=cache_dir, local_dir=cache_dir,
|
46 |
+
filename=config['filename'])
|
47 |
+
print(htr)
|
48 |
+
# cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
|
49 |
+
print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
|
50 |
+
|
51 |
+
loadnet = torch.load(model_path)
|
52 |
+
if 'params' in loadnet:
|
53 |
+
self.model.load_state_dict(loadnet['params'], strict=True)
|
54 |
+
elif 'params_ema' in loadnet:
|
55 |
+
self.model.load_state_dict(loadnet['params_ema'], strict=True)
|
56 |
+
else:
|
57 |
+
self.model.load_state_dict(loadnet, strict=True)
|
58 |
+
self.model.eval()
|
59 |
+
self.model.to(self.device)
|
60 |
+
|
61 |
+
# @torch.cuda.amp.autocast()
|
62 |
+
def predict(self, lr_image, batch_size=4, patches_size=192,
|
63 |
+
padding=24, pad_size=15):
|
64 |
+
torch.autocast(device_type=self.device.type)
|
65 |
+
scale = self.scale
|
66 |
+
device = self.device
|
67 |
+
lr_image = np.array(lr_image)
|
68 |
+
lr_image = pad_reflect(lr_image, pad_size)
|
69 |
+
|
70 |
+
patches, p_shape = split_image_into_overlapping_patches(
|
71 |
+
lr_image, patch_size=patches_size, padding_size=padding
|
72 |
+
)
|
73 |
+
img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
res = self.model(img[0:batch_size])
|
77 |
+
for i in range(batch_size, img.shape[0], batch_size):
|
78 |
+
res = torch.cat((res, self.model(img[i:i + batch_size])), 0)
|
79 |
+
|
80 |
+
sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1)
|
81 |
+
np_sr_image = sr_image.numpy()
|
82 |
+
|
83 |
+
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
|
84 |
+
scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
|
85 |
+
np_sr_image = stich_together(
|
86 |
+
np_sr_image, padded_image_shape=padded_size_scaled,
|
87 |
+
target_shape=scaled_image_shape, padding_size=padding * scale
|
88 |
+
)
|
89 |
+
sr_img = (np_sr_image * 255).astype(np.uint8)
|
90 |
+
sr_img = unpad_image(sr_img, pad_size * scale)
|
91 |
+
sr_img = Image.fromarray(sr_img)
|
92 |
+
|
93 |
+
return sr_img
|
Upsample/rrdbnet_arch.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
|
6 |
+
|
7 |
+
|
8 |
+
class ResidualDenseBlock(nn.Module):
|
9 |
+
"""Residual Dense Block.
|
10 |
+
|
11 |
+
Used in RRDB block in ESRGAN.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
num_feat (int): Channel number of intermediate features.
|
15 |
+
num_grow_ch (int): Channels for each growth.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
19 |
+
super(ResidualDenseBlock, self).__init__()
|
20 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
21 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
22 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
23 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
24 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
25 |
+
|
26 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
27 |
+
|
28 |
+
# initialization
|
29 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x1 = self.lrelu(self.conv1(x))
|
33 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
34 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
35 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
36 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
37 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
38 |
+
return x5 * 0.2 + x
|
39 |
+
|
40 |
+
|
41 |
+
class RRDB(nn.Module):
|
42 |
+
"""Residual in Residual Dense Block.
|
43 |
+
|
44 |
+
Used in RRDB-Net in ESRGAN.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
num_feat (int): Channel number of intermediate features.
|
48 |
+
num_grow_ch (int): Channels for each growth.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
52 |
+
super(RRDB, self).__init__()
|
53 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
54 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
55 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
out = self.rdb1(x)
|
59 |
+
out = self.rdb2(out)
|
60 |
+
out = self.rdb3(out)
|
61 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
62 |
+
return out * 0.2 + x
|
63 |
+
|
64 |
+
|
65 |
+
class RRDBNet(nn.Module):
|
66 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
67 |
+
in ESRGAN.
|
68 |
+
|
69 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
70 |
+
|
71 |
+
We extend ESRGAN for scale x2 and scale x1.
|
72 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
73 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
74 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
num_in_ch (int): Channel number of inputs.
|
78 |
+
num_out_ch (int): Channel number of outputs.
|
79 |
+
num_feat (int): Channel number of intermediate features.
|
80 |
+
Default: 64
|
81 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
82 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
86 |
+
super(RRDBNet, self).__init__()
|
87 |
+
self.scale = scale
|
88 |
+
if scale == 2:
|
89 |
+
num_in_ch = num_in_ch * 4
|
90 |
+
elif scale == 1:
|
91 |
+
num_in_ch = num_in_ch * 16
|
92 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
93 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
94 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
95 |
+
# upsample
|
96 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
97 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
98 |
+
if scale == 8:
|
99 |
+
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
102 |
+
|
103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.scale == 2:
|
107 |
+
feat = pixel_unshuffle(x, scale=2)
|
108 |
+
elif self.scale == 1:
|
109 |
+
feat = pixel_unshuffle(x, scale=4)
|
110 |
+
else:
|
111 |
+
feat = x
|
112 |
+
feat = self.conv_first(feat)
|
113 |
+
body_feat = self.conv_body(self.body(feat))
|
114 |
+
feat = feat + body_feat
|
115 |
+
# upsample
|
116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
118 |
+
if self.scale == 8:
|
119 |
+
feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
120 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
121 |
+
return out
|
Upsample/utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import io
|
6 |
+
|
7 |
+
|
8 |
+
def pad_reflect(image, pad_size):
|
9 |
+
imsize = image.shape
|
10 |
+
height, width = imsize[:2]
|
11 |
+
new_img = np.zeros([height + pad_size * 2, width + pad_size * 2, imsize[2]]).astype(np.uint8)
|
12 |
+
new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
|
13 |
+
|
14 |
+
new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top
|
15 |
+
new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom
|
16 |
+
new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size * 2, :], axis=1) # left
|
17 |
+
new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size * 2:-pad_size, :], axis=1) # right
|
18 |
+
|
19 |
+
return new_img
|
20 |
+
|
21 |
+
|
22 |
+
def unpad_image(image, pad_size):
|
23 |
+
return image[pad_size:-pad_size, pad_size:-pad_size, :]
|
24 |
+
|
25 |
+
|
26 |
+
def process_array(image_array, expand=True):
|
27 |
+
""" Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
|
28 |
+
|
29 |
+
image_batch = image_array / 255.0
|
30 |
+
if expand:
|
31 |
+
image_batch = np.expand_dims(image_batch, axis=0)
|
32 |
+
return image_batch
|
33 |
+
|
34 |
+
|
35 |
+
def process_output(output_tensor):
|
36 |
+
""" Transforms the 4-dimensional output tensor into a suitable image format. """
|
37 |
+
|
38 |
+
sr_img = output_tensor.clip(0, 1) * 255
|
39 |
+
sr_img = np.uint8(sr_img)
|
40 |
+
return sr_img
|
41 |
+
|
42 |
+
|
43 |
+
def pad_patch(image_patch, padding_size, channel_last=True):
|
44 |
+
""" Pads image_patch with with padding_size edge values. """
|
45 |
+
|
46 |
+
if channel_last:
|
47 |
+
return np.pad(
|
48 |
+
image_patch,
|
49 |
+
((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
|
50 |
+
'edge',
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
return np.pad(
|
54 |
+
image_patch,
|
55 |
+
((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
|
56 |
+
'edge',
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def unpad_patches(image_patches, padding_size):
|
61 |
+
return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
|
62 |
+
|
63 |
+
|
64 |
+
def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
|
65 |
+
""" Splits the image into partially overlapping patches.
|
66 |
+
The patches overlap by padding_size pixels.
|
67 |
+
Pads the image twice:
|
68 |
+
- first to have a size multiple of the patch size,
|
69 |
+
- then to have equal padding at the borders.
|
70 |
+
Args:
|
71 |
+
image_array: numpy array of the input image.
|
72 |
+
patch_size: size of the patches from the original image (without padding).
|
73 |
+
padding_size: size of the overlapping area.
|
74 |
+
"""
|
75 |
+
|
76 |
+
xmax, ymax, _ = image_array.shape
|
77 |
+
x_remainder = xmax % patch_size
|
78 |
+
y_remainder = ymax % patch_size
|
79 |
+
|
80 |
+
# modulo here is to avoid extending of patch_size instead of 0
|
81 |
+
x_extend = (patch_size - x_remainder) % patch_size
|
82 |
+
y_extend = (patch_size - y_remainder) % patch_size
|
83 |
+
|
84 |
+
# make sure the image is divisible into regular patches
|
85 |
+
extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
|
86 |
+
|
87 |
+
# add padding around the image to simplify computations
|
88 |
+
padded_image = pad_patch(extended_image, padding_size, channel_last=True)
|
89 |
+
|
90 |
+
xmax, ymax, _ = padded_image.shape
|
91 |
+
patches = []
|
92 |
+
|
93 |
+
x_lefts = range(padding_size, xmax - padding_size, patch_size)
|
94 |
+
y_tops = range(padding_size, ymax - padding_size, patch_size)
|
95 |
+
|
96 |
+
for x in x_lefts:
|
97 |
+
for y in y_tops:
|
98 |
+
x_left = x - padding_size
|
99 |
+
y_top = y - padding_size
|
100 |
+
x_right = x + patch_size + padding_size
|
101 |
+
y_bottom = y + patch_size + padding_size
|
102 |
+
patch = padded_image[x_left:x_right, y_top:y_bottom, :]
|
103 |
+
patches.append(patch)
|
104 |
+
|
105 |
+
return np.array(patches), padded_image.shape
|
106 |
+
|
107 |
+
|
108 |
+
def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
|
109 |
+
""" Reconstruct the image from overlapping patches.
|
110 |
+
After scaling, shapes and padding should be scaled too.
|
111 |
+
Args:
|
112 |
+
patches: patches obtained with split_image_into_overlapping_patches
|
113 |
+
padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
|
114 |
+
target_shape: shape of the final image
|
115 |
+
padding_size: size of the overlapping area.
|
116 |
+
"""
|
117 |
+
|
118 |
+
xmax, ymax, _ = padded_image_shape
|
119 |
+
patches = unpad_patches(patches, padding_size)
|
120 |
+
patch_size = patches.shape[1]
|
121 |
+
n_patches_per_row = ymax // patch_size
|
122 |
+
|
123 |
+
complete_image = np.zeros((xmax, ymax, 3))
|
124 |
+
|
125 |
+
row = -1
|
126 |
+
col = 0
|
127 |
+
for i in range(len(patches)):
|
128 |
+
if i % n_patches_per_row == 0:
|
129 |
+
row += 1
|
130 |
+
col = 0
|
131 |
+
complete_image[
|
132 |
+
row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
|
133 |
+
] = patches[i]
|
134 |
+
col += 1
|
135 |
+
return complete_image[0: target_shape[0], 0: target_shape[1], :]
|
app.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
4 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
+
from janus.utils.io import load_pil_images
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
from Upsample import RealESRGAN
|
12 |
+
import spaces # Import spaces for ZeroGPU compatibility
|
13 |
+
|
14 |
+
|
15 |
+
# Load model and processor
|
16 |
+
model_path = "deepseek-ai/Janus-Pro-7B"
|
17 |
+
config = AutoConfig.from_pretrained(model_path)
|
18 |
+
language_config = config.language_config
|
19 |
+
language_config._attn_implementation = 'eager'
|
20 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
21 |
+
language_config=language_config,
|
22 |
+
trust_remote_code=True)
|
23 |
+
if torch.cuda.is_available():
|
24 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
25 |
+
else:
|
26 |
+
vl_gpt = vl_gpt.to(torch.float16)
|
27 |
+
|
28 |
+
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
29 |
+
tokenizer = vl_chat_processor.tokenizer
|
30 |
+
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
31 |
+
|
32 |
+
# SR model
|
33 |
+
sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
|
34 |
+
sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
|
35 |
+
|
36 |
+
@torch.inference_mode()
|
37 |
+
@spaces.GPU(duration=120)
|
38 |
+
# Multimodal Understanding function
|
39 |
+
def multimodal_understanding(image, question, seed, top_p, temperature):
|
40 |
+
# Clear CUDA cache before generating
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
|
43 |
+
# set seed
|
44 |
+
torch.manual_seed(seed)
|
45 |
+
np.random.seed(seed)
|
46 |
+
torch.cuda.manual_seed(seed)
|
47 |
+
|
48 |
+
conversation = [
|
49 |
+
{
|
50 |
+
"role": "<|User|>",
|
51 |
+
"content": f"<image_placeholder>\n{question}",
|
52 |
+
"images": [image],
|
53 |
+
},
|
54 |
+
{"role": "<|Assistant|>", "content": ""},
|
55 |
+
]
|
56 |
+
|
57 |
+
pil_images = [Image.fromarray(image)]
|
58 |
+
prepare_inputs = vl_chat_processor(
|
59 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
60 |
+
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
61 |
+
|
62 |
+
|
63 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
64 |
+
|
65 |
+
outputs = vl_gpt.language_model.generate(
|
66 |
+
inputs_embeds=inputs_embeds,
|
67 |
+
attention_mask=prepare_inputs.attention_mask,
|
68 |
+
pad_token_id=tokenizer.eos_token_id,
|
69 |
+
bos_token_id=tokenizer.bos_token_id,
|
70 |
+
eos_token_id=tokenizer.eos_token_id,
|
71 |
+
max_new_tokens=512,
|
72 |
+
do_sample=False if temperature == 0 else True,
|
73 |
+
use_cache=True,
|
74 |
+
temperature=temperature,
|
75 |
+
top_p=top_p,
|
76 |
+
)
|
77 |
+
|
78 |
+
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
79 |
+
return answer
|
80 |
+
|
81 |
+
|
82 |
+
def generate(input_ids,
|
83 |
+
width,
|
84 |
+
height,
|
85 |
+
temperature: float = 1,
|
86 |
+
parallel_size: int = 5,
|
87 |
+
cfg_weight: float = 5,
|
88 |
+
image_token_num_per_image: int = 576,
|
89 |
+
patch_size: int = 16):
|
90 |
+
# Clear CUDA cache before generating
|
91 |
+
torch.cuda.empty_cache()
|
92 |
+
|
93 |
+
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
94 |
+
for i in range(parallel_size * 2):
|
95 |
+
tokens[i, :] = input_ids
|
96 |
+
if i % 2 != 0:
|
97 |
+
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
98 |
+
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
99 |
+
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
100 |
+
|
101 |
+
pkv = None
|
102 |
+
for i in range(image_token_num_per_image):
|
103 |
+
with torch.no_grad():
|
104 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
|
105 |
+
use_cache=True,
|
106 |
+
past_key_values=pkv)
|
107 |
+
pkv = outputs.past_key_values
|
108 |
+
hidden_states = outputs.last_hidden_state
|
109 |
+
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
110 |
+
logit_cond = logits[0::2, :]
|
111 |
+
logit_uncond = logits[1::2, :]
|
112 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
113 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
114 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
115 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
116 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
117 |
+
|
118 |
+
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
119 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
124 |
+
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
125 |
+
|
126 |
+
return generated_tokens.to(dtype=torch.int), patches
|
127 |
+
|
128 |
+
def unpack(dec, width, height, parallel_size=5):
|
129 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
130 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
131 |
+
|
132 |
+
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
133 |
+
visual_img[:, :, :] = dec
|
134 |
+
|
135 |
+
return visual_img
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
@torch.inference_mode()
|
140 |
+
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
141 |
+
def generate_image(prompt,
|
142 |
+
seed=None,
|
143 |
+
guidance=5,
|
144 |
+
t2i_temperature=1.0):
|
145 |
+
# Clear CUDA cache and avoid tracking gradients
|
146 |
+
torch.cuda.empty_cache()
|
147 |
+
# Set the seed for reproducible results
|
148 |
+
if seed is not None:
|
149 |
+
torch.manual_seed(seed)
|
150 |
+
torch.cuda.manual_seed(seed)
|
151 |
+
np.random.seed(seed)
|
152 |
+
width = 384
|
153 |
+
height = 384
|
154 |
+
parallel_size = 5
|
155 |
+
|
156 |
+
with torch.no_grad():
|
157 |
+
messages = [{'role': '<|User|>', 'content': prompt},
|
158 |
+
{'role': '<|Assistant|>', 'content': ''}]
|
159 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
160 |
+
sft_format=vl_chat_processor.sft_format,
|
161 |
+
system_prompt='')
|
162 |
+
text = text + vl_chat_processor.image_start_tag
|
163 |
+
|
164 |
+
input_ids = torch.LongTensor(tokenizer.encode(text))
|
165 |
+
output, patches = generate(input_ids,
|
166 |
+
width // 16 * 16,
|
167 |
+
height // 16 * 16,
|
168 |
+
cfg_weight=guidance,
|
169 |
+
parallel_size=parallel_size,
|
170 |
+
temperature=t2i_temperature)
|
171 |
+
images = unpack(patches,
|
172 |
+
width // 16 * 16,
|
173 |
+
height // 16 * 16,
|
174 |
+
parallel_size=parallel_size)
|
175 |
+
|
176 |
+
# return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
|
177 |
+
stime = time.time()
|
178 |
+
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
179 |
+
print(f'upsample time: {time.time() - stime}')
|
180 |
+
return ret_images
|
181 |
+
|
182 |
+
|
183 |
+
@spaces.GPU(duration=60)
|
184 |
+
def image_upsample(img: Image.Image) -> Image.Image:
|
185 |
+
if img is None:
|
186 |
+
raise Exception("Image not uploaded")
|
187 |
+
|
188 |
+
width, height = img.size
|
189 |
+
|
190 |
+
if width >= 5000 or height >= 5000:
|
191 |
+
raise Exception("The image is too large.")
|
192 |
+
|
193 |
+
global sr_model
|
194 |
+
result = sr_model.predict(img.convert('RGB'))
|
195 |
+
return result
|
196 |
+
|
197 |
+
|
198 |
+
# Gradio interface
|
199 |
+
with gr.Blocks() as demo:
|
200 |
+
gr.Markdown(value="# Multimodal Understanding")
|
201 |
+
with gr.Row():
|
202 |
+
image_input = gr.Image()
|
203 |
+
with gr.Column():
|
204 |
+
question_input = gr.Textbox(label="Question")
|
205 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
206 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
207 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
208 |
+
|
209 |
+
understanding_button = gr.Button("Chat")
|
210 |
+
understanding_output = gr.Textbox(label="Response")
|
211 |
+
|
212 |
+
examples_inpainting = gr.Examples(
|
213 |
+
label="Multimodal Understanding examples",
|
214 |
+
examples=[
|
215 |
+
[
|
216 |
+
"explain this meme",
|
217 |
+
"doge.png",
|
218 |
+
],
|
219 |
+
[
|
220 |
+
"Convert the formula into latex code.",
|
221 |
+
"equation.png",
|
222 |
+
],
|
223 |
+
],
|
224 |
+
inputs=[question_input, image_input],
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
gr.Markdown(value="# Text-to-Image Generation")
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
with gr.Row():
|
233 |
+
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
234 |
+
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
|
235 |
+
|
236 |
+
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
|
237 |
+
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
238 |
+
|
239 |
+
generation_button = gr.Button("Generate Images")
|
240 |
+
|
241 |
+
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
242 |
+
|
243 |
+
examples_t2i = gr.Examples(
|
244 |
+
label="Text to image generation examples.",
|
245 |
+
examples=[
|
246 |
+
"Master shifu racoon wearing drip attire as a street gangster.",
|
247 |
+
"The face of a beautiful girl",
|
248 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
249 |
+
"A glass of red wine on a reflective surface.",
|
250 |
+
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
|
251 |
+
"The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
|
252 |
+
],
|
253 |
+
inputs=prompt_input,
|
254 |
+
)
|
255 |
+
|
256 |
+
understanding_button.click(
|
257 |
+
multimodal_understanding,
|
258 |
+
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
259 |
+
outputs=understanding_output
|
260 |
+
)
|
261 |
+
|
262 |
+
generation_button.click(
|
263 |
+
fn=generate_image,
|
264 |
+
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
|
265 |
+
outputs=image_output
|
266 |
+
)
|
267 |
+
|
268 |
+
demo.launch(share=True)
|
doge.png
ADDED
equation.png
ADDED
janus/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
|
21 |
+
# check if python version is above 3.10
|
22 |
+
import sys
|
23 |
+
|
24 |
+
if sys.version_info >= (3, 10):
|
25 |
+
print("Python version is above 3.10, patching the collections module.")
|
26 |
+
# Monkey patch collections
|
27 |
+
import collections
|
28 |
+
import collections.abc
|
29 |
+
|
30 |
+
for type_name in collections.abc.__all__:
|
31 |
+
setattr(collections, type_name, getattr(collections.abc, type_name))
|
janus/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (433 Bytes). View file
|
|
janus/models/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from .image_processing_vlm import VLMImageProcessor
|
21 |
+
from .modeling_vlm import MultiModalityCausalLM
|
22 |
+
from .processing_vlm import VLChatProcessor
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"VLMImageProcessor",
|
26 |
+
"VLChatProcessor",
|
27 |
+
"MultiModalityCausalLM",
|
28 |
+
]
|
janus/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (391 Bytes). View file
|
|
janus/models/__pycache__/clip_encoder.cpython-38.pyc
ADDED
Binary file (2.74 kB). View file
|
|
janus/models/__pycache__/image_processing_vlm.cpython-38.pyc
ADDED
Binary file (4.98 kB). View file
|
|
janus/models/__pycache__/modeling_vlm.cpython-38.pyc
ADDED
Binary file (7.1 kB). View file
|
|
janus/models/__pycache__/processing_vlm.cpython-38.pyc
ADDED
Binary file (11.1 kB). View file
|
|
janus/models/__pycache__/projector.cpython-38.pyc
ADDED
Binary file (2.23 kB). View file
|
|
janus/models/__pycache__/siglip_vit.cpython-38.pyc
ADDED
Binary file (18.4 kB). View file
|
|
janus/models/__pycache__/vq_model.cpython-38.pyc
ADDED
Binary file (12.5 kB). View file
|
|
janus/models/clip_encoder.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torchvision.transforms
|
25 |
+
from einops import rearrange
|
26 |
+
|
27 |
+
from janus.models.siglip_vit import create_siglip_vit
|
28 |
+
|
29 |
+
|
30 |
+
class CLIPVisionTower(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
model_name: str = "siglip_large_patch16_384",
|
34 |
+
image_size: Union[Tuple[int, int], int] = 336,
|
35 |
+
select_feature: str = "patch",
|
36 |
+
select_layer: int = -2,
|
37 |
+
select_layers: list = None,
|
38 |
+
ckpt_path: str = "",
|
39 |
+
pixel_mean: Optional[List[float]] = None,
|
40 |
+
pixel_std: Optional[List[float]] = None,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.model_name = model_name
|
46 |
+
self.select_feature = select_feature
|
47 |
+
self.select_layer = select_layer
|
48 |
+
self.select_layers = select_layers
|
49 |
+
|
50 |
+
vision_tower_params = {
|
51 |
+
"model_name": model_name,
|
52 |
+
"image_size": image_size,
|
53 |
+
"ckpt_path": ckpt_path,
|
54 |
+
"select_layer": select_layer,
|
55 |
+
}
|
56 |
+
vision_tower_params.update(kwargs)
|
57 |
+
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
|
58 |
+
vision_tower_params
|
59 |
+
)
|
60 |
+
|
61 |
+
if pixel_mean is not None and pixel_std is not None:
|
62 |
+
image_norm = torchvision.transforms.Normalize(
|
63 |
+
mean=pixel_mean, std=pixel_std
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
image_norm = None
|
67 |
+
|
68 |
+
self.image_norm = image_norm
|
69 |
+
|
70 |
+
def build_vision_tower(self, vision_tower_params):
|
71 |
+
if self.model_name.startswith("siglip"):
|
72 |
+
self.select_feature = "same"
|
73 |
+
vision_tower = create_siglip_vit(**vision_tower_params)
|
74 |
+
forward_kwargs = dict()
|
75 |
+
|
76 |
+
elif self.model_name.startswith("sam"):
|
77 |
+
vision_tower = create_sam_vit(**vision_tower_params)
|
78 |
+
forward_kwargs = dict()
|
79 |
+
|
80 |
+
else: # huggingface
|
81 |
+
from transformers import CLIPVisionModel
|
82 |
+
|
83 |
+
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
84 |
+
forward_kwargs = dict(output_hidden_states=True)
|
85 |
+
|
86 |
+
return vision_tower, forward_kwargs
|
87 |
+
|
88 |
+
def feature_select(self, image_forward_outs):
|
89 |
+
if isinstance(image_forward_outs, torch.Tensor):
|
90 |
+
# the output has been the self.select_layer"s features
|
91 |
+
image_features = image_forward_outs
|
92 |
+
else:
|
93 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
94 |
+
|
95 |
+
if self.select_feature == "patch":
|
96 |
+
# if the output has cls_token
|
97 |
+
image_features = image_features[:, 1:]
|
98 |
+
elif self.select_feature == "cls_patch":
|
99 |
+
image_features = image_features
|
100 |
+
elif self.select_feature == "same":
|
101 |
+
image_features = image_features
|
102 |
+
|
103 |
+
else:
|
104 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
105 |
+
return image_features
|
106 |
+
|
107 |
+
def forward(self, images):
|
108 |
+
"""
|
109 |
+
|
110 |
+
Args:
|
111 |
+
images (torch.Tensor): [b, 3, H, W]
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
image_features (torch.Tensor): [b, n_patch, d]
|
115 |
+
"""
|
116 |
+
|
117 |
+
if self.image_norm is not None:
|
118 |
+
images = self.image_norm(images)
|
119 |
+
|
120 |
+
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
121 |
+
image_features = self.feature_select(image_forward_outs)
|
122 |
+
return image_features
|
janus/models/image_processing_vlm.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import List, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import torchvision
|
25 |
+
import torchvision.transforms.functional
|
26 |
+
from PIL import Image
|
27 |
+
from transformers import AutoImageProcessor, PretrainedConfig
|
28 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
29 |
+
from transformers.image_utils import to_numpy_array
|
30 |
+
from transformers.utils import logging
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
35 |
+
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
36 |
+
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
|
37 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
38 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
39 |
+
|
40 |
+
|
41 |
+
def expand2square(pil_img, background_color):
|
42 |
+
width, height = pil_img.size
|
43 |
+
if width == height:
|
44 |
+
return pil_img
|
45 |
+
elif width > height:
|
46 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
47 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
48 |
+
return result
|
49 |
+
else:
|
50 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
51 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
52 |
+
return result
|
53 |
+
|
54 |
+
|
55 |
+
class VLMImageProcessorConfig(PretrainedConfig):
|
56 |
+
model_type = "deepseek_vlm"
|
57 |
+
image_size: int
|
58 |
+
min_size: int
|
59 |
+
image_mean: Union[Tuple[float, float, float], List[float]]
|
60 |
+
image_std: Union[Tuple[float, float, float], List[float]]
|
61 |
+
rescale_factor: float
|
62 |
+
do_normalize: bool
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
image_size: int,
|
67 |
+
min_size: int = 14,
|
68 |
+
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
69 |
+
0.48145466,
|
70 |
+
0.4578275,
|
71 |
+
0.40821073,
|
72 |
+
),
|
73 |
+
image_std: Union[Tuple[float, float, float], List[float]] = (
|
74 |
+
0.26862954,
|
75 |
+
0.26130258,
|
76 |
+
0.27577711,
|
77 |
+
),
|
78 |
+
rescale_factor: float = 1.0 / 255.0,
|
79 |
+
do_normalize: bool = True,
|
80 |
+
**kwargs,
|
81 |
+
):
|
82 |
+
self.image_size = image_size
|
83 |
+
self.min_size = min_size
|
84 |
+
self.image_mean = image_mean
|
85 |
+
self.image_std = image_std
|
86 |
+
self.rescale_factor = rescale_factor
|
87 |
+
self.do_normalize = do_normalize
|
88 |
+
|
89 |
+
super().__init__(**kwargs)
|
90 |
+
|
91 |
+
|
92 |
+
class VLMImageProcessor(BaseImageProcessor):
|
93 |
+
model_input_names = ["pixel_values"]
|
94 |
+
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
image_size: int,
|
98 |
+
min_size: int = 14,
|
99 |
+
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
100 |
+
0.48145466,
|
101 |
+
0.4578275,
|
102 |
+
0.40821073,
|
103 |
+
),
|
104 |
+
image_std: Union[Tuple[float, float, float], List[float]] = (
|
105 |
+
0.26862954,
|
106 |
+
0.26130258,
|
107 |
+
0.27577711,
|
108 |
+
),
|
109 |
+
rescale_factor: float = 1.0 / 255.0,
|
110 |
+
do_normalize: bool = True,
|
111 |
+
**kwargs,
|
112 |
+
):
|
113 |
+
super().__init__(**kwargs)
|
114 |
+
|
115 |
+
self.image_size = image_size
|
116 |
+
self.rescale_factor = rescale_factor
|
117 |
+
self.image_mean = image_mean
|
118 |
+
self.image_std = image_std
|
119 |
+
self.min_size = min_size
|
120 |
+
self.do_normalize = do_normalize
|
121 |
+
|
122 |
+
if image_mean is None:
|
123 |
+
self.background_color = (127, 127, 127)
|
124 |
+
else:
|
125 |
+
self.background_color = tuple([int(x * 255) for x in image_mean])
|
126 |
+
|
127 |
+
def resize(self, pil_img: Image) -> np.ndarray:
|
128 |
+
"""
|
129 |
+
|
130 |
+
Args:
|
131 |
+
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
x (np.ndarray): [3, self.image_size, self.image_size]
|
135 |
+
"""
|
136 |
+
|
137 |
+
width, height = pil_img.size
|
138 |
+
max_size = max(width, height)
|
139 |
+
|
140 |
+
size = [
|
141 |
+
max(int(height / max_size * self.image_size), self.min_size),
|
142 |
+
max(int(width / max_size * self.image_size), self.min_size),
|
143 |
+
]
|
144 |
+
|
145 |
+
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
146 |
+
print(f"orig size = {pil_img.size}, new size = {size}")
|
147 |
+
raise ValueError("Invalid size!")
|
148 |
+
|
149 |
+
pil_img = torchvision.transforms.functional.resize(
|
150 |
+
pil_img,
|
151 |
+
size,
|
152 |
+
interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
|
153 |
+
antialias=True,
|
154 |
+
)
|
155 |
+
|
156 |
+
pil_img = expand2square(pil_img, self.background_color)
|
157 |
+
x = to_numpy_array(pil_img)
|
158 |
+
|
159 |
+
# [H, W, 3] -> [3, H, W]
|
160 |
+
x = np.transpose(x, (2, 0, 1))
|
161 |
+
|
162 |
+
return x
|
163 |
+
|
164 |
+
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
165 |
+
# resize and pad to [self.image_size, self.image_size]
|
166 |
+
# then convert from [H, W, 3] to [3, H, W]
|
167 |
+
images: List[np.ndarray] = [self.resize(image) for image in images]
|
168 |
+
|
169 |
+
# resacle from [0, 255] -> [0, 1]
|
170 |
+
images = [
|
171 |
+
self.rescale(
|
172 |
+
image=image,
|
173 |
+
scale=self.rescale_factor,
|
174 |
+
input_data_format="channels_first",
|
175 |
+
)
|
176 |
+
for image in images
|
177 |
+
]
|
178 |
+
|
179 |
+
# normalize
|
180 |
+
if self.do_normalize:
|
181 |
+
images = [
|
182 |
+
self.normalize(
|
183 |
+
image=image,
|
184 |
+
mean=self.image_mean,
|
185 |
+
std=self.image_std,
|
186 |
+
input_data_format="channels_first",
|
187 |
+
)
|
188 |
+
for image in images
|
189 |
+
]
|
190 |
+
|
191 |
+
data = {"pixel_values": images}
|
192 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
193 |
+
|
194 |
+
@property
|
195 |
+
def default_shape(self):
|
196 |
+
return [3, self.image_size, self.image_size]
|
197 |
+
|
198 |
+
|
199 |
+
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
image_processor = VLMImageProcessor(
|
204 |
+
image_size=1024,
|
205 |
+
image_mean=IMAGENET_INCEPTION_MEAN,
|
206 |
+
image_std=IMAGENET_INCEPTION_STD,
|
207 |
+
do_normalize=True,
|
208 |
+
)
|
janus/models/modeling_vlm.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from attrdict import AttrDict
|
22 |
+
from einops import rearrange
|
23 |
+
from transformers import (
|
24 |
+
AutoConfig,
|
25 |
+
AutoModelForCausalLM,
|
26 |
+
LlamaConfig,
|
27 |
+
LlamaForCausalLM,
|
28 |
+
PreTrainedModel,
|
29 |
+
)
|
30 |
+
from transformers.configuration_utils import PretrainedConfig
|
31 |
+
|
32 |
+
from janus.models.clip_encoder import CLIPVisionTower
|
33 |
+
from janus.models.projector import MlpProjector
|
34 |
+
|
35 |
+
|
36 |
+
class vision_head(torch.nn.Module):
|
37 |
+
def __init__(self, params):
|
38 |
+
super().__init__()
|
39 |
+
self.output_mlp_projector = torch.nn.Linear(
|
40 |
+
params.n_embed, params.image_token_embed
|
41 |
+
)
|
42 |
+
self.vision_activation = torch.nn.GELU()
|
43 |
+
self.vision_head = torch.nn.Linear(
|
44 |
+
params.image_token_embed, params.image_token_size
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = self.output_mlp_projector(x)
|
49 |
+
x = self.vision_activation(x)
|
50 |
+
x = self.vision_head(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
def model_name_to_cls(cls_name):
|
55 |
+
if "MlpProjector" in cls_name:
|
56 |
+
cls = MlpProjector
|
57 |
+
|
58 |
+
elif "CLIPVisionTower" in cls_name:
|
59 |
+
cls = CLIPVisionTower
|
60 |
+
|
61 |
+
elif "VQ" in cls_name:
|
62 |
+
from janus.models.vq_model import VQ_models
|
63 |
+
|
64 |
+
cls = VQ_models[cls_name]
|
65 |
+
elif "vision_head" in cls_name:
|
66 |
+
cls = vision_head
|
67 |
+
else:
|
68 |
+
raise ValueError(f"class_name {cls_name} is invalid.")
|
69 |
+
|
70 |
+
return cls
|
71 |
+
|
72 |
+
|
73 |
+
class VisionConfig(PretrainedConfig):
|
74 |
+
model_type = "vision"
|
75 |
+
cls: str = ""
|
76 |
+
params: AttrDict = {}
|
77 |
+
|
78 |
+
def __init__(self, **kwargs):
|
79 |
+
super().__init__(**kwargs)
|
80 |
+
|
81 |
+
self.cls = kwargs.get("cls", "")
|
82 |
+
if not isinstance(self.cls, str):
|
83 |
+
self.cls = self.cls.__name__
|
84 |
+
|
85 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
86 |
+
|
87 |
+
|
88 |
+
class AlignerConfig(PretrainedConfig):
|
89 |
+
model_type = "aligner"
|
90 |
+
cls: str = ""
|
91 |
+
params: AttrDict = {}
|
92 |
+
|
93 |
+
def __init__(self, **kwargs):
|
94 |
+
super().__init__(**kwargs)
|
95 |
+
|
96 |
+
self.cls = kwargs.get("cls", "")
|
97 |
+
if not isinstance(self.cls, str):
|
98 |
+
self.cls = self.cls.__name__
|
99 |
+
|
100 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
101 |
+
|
102 |
+
|
103 |
+
class GenVisionConfig(PretrainedConfig):
|
104 |
+
model_type = "gen_vision"
|
105 |
+
cls: str = ""
|
106 |
+
params: AttrDict = {}
|
107 |
+
|
108 |
+
def __init__(self, **kwargs):
|
109 |
+
super().__init__(**kwargs)
|
110 |
+
|
111 |
+
self.cls = kwargs.get("cls", "")
|
112 |
+
if not isinstance(self.cls, str):
|
113 |
+
self.cls = self.cls.__name__
|
114 |
+
|
115 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
116 |
+
|
117 |
+
|
118 |
+
class GenAlignerConfig(PretrainedConfig):
|
119 |
+
model_type = "gen_aligner"
|
120 |
+
cls: str = ""
|
121 |
+
params: AttrDict = {}
|
122 |
+
|
123 |
+
def __init__(self, **kwargs):
|
124 |
+
super().__init__(**kwargs)
|
125 |
+
|
126 |
+
self.cls = kwargs.get("cls", "")
|
127 |
+
if not isinstance(self.cls, str):
|
128 |
+
self.cls = self.cls.__name__
|
129 |
+
|
130 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
131 |
+
|
132 |
+
|
133 |
+
class GenHeadConfig(PretrainedConfig):
|
134 |
+
model_type = "gen_head"
|
135 |
+
cls: str = ""
|
136 |
+
params: AttrDict = {}
|
137 |
+
|
138 |
+
def __init__(self, **kwargs):
|
139 |
+
super().__init__(**kwargs)
|
140 |
+
|
141 |
+
self.cls = kwargs.get("cls", "")
|
142 |
+
if not isinstance(self.cls, str):
|
143 |
+
self.cls = self.cls.__name__
|
144 |
+
|
145 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
146 |
+
|
147 |
+
|
148 |
+
class MultiModalityConfig(PretrainedConfig):
|
149 |
+
model_type = "multi_modality"
|
150 |
+
vision_config: VisionConfig
|
151 |
+
aligner_config: AlignerConfig
|
152 |
+
|
153 |
+
gen_vision_config: GenVisionConfig
|
154 |
+
gen_aligner_config: GenAlignerConfig
|
155 |
+
gen_head_config: GenHeadConfig
|
156 |
+
|
157 |
+
language_config: LlamaConfig
|
158 |
+
|
159 |
+
def __init__(self, **kwargs):
|
160 |
+
super().__init__(**kwargs)
|
161 |
+
vision_config = kwargs.get("vision_config", {})
|
162 |
+
self.vision_config = VisionConfig(**vision_config)
|
163 |
+
|
164 |
+
aligner_config = kwargs.get("aligner_config", {})
|
165 |
+
self.aligner_config = AlignerConfig(**aligner_config)
|
166 |
+
|
167 |
+
gen_vision_config = kwargs.get("gen_vision_config", {})
|
168 |
+
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
|
169 |
+
|
170 |
+
gen_aligner_config = kwargs.get("gen_aligner_config", {})
|
171 |
+
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
|
172 |
+
|
173 |
+
gen_head_config = kwargs.get("gen_head_config", {})
|
174 |
+
self.gen_head_config = GenHeadConfig(**gen_head_config)
|
175 |
+
|
176 |
+
language_config = kwargs.get("language_config", {})
|
177 |
+
if isinstance(language_config, LlamaConfig):
|
178 |
+
self.language_config = language_config
|
179 |
+
else:
|
180 |
+
self.language_config = LlamaConfig(**language_config)
|
181 |
+
|
182 |
+
|
183 |
+
class MultiModalityPreTrainedModel(PreTrainedModel):
|
184 |
+
config_class = MultiModalityConfig
|
185 |
+
base_model_prefix = "multi_modality"
|
186 |
+
_no_split_modules = []
|
187 |
+
_skip_keys_device_placement = "past_key_values"
|
188 |
+
|
189 |
+
|
190 |
+
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
191 |
+
def __init__(self, config: MultiModalityConfig):
|
192 |
+
super().__init__(config)
|
193 |
+
|
194 |
+
vision_config = config.vision_config
|
195 |
+
vision_cls = model_name_to_cls(vision_config.cls)
|
196 |
+
self.vision_model = vision_cls(**vision_config.params)
|
197 |
+
|
198 |
+
aligner_config = config.aligner_config
|
199 |
+
aligner_cls = model_name_to_cls(aligner_config.cls)
|
200 |
+
self.aligner = aligner_cls(aligner_config.params)
|
201 |
+
|
202 |
+
gen_vision_config = config.gen_vision_config
|
203 |
+
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
|
204 |
+
self.gen_vision_model = gen_vision_cls()
|
205 |
+
|
206 |
+
gen_aligner_config = config.gen_aligner_config
|
207 |
+
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
|
208 |
+
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
|
209 |
+
|
210 |
+
gen_head_config = config.gen_head_config
|
211 |
+
gen_head_cls = model_name_to_cls(gen_head_config.cls)
|
212 |
+
self.gen_head = gen_head_cls(gen_head_config.params)
|
213 |
+
|
214 |
+
self.gen_embed = torch.nn.Embedding(
|
215 |
+
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
|
216 |
+
)
|
217 |
+
|
218 |
+
language_config = config.language_config
|
219 |
+
self.language_model = LlamaForCausalLM(language_config)
|
220 |
+
|
221 |
+
def prepare_inputs_embeds(
|
222 |
+
self,
|
223 |
+
input_ids: torch.LongTensor,
|
224 |
+
pixel_values: torch.FloatTensor,
|
225 |
+
images_seq_mask: torch.LongTensor,
|
226 |
+
images_emb_mask: torch.LongTensor,
|
227 |
+
**kwargs,
|
228 |
+
):
|
229 |
+
"""
|
230 |
+
|
231 |
+
Args:
|
232 |
+
input_ids (torch.LongTensor): [b, T]
|
233 |
+
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
234 |
+
images_seq_mask (torch.BoolTensor): [b, T]
|
235 |
+
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
236 |
+
|
237 |
+
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
input_embeds (torch.Tensor): [b, T, D]
|
241 |
+
"""
|
242 |
+
|
243 |
+
bs, n = pixel_values.shape[0:2]
|
244 |
+
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
245 |
+
# [b x n, T2, D]
|
246 |
+
images_embeds = self.aligner(self.vision_model(images))
|
247 |
+
|
248 |
+
# [b x n, T2, D] -> [b, n x T2, D]
|
249 |
+
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
250 |
+
# [b, n, T2] -> [b, n x T2]
|
251 |
+
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
252 |
+
|
253 |
+
# [b, T, D]
|
254 |
+
input_ids[input_ids < 0] = 0 # ignore the image embeddings
|
255 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
256 |
+
|
257 |
+
# replace with the image embeddings
|
258 |
+
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
259 |
+
|
260 |
+
return inputs_embeds
|
261 |
+
|
262 |
+
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
263 |
+
return self.gen_aligner(self.gen_embed(image_ids))
|
264 |
+
|
265 |
+
|
266 |
+
AutoConfig.register("vision", VisionConfig)
|
267 |
+
AutoConfig.register("aligner", AlignerConfig)
|
268 |
+
AutoConfig.register("gen_vision", GenVisionConfig)
|
269 |
+
AutoConfig.register("gen_aligner", GenAlignerConfig)
|
270 |
+
AutoConfig.register("gen_head", GenHeadConfig)
|
271 |
+
AutoConfig.register("multi_modality", MultiModalityConfig)
|
272 |
+
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
janus/models/processing_vlm.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Dict, List
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from PIL.Image import Image
|
25 |
+
from transformers import LlamaTokenizerFast
|
26 |
+
from transformers.processing_utils import ProcessorMixin
|
27 |
+
|
28 |
+
from janus.models.image_processing_vlm import VLMImageProcessor
|
29 |
+
from janus.utils.conversation import get_conv_template
|
30 |
+
|
31 |
+
|
32 |
+
class DictOutput(object):
|
33 |
+
def keys(self):
|
34 |
+
return self.__dict__.keys()
|
35 |
+
|
36 |
+
def __getitem__(self, item):
|
37 |
+
return self.__dict__[item]
|
38 |
+
|
39 |
+
def __setitem__(self, key, value):
|
40 |
+
self.__dict__[key] = value
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class VLChatProcessorOutput(DictOutput):
|
45 |
+
sft_format: str
|
46 |
+
input_ids: torch.Tensor
|
47 |
+
pixel_values: torch.Tensor
|
48 |
+
num_image_tokens: torch.IntTensor
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.input_ids)
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class BatchedVLChatProcessorOutput(DictOutput):
|
56 |
+
sft_format: List[str]
|
57 |
+
input_ids: torch.Tensor
|
58 |
+
pixel_values: torch.Tensor
|
59 |
+
attention_mask: torch.Tensor
|
60 |
+
images_seq_mask: torch.BoolTensor
|
61 |
+
images_emb_mask: torch.BoolTensor
|
62 |
+
|
63 |
+
def to(self, device, dtype=torch.bfloat16):
|
64 |
+
self.input_ids = self.input_ids.to(device)
|
65 |
+
self.attention_mask = self.attention_mask.to(device)
|
66 |
+
self.images_seq_mask = self.images_seq_mask.to(device)
|
67 |
+
self.images_emb_mask = self.images_emb_mask.to(device)
|
68 |
+
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
|
69 |
+
return self
|
70 |
+
|
71 |
+
|
72 |
+
class VLChatProcessor(ProcessorMixin):
|
73 |
+
image_processor_class = "AutoImageProcessor"
|
74 |
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
75 |
+
|
76 |
+
attributes = ["image_processor", "tokenizer"]
|
77 |
+
|
78 |
+
system_prompt = (
|
79 |
+
"You are a helpful language and vision assistant. "
|
80 |
+
"You are able to understand the visual content that the user provides, "
|
81 |
+
"and assist the user with a variety of tasks using natural language."
|
82 |
+
)
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
image_processor: VLMImageProcessor,
|
87 |
+
tokenizer: LlamaTokenizerFast,
|
88 |
+
image_tag: str = "<image_placeholder>",
|
89 |
+
image_start_tag: str = "<begin_of_image>",
|
90 |
+
image_end_tag: str = "<end_of_image>",
|
91 |
+
pad_tag: str = "<|▁pad▁|>",
|
92 |
+
num_image_tokens: int = 576,
|
93 |
+
add_special_token: bool = False,
|
94 |
+
sft_format: str = "deepseek",
|
95 |
+
mask_prompt: bool = True,
|
96 |
+
ignore_id: int = -100,
|
97 |
+
**kwargs,
|
98 |
+
):
|
99 |
+
self.image_processor = image_processor
|
100 |
+
self.tokenizer = tokenizer
|
101 |
+
|
102 |
+
image_id = self.tokenizer.vocab.get(image_tag)
|
103 |
+
if image_id is None:
|
104 |
+
special_tokens = [image_tag]
|
105 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
106 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
107 |
+
print(f"Add image tag = {image_tag} to the tokenizer")
|
108 |
+
|
109 |
+
self.image_tag = image_tag
|
110 |
+
self.image_start_tag = image_start_tag
|
111 |
+
self.image_end_tag = image_end_tag
|
112 |
+
self.pad_tag = pad_tag
|
113 |
+
|
114 |
+
self.num_image_tokens = num_image_tokens
|
115 |
+
self.add_special_token = add_special_token
|
116 |
+
self.sft_format = sft_format
|
117 |
+
self.mask_prompt = mask_prompt
|
118 |
+
self.ignore_id = ignore_id
|
119 |
+
|
120 |
+
super().__init__(
|
121 |
+
image_processor,
|
122 |
+
tokenizer,
|
123 |
+
image_tag,
|
124 |
+
num_image_tokens,
|
125 |
+
add_special_token,
|
126 |
+
sft_format,
|
127 |
+
mask_prompt,
|
128 |
+
ignore_id,
|
129 |
+
**kwargs,
|
130 |
+
)
|
131 |
+
|
132 |
+
def new_chat_template(self):
|
133 |
+
conv = get_conv_template(self.sft_format)
|
134 |
+
conv.set_system_message(self.system_prompt)
|
135 |
+
return conv
|
136 |
+
|
137 |
+
def apply_sft_template_for_multi_turn_prompts(
|
138 |
+
self,
|
139 |
+
conversations: List[Dict[str, str]],
|
140 |
+
sft_format: str = "deepseek",
|
141 |
+
system_prompt: str = "",
|
142 |
+
):
|
143 |
+
"""
|
144 |
+
Applies the SFT template to conversation.
|
145 |
+
|
146 |
+
An example of conversation:
|
147 |
+
conversation = [
|
148 |
+
{
|
149 |
+
"role": "User",
|
150 |
+
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
|
151 |
+
"images": [
|
152 |
+
"./multi-images/attribute_comparison_1.png",
|
153 |
+
"./multi-images/attribute_comparison_2.png"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"role": "Assistant",
|
158 |
+
"content": ""
|
159 |
+
}
|
160 |
+
]
|
161 |
+
|
162 |
+
Args:
|
163 |
+
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
|
164 |
+
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
|
165 |
+
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
sft_prompt (str): The formatted text.
|
169 |
+
"""
|
170 |
+
|
171 |
+
conv = get_conv_template(sft_format)
|
172 |
+
conv.set_system_message(system_prompt)
|
173 |
+
for message in conversations:
|
174 |
+
conv.append_message(message["role"], message["content"].strip())
|
175 |
+
sft_prompt = conv.get_prompt().strip()
|
176 |
+
|
177 |
+
return sft_prompt
|
178 |
+
|
179 |
+
@property
|
180 |
+
def image_token(self):
|
181 |
+
return self.image_tag
|
182 |
+
|
183 |
+
@property
|
184 |
+
def image_id(self):
|
185 |
+
image_id = self.tokenizer.vocab.get(self.image_tag)
|
186 |
+
return image_id
|
187 |
+
|
188 |
+
@property
|
189 |
+
def image_start_id(self):
|
190 |
+
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
|
191 |
+
return image_start_id
|
192 |
+
|
193 |
+
@property
|
194 |
+
def image_end_id(self):
|
195 |
+
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
|
196 |
+
return image_end_id
|
197 |
+
|
198 |
+
@property
|
199 |
+
def image_start_token(self):
|
200 |
+
return self.image_start_tag
|
201 |
+
|
202 |
+
@property
|
203 |
+
def image_end_token(self):
|
204 |
+
return self.image_end_tag
|
205 |
+
|
206 |
+
@property
|
207 |
+
def pad_id(self):
|
208 |
+
pad_id = self.tokenizer.vocab.get(self.pad_tag)
|
209 |
+
# pad_id = self.tokenizer.pad_token_id
|
210 |
+
# if pad_id is None:
|
211 |
+
# pad_id = self.tokenizer.eos_token_id
|
212 |
+
|
213 |
+
return pad_id
|
214 |
+
|
215 |
+
def add_image_token(
|
216 |
+
self,
|
217 |
+
image_indices: List[int],
|
218 |
+
input_ids: torch.LongTensor,
|
219 |
+
):
|
220 |
+
"""
|
221 |
+
|
222 |
+
Args:
|
223 |
+
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
224 |
+
input_ids (torch.LongTensor): [N]
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
input_ids (torch.LongTensor): [N + image tokens]
|
228 |
+
num_image_tokens (torch.IntTensor): [n_images]
|
229 |
+
"""
|
230 |
+
|
231 |
+
input_slices = []
|
232 |
+
|
233 |
+
start = 0
|
234 |
+
for index in image_indices:
|
235 |
+
if self.add_special_token:
|
236 |
+
end = index + 1
|
237 |
+
else:
|
238 |
+
end = index
|
239 |
+
|
240 |
+
# original text tokens
|
241 |
+
input_slices.append(input_ids[start:end])
|
242 |
+
|
243 |
+
# add boi, image tokens, eoi and set the mask as False
|
244 |
+
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
|
245 |
+
input_slices.append(
|
246 |
+
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
247 |
+
)
|
248 |
+
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
|
249 |
+
start = index + 1
|
250 |
+
|
251 |
+
# the left part
|
252 |
+
input_slices.append(input_ids[start:])
|
253 |
+
|
254 |
+
# concat all slices
|
255 |
+
input_ids = torch.cat(input_slices, dim=0)
|
256 |
+
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
257 |
+
|
258 |
+
return input_ids, num_image_tokens
|
259 |
+
|
260 |
+
def process_one(
|
261 |
+
self,
|
262 |
+
prompt: str = None,
|
263 |
+
conversations: List[Dict[str, str]] = None,
|
264 |
+
images: List[Image] = None,
|
265 |
+
**kwargs,
|
266 |
+
):
|
267 |
+
"""
|
268 |
+
|
269 |
+
Args:
|
270 |
+
prompt (str): the formatted prompt;
|
271 |
+
conversations (List[Dict]): conversations with a list of messages;
|
272 |
+
images (List[ImageType]): the list of images;
|
273 |
+
**kwargs:
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
277 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
278 |
+
- target_ids (torch.LongTensor): [N + image tokens]
|
279 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
280 |
+
- image_id (int): the id of the image token
|
281 |
+
- num_image_tokens (List[int]): the number of image tokens
|
282 |
+
"""
|
283 |
+
|
284 |
+
assert (
|
285 |
+
prompt is None or conversations is None
|
286 |
+
), "prompt and conversations cannot be used at the same time."
|
287 |
+
|
288 |
+
if prompt is None:
|
289 |
+
# apply sft format
|
290 |
+
sft_format = self.apply_sft_template_for_multi_turn_prompts(
|
291 |
+
conversations=conversations,
|
292 |
+
sft_format=self.sft_format,
|
293 |
+
system_prompt=self.system_prompt,
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
sft_format = prompt
|
297 |
+
|
298 |
+
# tokenize
|
299 |
+
input_ids = self.tokenizer.encode(sft_format)
|
300 |
+
input_ids = torch.LongTensor(input_ids)
|
301 |
+
|
302 |
+
# add image tokens to the input_ids
|
303 |
+
image_token_mask: torch.BoolTensor = input_ids == self.image_id
|
304 |
+
image_indices = image_token_mask.nonzero()
|
305 |
+
input_ids, num_image_tokens = self.add_image_token(
|
306 |
+
image_indices=image_indices,
|
307 |
+
input_ids=input_ids,
|
308 |
+
)
|
309 |
+
|
310 |
+
# load images
|
311 |
+
images_outputs = self.image_processor(images, return_tensors="pt")
|
312 |
+
|
313 |
+
prepare = VLChatProcessorOutput(
|
314 |
+
sft_format=sft_format,
|
315 |
+
input_ids=input_ids,
|
316 |
+
pixel_values=images_outputs.pixel_values,
|
317 |
+
num_image_tokens=num_image_tokens,
|
318 |
+
)
|
319 |
+
|
320 |
+
return prepare
|
321 |
+
|
322 |
+
def __call__(
|
323 |
+
self,
|
324 |
+
*,
|
325 |
+
prompt: str = None,
|
326 |
+
conversations: List[Dict[str, str]] = None,
|
327 |
+
images: List[Image] = None,
|
328 |
+
force_batchify: bool = True,
|
329 |
+
**kwargs,
|
330 |
+
):
|
331 |
+
"""
|
332 |
+
|
333 |
+
Args:
|
334 |
+
prompt (str): the formatted prompt;
|
335 |
+
conversations (List[Dict]): conversations with a list of messages;
|
336 |
+
images (List[ImageType]): the list of images;
|
337 |
+
force_batchify (bool): force batchify the inputs;
|
338 |
+
**kwargs:
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
342 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
343 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
344 |
+
- image_id (int): the id of the image token
|
345 |
+
- num_image_tokens (List[int]): the number of image tokens
|
346 |
+
"""
|
347 |
+
|
348 |
+
prepare = self.process_one(
|
349 |
+
prompt=prompt, conversations=conversations, images=images
|
350 |
+
)
|
351 |
+
|
352 |
+
if force_batchify:
|
353 |
+
prepare = self.batchify([prepare])
|
354 |
+
|
355 |
+
return prepare
|
356 |
+
|
357 |
+
def batchify(
|
358 |
+
self, prepare_list: List[VLChatProcessorOutput]
|
359 |
+
) -> BatchedVLChatProcessorOutput:
|
360 |
+
"""
|
361 |
+
Preprocesses the inputs for multimodal inference.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
368 |
+
"""
|
369 |
+
|
370 |
+
batch_size = len(prepare_list)
|
371 |
+
sft_format = []
|
372 |
+
n_images = []
|
373 |
+
seq_lens = []
|
374 |
+
for prepare in prepare_list:
|
375 |
+
n_images.append(len(prepare.num_image_tokens))
|
376 |
+
seq_lens.append(len(prepare))
|
377 |
+
|
378 |
+
input_token_max_len = max(seq_lens)
|
379 |
+
max_n_images = max(1, max(n_images))
|
380 |
+
|
381 |
+
batched_input_ids = torch.full(
|
382 |
+
(batch_size, input_token_max_len), self.pad_id
|
383 |
+
).long() # FIXME
|
384 |
+
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
385 |
+
batched_pixel_values = torch.zeros(
|
386 |
+
(batch_size, max_n_images, *self.image_processor.default_shape)
|
387 |
+
).float()
|
388 |
+
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
389 |
+
batched_images_emb_mask = torch.zeros(
|
390 |
+
(batch_size, max_n_images, self.num_image_tokens)
|
391 |
+
).bool()
|
392 |
+
|
393 |
+
for i, prepare in enumerate(prepare_list):
|
394 |
+
input_ids = prepare.input_ids
|
395 |
+
seq_len = len(prepare)
|
396 |
+
n_image = len(prepare.num_image_tokens)
|
397 |
+
# left-padding
|
398 |
+
batched_attention_mask[i, -seq_len:] = 1
|
399 |
+
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
400 |
+
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
401 |
+
|
402 |
+
if n_image > 0:
|
403 |
+
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
404 |
+
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
405 |
+
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
406 |
+
|
407 |
+
sft_format.append(prepare.sft_format)
|
408 |
+
|
409 |
+
batched_prepares = BatchedVLChatProcessorOutput(
|
410 |
+
input_ids=batched_input_ids,
|
411 |
+
attention_mask=batched_attention_mask,
|
412 |
+
pixel_values=batched_pixel_values,
|
413 |
+
images_seq_mask=batched_images_seq_mask,
|
414 |
+
images_emb_mask=batched_images_emb_mask,
|
415 |
+
sft_format=sft_format,
|
416 |
+
)
|
417 |
+
|
418 |
+
return batched_prepares
|
janus/models/projector.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from attrdict import AttrDict
|
25 |
+
|
26 |
+
|
27 |
+
class MlpProjector(nn.Module):
|
28 |
+
def __init__(self, cfg):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.cfg = cfg
|
32 |
+
|
33 |
+
if cfg.projector_type == "identity":
|
34 |
+
modules = nn.Identity()
|
35 |
+
|
36 |
+
elif cfg.projector_type == "linear":
|
37 |
+
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
38 |
+
|
39 |
+
elif cfg.projector_type == "mlp_gelu":
|
40 |
+
mlp_depth = cfg.get("depth", 1)
|
41 |
+
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
42 |
+
for _ in range(1, mlp_depth):
|
43 |
+
modules.append(nn.GELU())
|
44 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
45 |
+
modules = nn.Sequential(*modules)
|
46 |
+
|
47 |
+
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
48 |
+
mlp_depth = cfg.get("depth", 1)
|
49 |
+
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
50 |
+
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
51 |
+
|
52 |
+
modules = []
|
53 |
+
for _ in range(1, mlp_depth):
|
54 |
+
modules.append(nn.GELU())
|
55 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
56 |
+
modules = nn.Sequential(*modules)
|
57 |
+
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
60 |
+
|
61 |
+
self.layers = modules
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
|
70 |
+
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
|
71 |
+
otherwise it is the feature from the single vision encoder.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
x (torch.Tensor): [b, s, c]
|
75 |
+
"""
|
76 |
+
|
77 |
+
if isinstance(x_or_tuple, tuple):
|
78 |
+
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
79 |
+
high_x, low_x = x_or_tuple
|
80 |
+
high_x = self.high_up_proj(high_x)
|
81 |
+
low_x = self.low_up_proj(low_x)
|
82 |
+
x = torch.concat([high_x, low_x], dim=-1)
|
83 |
+
else:
|
84 |
+
x = x_or_tuple
|
85 |
+
|
86 |
+
return self.layers(x)
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
cfg = AttrDict(
|
91 |
+
input_dim=1024,
|
92 |
+
n_embed=2048,
|
93 |
+
depth=2,
|
94 |
+
projector_type="low_high_hybrid_split_mlp_gelu",
|
95 |
+
)
|
96 |
+
inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
|
97 |
+
|
98 |
+
m = MlpProjector(cfg)
|
99 |
+
out = m(inputs)
|
100 |
+
print(out.shape)
|
janus/models/siglip_vit.py
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
21 |
+
import math
|
22 |
+
import warnings
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from functools import partial
|
25 |
+
from typing import (
|
26 |
+
Callable,
|
27 |
+
Dict,
|
28 |
+
Final,
|
29 |
+
List,
|
30 |
+
Literal,
|
31 |
+
Optional,
|
32 |
+
Sequence,
|
33 |
+
Set,
|
34 |
+
Tuple,
|
35 |
+
Type,
|
36 |
+
Union,
|
37 |
+
)
|
38 |
+
|
39 |
+
import torch
|
40 |
+
import torch.nn as nn
|
41 |
+
import torch.nn.functional as F
|
42 |
+
from timm.layers import (
|
43 |
+
AttentionPoolLatent,
|
44 |
+
DropPath,
|
45 |
+
LayerType,
|
46 |
+
Mlp,
|
47 |
+
PatchDropout,
|
48 |
+
PatchEmbed,
|
49 |
+
resample_abs_pos_embed,
|
50 |
+
)
|
51 |
+
from timm.models._manipulate import checkpoint_seq, named_apply
|
52 |
+
|
53 |
+
|
54 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
55 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
56 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
57 |
+
def norm_cdf(x):
|
58 |
+
# Computes standard normal cumulative distribution function
|
59 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
60 |
+
|
61 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
62 |
+
warnings.warn(
|
63 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
64 |
+
"The distribution of values may be incorrect.",
|
65 |
+
stacklevel=2,
|
66 |
+
)
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
# Values are generated by using a truncated uniform distribution and
|
70 |
+
# then using the inverse CDF for the normal distribution.
|
71 |
+
# Get upper and lower cdf values
|
72 |
+
l = norm_cdf((a - mean) / std) # noqa: E741
|
73 |
+
u = norm_cdf((b - mean) / std)
|
74 |
+
|
75 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
76 |
+
# [2l-1, 2u-1].
|
77 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
78 |
+
|
79 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
80 |
+
# standard normal
|
81 |
+
tensor.erfinv_()
|
82 |
+
|
83 |
+
# Transform to proper mean, std
|
84 |
+
tensor.mul_(std * math.sqrt(2.0))
|
85 |
+
tensor.add_(mean)
|
86 |
+
|
87 |
+
# Clamp to ensure it's in the proper range
|
88 |
+
tensor.clamp_(min=a, max=b)
|
89 |
+
return tensor
|
90 |
+
|
91 |
+
|
92 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
93 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
94 |
+
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
95 |
+
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
|
96 |
+
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
97 |
+
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
98 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
99 |
+
the bounds. The method used for generating the random values works
|
100 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
101 |
+
Args:
|
102 |
+
tensor: an n-dimensional `torch.Tensor`
|
103 |
+
mean: the mean of the normal distribution
|
104 |
+
std: the standard deviation of the normal distribution
|
105 |
+
a: the minimum cutoff value
|
106 |
+
b: the maximum cutoff value
|
107 |
+
Examples:
|
108 |
+
>>> w = torch.empty(3, 5)
|
109 |
+
>>> nn.init.trunc_normal_(w)
|
110 |
+
"""
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
dtype = tensor.dtype
|
114 |
+
tensor_fp32 = tensor.float()
|
115 |
+
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
116 |
+
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
117 |
+
tensor.copy_(tensor_dtype)
|
118 |
+
|
119 |
+
|
120 |
+
def init_weights(self):
|
121 |
+
if self.pos_embed is not None:
|
122 |
+
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
123 |
+
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
124 |
+
|
125 |
+
|
126 |
+
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
127 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
128 |
+
if isinstance(module, nn.Linear):
|
129 |
+
trunc_normal_(module.weight, std=0.02)
|
130 |
+
if module.bias is not None:
|
131 |
+
nn.init.zeros_(module.bias)
|
132 |
+
elif hasattr(module, "init_weights"):
|
133 |
+
module.init_weights()
|
134 |
+
|
135 |
+
|
136 |
+
class Attention(nn.Module):
|
137 |
+
fused_attn: Final[bool]
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
dim: int,
|
142 |
+
num_heads: int = 8,
|
143 |
+
qkv_bias: bool = False,
|
144 |
+
qk_norm: bool = False,
|
145 |
+
attn_drop: float = 0.0,
|
146 |
+
proj_drop: float = 0.0,
|
147 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
148 |
+
) -> None:
|
149 |
+
super().__init__()
|
150 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
151 |
+
self.num_heads = num_heads
|
152 |
+
self.head_dim = dim // num_heads
|
153 |
+
self.scale = self.head_dim**-0.5
|
154 |
+
# self.fused_attn = use_fused_attn()
|
155 |
+
self.fused_attn = True
|
156 |
+
|
157 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
158 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
159 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
160 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
161 |
+
self.proj = nn.Linear(dim, dim)
|
162 |
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
B, N, C = x.shape
|
166 |
+
qkv = (
|
167 |
+
self.qkv(x)
|
168 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
169 |
+
.permute(2, 0, 3, 1, 4)
|
170 |
+
)
|
171 |
+
q, k, v = qkv.unbind(0)
|
172 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
173 |
+
|
174 |
+
if self.fused_attn:
|
175 |
+
x = F.scaled_dot_product_attention(
|
176 |
+
q,
|
177 |
+
k,
|
178 |
+
v,
|
179 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
q = q * self.scale
|
183 |
+
attn = q @ k.transpose(-2, -1)
|
184 |
+
attn = attn.softmax(dim=-1)
|
185 |
+
attn = self.attn_drop(attn)
|
186 |
+
x = attn @ v
|
187 |
+
|
188 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
189 |
+
x = self.proj(x)
|
190 |
+
x = self.proj_drop(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class LayerScale(nn.Module):
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
dim: int,
|
198 |
+
init_values: float = 1e-5,
|
199 |
+
inplace: bool = False,
|
200 |
+
) -> None:
|
201 |
+
super().__init__()
|
202 |
+
self.inplace = inplace
|
203 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
204 |
+
|
205 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
206 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
207 |
+
|
208 |
+
|
209 |
+
class Block(nn.Module):
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
dim: int,
|
213 |
+
num_heads: int,
|
214 |
+
mlp_ratio: float = 4.0,
|
215 |
+
qkv_bias: bool = False,
|
216 |
+
qk_norm: bool = False,
|
217 |
+
proj_drop: float = 0.0,
|
218 |
+
attn_drop: float = 0.0,
|
219 |
+
init_values: Optional[float] = None,
|
220 |
+
drop_path: float = 0.0,
|
221 |
+
act_layer: nn.Module = nn.GELU,
|
222 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
223 |
+
mlp_layer: nn.Module = Mlp,
|
224 |
+
) -> None:
|
225 |
+
super().__init__()
|
226 |
+
self.norm1 = norm_layer(dim)
|
227 |
+
self.attn = Attention(
|
228 |
+
dim,
|
229 |
+
num_heads=num_heads,
|
230 |
+
qkv_bias=qkv_bias,
|
231 |
+
qk_norm=qk_norm,
|
232 |
+
attn_drop=attn_drop,
|
233 |
+
proj_drop=proj_drop,
|
234 |
+
norm_layer=norm_layer,
|
235 |
+
)
|
236 |
+
self.ls1 = (
|
237 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
238 |
+
)
|
239 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
240 |
+
|
241 |
+
self.norm2 = norm_layer(dim)
|
242 |
+
self.mlp = mlp_layer(
|
243 |
+
in_features=dim,
|
244 |
+
hidden_features=int(dim * mlp_ratio),
|
245 |
+
act_layer=act_layer,
|
246 |
+
drop=proj_drop,
|
247 |
+
)
|
248 |
+
self.ls2 = (
|
249 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
250 |
+
)
|
251 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
252 |
+
|
253 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
254 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
255 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class VisionTransformer(nn.Module):
|
260 |
+
"""Vision Transformer
|
261 |
+
|
262 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
263 |
+
- https://arxiv.org/abs/2010.11929
|
264 |
+
"""
|
265 |
+
|
266 |
+
dynamic_img_size: Final[bool]
|
267 |
+
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
271 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
272 |
+
in_chans: int = 3,
|
273 |
+
num_classes: int = 1000,
|
274 |
+
global_pool: Literal["", "avg", "token", "map"] = "token",
|
275 |
+
embed_dim: int = 768,
|
276 |
+
depth: int = 12,
|
277 |
+
num_heads: int = 12,
|
278 |
+
mlp_ratio: float = 4.0,
|
279 |
+
qkv_bias: bool = True,
|
280 |
+
qk_norm: bool = False,
|
281 |
+
init_values: Optional[float] = None,
|
282 |
+
class_token: bool = True,
|
283 |
+
no_embed_class: bool = False,
|
284 |
+
reg_tokens: int = 0,
|
285 |
+
pre_norm: bool = False,
|
286 |
+
fc_norm: Optional[bool] = None,
|
287 |
+
dynamic_img_size: bool = False,
|
288 |
+
dynamic_img_pad: bool = False,
|
289 |
+
drop_rate: float = 0.0,
|
290 |
+
pos_drop_rate: float = 0.0,
|
291 |
+
patch_drop_rate: float = 0.0,
|
292 |
+
proj_drop_rate: float = 0.0,
|
293 |
+
attn_drop_rate: float = 0.0,
|
294 |
+
drop_path_rate: float = 0.0,
|
295 |
+
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
296 |
+
embed_layer: Callable = PatchEmbed,
|
297 |
+
norm_layer: Optional[LayerType] = None,
|
298 |
+
act_layer: Optional[LayerType] = None,
|
299 |
+
block_fn: Type[nn.Module] = Block,
|
300 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
301 |
+
ignore_head: bool = False,
|
302 |
+
) -> None:
|
303 |
+
"""
|
304 |
+
Args:
|
305 |
+
img_size: Input image size.
|
306 |
+
patch_size: Patch size.
|
307 |
+
in_chans: Number of image input channels.
|
308 |
+
num_classes: Mumber of classes for classification head.
|
309 |
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
310 |
+
embed_dim: Transformer embedding dimension.
|
311 |
+
depth: Depth of transformer.
|
312 |
+
num_heads: Number of attention heads.
|
313 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
314 |
+
qkv_bias: Enable bias for qkv projections if True.
|
315 |
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
316 |
+
class_token: Use class token.
|
317 |
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
318 |
+
reg_tokens: Number of register tokens.
|
319 |
+
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
320 |
+
drop_rate: Head dropout rate.
|
321 |
+
pos_drop_rate: Position embedding dropout rate.
|
322 |
+
attn_drop_rate: Attention dropout rate.
|
323 |
+
drop_path_rate: Stochastic depth rate.
|
324 |
+
weight_init: Weight initialization scheme.
|
325 |
+
embed_layer: Patch embedding layer.
|
326 |
+
norm_layer: Normalization layer.
|
327 |
+
act_layer: MLP activation layer.
|
328 |
+
block_fn: Transformer block layer.
|
329 |
+
"""
|
330 |
+
super().__init__()
|
331 |
+
assert global_pool in ("", "avg", "token", "map")
|
332 |
+
assert class_token or global_pool != "token"
|
333 |
+
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
334 |
+
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
335 |
+
# act_layer = get_act_layer(act_layer) or nn.GELU
|
336 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
337 |
+
act_layer = nn.GELU
|
338 |
+
|
339 |
+
self.num_classes = num_classes
|
340 |
+
self.global_pool = global_pool
|
341 |
+
self.num_features = self.embed_dim = (
|
342 |
+
embed_dim # num_features for consistency with other models
|
343 |
+
)
|
344 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
345 |
+
self.num_prefix_tokens += reg_tokens
|
346 |
+
self.num_reg_tokens = reg_tokens
|
347 |
+
self.has_class_token = class_token
|
348 |
+
self.no_embed_class = (
|
349 |
+
no_embed_class # don't embed prefix positions (includes reg)
|
350 |
+
)
|
351 |
+
self.dynamic_img_size = dynamic_img_size
|
352 |
+
self.grad_checkpointing = False
|
353 |
+
self.ignore_head = ignore_head
|
354 |
+
|
355 |
+
embed_args = {}
|
356 |
+
if dynamic_img_size:
|
357 |
+
# flatten deferred until after pos embed
|
358 |
+
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
359 |
+
self.patch_embed = embed_layer(
|
360 |
+
img_size=img_size,
|
361 |
+
patch_size=patch_size,
|
362 |
+
in_chans=in_chans,
|
363 |
+
embed_dim=embed_dim,
|
364 |
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
365 |
+
dynamic_img_pad=dynamic_img_pad,
|
366 |
+
**embed_args,
|
367 |
+
)
|
368 |
+
num_patches = self.patch_embed.num_patches
|
369 |
+
|
370 |
+
self.cls_token = (
|
371 |
+
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
372 |
+
)
|
373 |
+
self.reg_token = (
|
374 |
+
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
375 |
+
)
|
376 |
+
embed_len = (
|
377 |
+
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
378 |
+
)
|
379 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
380 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
381 |
+
if patch_drop_rate > 0:
|
382 |
+
self.patch_drop = PatchDropout(
|
383 |
+
patch_drop_rate,
|
384 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
self.patch_drop = nn.Identity()
|
388 |
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
389 |
+
|
390 |
+
dpr = [
|
391 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
392 |
+
] # stochastic depth decay rule
|
393 |
+
self.blocks = nn.Sequential(
|
394 |
+
*[
|
395 |
+
block_fn(
|
396 |
+
dim=embed_dim,
|
397 |
+
num_heads=num_heads,
|
398 |
+
mlp_ratio=mlp_ratio,
|
399 |
+
qkv_bias=qkv_bias,
|
400 |
+
qk_norm=qk_norm,
|
401 |
+
init_values=init_values,
|
402 |
+
proj_drop=proj_drop_rate,
|
403 |
+
attn_drop=attn_drop_rate,
|
404 |
+
drop_path=dpr[i],
|
405 |
+
norm_layer=norm_layer,
|
406 |
+
act_layer=act_layer,
|
407 |
+
mlp_layer=mlp_layer,
|
408 |
+
)
|
409 |
+
for i in range(depth)
|
410 |
+
]
|
411 |
+
)
|
412 |
+
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
413 |
+
|
414 |
+
# Classifier Head
|
415 |
+
if global_pool == "map":
|
416 |
+
AttentionPoolLatent.init_weights = init_weights
|
417 |
+
self.attn_pool = AttentionPoolLatent(
|
418 |
+
self.embed_dim,
|
419 |
+
num_heads=num_heads,
|
420 |
+
mlp_ratio=mlp_ratio,
|
421 |
+
norm_layer=norm_layer,
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
self.attn_pool = None
|
425 |
+
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
426 |
+
self.head_drop = nn.Dropout(drop_rate)
|
427 |
+
self.head = (
|
428 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
429 |
+
)
|
430 |
+
|
431 |
+
if weight_init != "skip":
|
432 |
+
self.init_weights(weight_init)
|
433 |
+
|
434 |
+
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
435 |
+
assert mode in ("jax", "jax_nlhb", "moco", "")
|
436 |
+
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
437 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
438 |
+
if self.cls_token is not None:
|
439 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
440 |
+
named_apply(init_weights_vit_timm, self)
|
441 |
+
|
442 |
+
@torch.jit.ignore
|
443 |
+
def no_weight_decay(self) -> Set:
|
444 |
+
return {"pos_embed", "cls_token", "dist_token"}
|
445 |
+
|
446 |
+
@torch.jit.ignore
|
447 |
+
def group_matcher(self, coarse: bool = False) -> Dict:
|
448 |
+
return dict(
|
449 |
+
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
450 |
+
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
451 |
+
)
|
452 |
+
|
453 |
+
@torch.jit.ignore
|
454 |
+
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
455 |
+
self.grad_checkpointing = enable
|
456 |
+
|
457 |
+
@torch.jit.ignore
|
458 |
+
def get_classifier(self) -> nn.Module:
|
459 |
+
return self.head
|
460 |
+
|
461 |
+
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
462 |
+
self.num_classes = num_classes
|
463 |
+
if global_pool is not None:
|
464 |
+
assert global_pool in ("", "avg", "token", "map")
|
465 |
+
if global_pool == "map" and self.attn_pool is None:
|
466 |
+
assert (
|
467 |
+
False
|
468 |
+
), "Cannot currently add attention pooling in reset_classifier()."
|
469 |
+
elif global_pool != "map " and self.attn_pool is not None:
|
470 |
+
self.attn_pool = None # remove attention pooling
|
471 |
+
self.global_pool = global_pool
|
472 |
+
self.head = (
|
473 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
474 |
+
)
|
475 |
+
|
476 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
477 |
+
if self.dynamic_img_size:
|
478 |
+
B, H, W, C = x.shape
|
479 |
+
pos_embed = resample_abs_pos_embed(
|
480 |
+
self.pos_embed,
|
481 |
+
(H, W),
|
482 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
483 |
+
)
|
484 |
+
x = x.view(B, -1, C)
|
485 |
+
else:
|
486 |
+
pos_embed = self.pos_embed
|
487 |
+
|
488 |
+
to_cat = []
|
489 |
+
if self.cls_token is not None:
|
490 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
491 |
+
if self.reg_token is not None:
|
492 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
493 |
+
|
494 |
+
if self.no_embed_class:
|
495 |
+
# deit-3, updated JAX (big vision)
|
496 |
+
# position embedding does not overlap with class token, add then concat
|
497 |
+
x = x + pos_embed
|
498 |
+
if to_cat:
|
499 |
+
x = torch.cat(to_cat + [x], dim=1)
|
500 |
+
else:
|
501 |
+
# original timm, JAX, and deit vit impl
|
502 |
+
# pos_embed has entry for class token, concat then add
|
503 |
+
if to_cat:
|
504 |
+
x = torch.cat(to_cat + [x], dim=1)
|
505 |
+
x = x + pos_embed
|
506 |
+
|
507 |
+
return self.pos_drop(x)
|
508 |
+
|
509 |
+
def _intermediate_layers(
|
510 |
+
self,
|
511 |
+
x: torch.Tensor,
|
512 |
+
n: Union[int, Sequence] = 1,
|
513 |
+
) -> List[torch.Tensor]:
|
514 |
+
outputs, num_blocks = [], len(self.blocks)
|
515 |
+
take_indices = set(
|
516 |
+
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
517 |
+
)
|
518 |
+
|
519 |
+
# forward pass
|
520 |
+
x = self.patch_embed(x)
|
521 |
+
x = self._pos_embed(x)
|
522 |
+
x = self.patch_drop(x)
|
523 |
+
x = self.norm_pre(x)
|
524 |
+
for i, blk in enumerate(self.blocks):
|
525 |
+
x = blk(x)
|
526 |
+
if i in take_indices:
|
527 |
+
outputs.append(x)
|
528 |
+
|
529 |
+
return outputs
|
530 |
+
|
531 |
+
def get_intermediate_layers(
|
532 |
+
self,
|
533 |
+
x: torch.Tensor,
|
534 |
+
n: Union[int, Sequence] = 1,
|
535 |
+
reshape: bool = False,
|
536 |
+
return_prefix_tokens: bool = False,
|
537 |
+
norm: bool = False,
|
538 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
539 |
+
"""Intermediate layer accessor (NOTE: This is a WIP experiment).
|
540 |
+
Inspired by DINO / DINOv2 interface
|
541 |
+
"""
|
542 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
543 |
+
outputs = self._intermediate_layers(x, n)
|
544 |
+
if norm:
|
545 |
+
outputs = [self.norm(out) for out in outputs]
|
546 |
+
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
|
547 |
+
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
|
548 |
+
|
549 |
+
if reshape:
|
550 |
+
grid_size = self.patch_embed.grid_size
|
551 |
+
outputs = [
|
552 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
|
553 |
+
.permute(0, 3, 1, 2)
|
554 |
+
.contiguous()
|
555 |
+
for out in outputs
|
556 |
+
]
|
557 |
+
|
558 |
+
if return_prefix_tokens:
|
559 |
+
return tuple(zip(outputs, prefix_tokens))
|
560 |
+
return tuple(outputs)
|
561 |
+
|
562 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
563 |
+
x = self.patch_embed(x)
|
564 |
+
x = self._pos_embed(x)
|
565 |
+
x = self.patch_drop(x)
|
566 |
+
x = self.norm_pre(x)
|
567 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
568 |
+
x = checkpoint_seq(self.blocks, x)
|
569 |
+
else:
|
570 |
+
x = self.blocks(x)
|
571 |
+
x = self.norm(x)
|
572 |
+
return x
|
573 |
+
|
574 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
575 |
+
if self.attn_pool is not None:
|
576 |
+
x = self.attn_pool(x)
|
577 |
+
elif self.global_pool == "avg":
|
578 |
+
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
579 |
+
elif self.global_pool:
|
580 |
+
x = x[:, 0] # class token
|
581 |
+
x = self.fc_norm(x)
|
582 |
+
x = self.head_drop(x)
|
583 |
+
return x if pre_logits else self.head(x)
|
584 |
+
|
585 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
586 |
+
x = self.forward_features(x)
|
587 |
+
if not self.ignore_head:
|
588 |
+
x = self.forward_head(x)
|
589 |
+
return x
|
590 |
+
|
591 |
+
|
592 |
+
@dataclass
|
593 |
+
class SigLIPVisionCfg:
|
594 |
+
width: int = 1152
|
595 |
+
layers: Union[Tuple[int, int, int, int], int] = 27
|
596 |
+
heads: int = 16
|
597 |
+
patch_size: int = 14
|
598 |
+
image_size: Union[Tuple[int, int], int] = 336
|
599 |
+
global_pool: str = "map"
|
600 |
+
mlp_ratio: float = 3.7362
|
601 |
+
class_token: bool = False
|
602 |
+
num_classes: int = 0
|
603 |
+
use_checkpoint: bool = False
|
604 |
+
|
605 |
+
|
606 |
+
SigLIP_MODEL_CONFIG = {
|
607 |
+
"siglip_so400m_patch14_384": {
|
608 |
+
"image_size": 336,
|
609 |
+
"patch_size": 14,
|
610 |
+
"width": 1152,
|
611 |
+
"layers": 27,
|
612 |
+
"heads": 16,
|
613 |
+
"mlp_ratio": 3.7362,
|
614 |
+
"global_pool": "map",
|
615 |
+
"use_checkpoint": False,
|
616 |
+
},
|
617 |
+
"siglip_so400m_patch14_224": {
|
618 |
+
"image_size": 224,
|
619 |
+
"patch_size": 14,
|
620 |
+
"width": 1152,
|
621 |
+
"layers": 27,
|
622 |
+
"heads": 16,
|
623 |
+
"mlp_ratio": 3.7362,
|
624 |
+
"global_pool": "map",
|
625 |
+
"use_checkpoint": False,
|
626 |
+
},
|
627 |
+
"siglip_large_patch16_384": {
|
628 |
+
"image_size": 384,
|
629 |
+
"patch_size": 16,
|
630 |
+
"width": 1024,
|
631 |
+
"layers": 24,
|
632 |
+
"heads": 16,
|
633 |
+
"mlp_ratio": 4,
|
634 |
+
"global_pool": "map",
|
635 |
+
"use_checkpoint": False,
|
636 |
+
},
|
637 |
+
}
|
638 |
+
|
639 |
+
|
640 |
+
def create_siglip_vit(
|
641 |
+
model_name: str = "siglip_so400m_patch14_384",
|
642 |
+
image_size: int = 384,
|
643 |
+
select_layer: int = -1,
|
644 |
+
ckpt_path: str = "",
|
645 |
+
**kwargs,
|
646 |
+
):
|
647 |
+
assert (
|
648 |
+
model_name in SigLIP_MODEL_CONFIG.keys()
|
649 |
+
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
650 |
+
|
651 |
+
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
652 |
+
|
653 |
+
if select_layer <= 0:
|
654 |
+
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
655 |
+
else:
|
656 |
+
layers = min(vision_cfg.layers, select_layer)
|
657 |
+
|
658 |
+
model = VisionTransformer(
|
659 |
+
img_size=image_size,
|
660 |
+
patch_size=vision_cfg.patch_size,
|
661 |
+
embed_dim=vision_cfg.width,
|
662 |
+
depth=layers,
|
663 |
+
num_heads=vision_cfg.heads,
|
664 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
665 |
+
class_token=vision_cfg.class_token,
|
666 |
+
global_pool=vision_cfg.global_pool,
|
667 |
+
ignore_head=kwargs.get("ignore_head", True),
|
668 |
+
weight_init=kwargs.get("weight_init", "skip"),
|
669 |
+
num_classes=0,
|
670 |
+
)
|
671 |
+
|
672 |
+
if ckpt_path:
|
673 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
674 |
+
|
675 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
676 |
+
print(
|
677 |
+
f"SigLIP-ViT restores from {ckpt_path},\n"
|
678 |
+
f"\tincompatible_keys:', {incompatible_keys}."
|
679 |
+
)
|
680 |
+
|
681 |
+
return model
|
janus/models/vq_model.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
|
21 |
+
from dataclasses import dataclass, field
|
22 |
+
from typing import List
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
|
28 |
+
from functools import partial
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class ModelArgs:
|
33 |
+
codebook_size: int = 16384
|
34 |
+
codebook_embed_dim: int = 8
|
35 |
+
codebook_l2_norm: bool = True
|
36 |
+
codebook_show_usage: bool = True
|
37 |
+
commit_loss_beta: float = 0.25
|
38 |
+
entropy_loss_ratio: float = 0.0
|
39 |
+
|
40 |
+
encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
41 |
+
decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
42 |
+
z_channels: int = 256
|
43 |
+
dropout_p: float = 0.0
|
44 |
+
|
45 |
+
|
46 |
+
class Encoder(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
in_channels=3,
|
50 |
+
ch=128,
|
51 |
+
ch_mult=(1, 1, 2, 2, 4),
|
52 |
+
num_res_blocks=2,
|
53 |
+
norm_type="group",
|
54 |
+
dropout=0.0,
|
55 |
+
resamp_with_conv=True,
|
56 |
+
z_channels=256,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.num_resolutions = len(ch_mult)
|
60 |
+
self.num_res_blocks = num_res_blocks
|
61 |
+
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
|
62 |
+
|
63 |
+
# downsampling
|
64 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
65 |
+
self.conv_blocks = nn.ModuleList()
|
66 |
+
for i_level in range(self.num_resolutions):
|
67 |
+
conv_block = nn.Module()
|
68 |
+
# res & attn
|
69 |
+
res_block = nn.ModuleList()
|
70 |
+
attn_block = nn.ModuleList()
|
71 |
+
block_in = ch * in_ch_mult[i_level]
|
72 |
+
block_out = ch * ch_mult[i_level]
|
73 |
+
for _ in range(self.num_res_blocks):
|
74 |
+
res_block.append(
|
75 |
+
ResnetBlock(
|
76 |
+
block_in, block_out, dropout=dropout, norm_type=norm_type
|
77 |
+
)
|
78 |
+
)
|
79 |
+
block_in = block_out
|
80 |
+
if i_level == self.num_resolutions - 1:
|
81 |
+
attn_block.append(AttnBlock(block_in, norm_type))
|
82 |
+
conv_block.res = res_block
|
83 |
+
conv_block.attn = attn_block
|
84 |
+
# downsample
|
85 |
+
if i_level != self.num_resolutions - 1:
|
86 |
+
conv_block.downsample = Downsample(block_in, resamp_with_conv)
|
87 |
+
self.conv_blocks.append(conv_block)
|
88 |
+
|
89 |
+
# middle
|
90 |
+
self.mid = nn.ModuleList()
|
91 |
+
self.mid.append(
|
92 |
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
93 |
+
)
|
94 |
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
95 |
+
self.mid.append(
|
96 |
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
97 |
+
)
|
98 |
+
|
99 |
+
# end
|
100 |
+
self.norm_out = Normalize(block_in, norm_type)
|
101 |
+
self.conv_out = nn.Conv2d(
|
102 |
+
block_in, z_channels, kernel_size=3, stride=1, padding=1
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
h = self.conv_in(x)
|
107 |
+
# downsampling
|
108 |
+
for i_level, block in enumerate(self.conv_blocks):
|
109 |
+
for i_block in range(self.num_res_blocks):
|
110 |
+
h = block.res[i_block](h)
|
111 |
+
if len(block.attn) > 0:
|
112 |
+
h = block.attn[i_block](h)
|
113 |
+
if i_level != self.num_resolutions - 1:
|
114 |
+
h = block.downsample(h)
|
115 |
+
|
116 |
+
# middle
|
117 |
+
for mid_block in self.mid:
|
118 |
+
h = mid_block(h)
|
119 |
+
|
120 |
+
# end
|
121 |
+
h = self.norm_out(h)
|
122 |
+
h = nonlinearity(h)
|
123 |
+
h = self.conv_out(h)
|
124 |
+
return h
|
125 |
+
|
126 |
+
|
127 |
+
class Decoder(nn.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
z_channels=256,
|
131 |
+
ch=128,
|
132 |
+
ch_mult=(1, 1, 2, 2, 4),
|
133 |
+
num_res_blocks=2,
|
134 |
+
norm_type="group",
|
135 |
+
dropout=0.0,
|
136 |
+
resamp_with_conv=True,
|
137 |
+
out_channels=3,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
self.num_resolutions = len(ch_mult)
|
141 |
+
self.num_res_blocks = num_res_blocks
|
142 |
+
|
143 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
144 |
+
# z to block_in
|
145 |
+
self.conv_in = nn.Conv2d(
|
146 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
147 |
+
)
|
148 |
+
|
149 |
+
# middle
|
150 |
+
self.mid = nn.ModuleList()
|
151 |
+
self.mid.append(
|
152 |
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
153 |
+
)
|
154 |
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
155 |
+
self.mid.append(
|
156 |
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
157 |
+
)
|
158 |
+
|
159 |
+
# upsampling
|
160 |
+
self.conv_blocks = nn.ModuleList()
|
161 |
+
for i_level in reversed(range(self.num_resolutions)):
|
162 |
+
conv_block = nn.Module()
|
163 |
+
# res & attn
|
164 |
+
res_block = nn.ModuleList()
|
165 |
+
attn_block = nn.ModuleList()
|
166 |
+
block_out = ch * ch_mult[i_level]
|
167 |
+
for _ in range(self.num_res_blocks + 1):
|
168 |
+
res_block.append(
|
169 |
+
ResnetBlock(
|
170 |
+
block_in, block_out, dropout=dropout, norm_type=norm_type
|
171 |
+
)
|
172 |
+
)
|
173 |
+
block_in = block_out
|
174 |
+
if i_level == self.num_resolutions - 1:
|
175 |
+
attn_block.append(AttnBlock(block_in, norm_type))
|
176 |
+
conv_block.res = res_block
|
177 |
+
conv_block.attn = attn_block
|
178 |
+
# downsample
|
179 |
+
if i_level != 0:
|
180 |
+
conv_block.upsample = Upsample(block_in, resamp_with_conv)
|
181 |
+
self.conv_blocks.append(conv_block)
|
182 |
+
|
183 |
+
# end
|
184 |
+
self.norm_out = Normalize(block_in, norm_type)
|
185 |
+
self.conv_out = nn.Conv2d(
|
186 |
+
block_in, out_channels, kernel_size=3, stride=1, padding=1
|
187 |
+
)
|
188 |
+
|
189 |
+
@property
|
190 |
+
def last_layer(self):
|
191 |
+
return self.conv_out.weight
|
192 |
+
|
193 |
+
def forward(self, z):
|
194 |
+
# z to block_in
|
195 |
+
h = self.conv_in(z)
|
196 |
+
|
197 |
+
# middle
|
198 |
+
for mid_block in self.mid:
|
199 |
+
h = mid_block(h)
|
200 |
+
|
201 |
+
# upsampling
|
202 |
+
for i_level, block in enumerate(self.conv_blocks):
|
203 |
+
for i_block in range(self.num_res_blocks + 1):
|
204 |
+
h = block.res[i_block](h)
|
205 |
+
if len(block.attn) > 0:
|
206 |
+
h = block.attn[i_block](h)
|
207 |
+
if i_level != self.num_resolutions - 1:
|
208 |
+
h = block.upsample(h)
|
209 |
+
|
210 |
+
# end
|
211 |
+
h = self.norm_out(h)
|
212 |
+
h = nonlinearity(h)
|
213 |
+
h = self.conv_out(h)
|
214 |
+
return h
|
215 |
+
|
216 |
+
|
217 |
+
class VectorQuantizer(nn.Module):
|
218 |
+
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
|
219 |
+
super().__init__()
|
220 |
+
self.n_e = n_e
|
221 |
+
self.e_dim = e_dim
|
222 |
+
self.beta = beta
|
223 |
+
self.entropy_loss_ratio = entropy_loss_ratio
|
224 |
+
self.l2_norm = l2_norm
|
225 |
+
self.show_usage = show_usage
|
226 |
+
|
227 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
228 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
229 |
+
if self.l2_norm:
|
230 |
+
self.embedding.weight.data = F.normalize(
|
231 |
+
self.embedding.weight.data, p=2, dim=-1
|
232 |
+
)
|
233 |
+
if self.show_usage:
|
234 |
+
self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
|
235 |
+
|
236 |
+
def forward(self, z):
|
237 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
238 |
+
z = torch.einsum("b c h w -> b h w c", z).contiguous()
|
239 |
+
z_flattened = z.view(-1, self.e_dim)
|
240 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
241 |
+
|
242 |
+
if self.l2_norm:
|
243 |
+
z = F.normalize(z, p=2, dim=-1)
|
244 |
+
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
|
245 |
+
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
246 |
+
else:
|
247 |
+
embedding = self.embedding.weight
|
248 |
+
|
249 |
+
d = (
|
250 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
251 |
+
+ torch.sum(embedding**2, dim=1)
|
252 |
+
- 2
|
253 |
+
* torch.einsum(
|
254 |
+
"bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
259 |
+
z_q = embedding[min_encoding_indices].view(z.shape)
|
260 |
+
perplexity = None
|
261 |
+
min_encodings = None
|
262 |
+
vq_loss = None
|
263 |
+
commit_loss = None
|
264 |
+
entropy_loss = None
|
265 |
+
|
266 |
+
# compute loss for embedding
|
267 |
+
if self.training:
|
268 |
+
vq_loss = torch.mean((z_q - z.detach()) ** 2)
|
269 |
+
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
|
270 |
+
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
|
271 |
+
|
272 |
+
# preserve gradients
|
273 |
+
z_q = z + (z_q - z).detach()
|
274 |
+
|
275 |
+
# reshape back to match original input shape
|
276 |
+
z_q = torch.einsum("b h w c -> b c h w", z_q)
|
277 |
+
|
278 |
+
return (
|
279 |
+
z_q,
|
280 |
+
(vq_loss, commit_loss, entropy_loss),
|
281 |
+
(perplexity, min_encodings, min_encoding_indices),
|
282 |
+
)
|
283 |
+
|
284 |
+
def get_codebook_entry(self, indices, shape=None, channel_first=True):
|
285 |
+
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
|
286 |
+
if self.l2_norm:
|
287 |
+
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
288 |
+
else:
|
289 |
+
embedding = self.embedding.weight
|
290 |
+
z_q = embedding[indices] # (b*h*w, c)
|
291 |
+
|
292 |
+
if shape is not None:
|
293 |
+
if channel_first:
|
294 |
+
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
|
295 |
+
# reshape back to match original input shape
|
296 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
297 |
+
else:
|
298 |
+
z_q = z_q.view(shape)
|
299 |
+
return z_q
|
300 |
+
|
301 |
+
|
302 |
+
class ResnetBlock(nn.Module):
|
303 |
+
def __init__(
|
304 |
+
self,
|
305 |
+
in_channels,
|
306 |
+
out_channels=None,
|
307 |
+
conv_shortcut=False,
|
308 |
+
dropout=0.0,
|
309 |
+
norm_type="group",
|
310 |
+
):
|
311 |
+
super().__init__()
|
312 |
+
self.in_channels = in_channels
|
313 |
+
out_channels = in_channels if out_channels is None else out_channels
|
314 |
+
self.out_channels = out_channels
|
315 |
+
self.use_conv_shortcut = conv_shortcut
|
316 |
+
|
317 |
+
self.norm1 = Normalize(in_channels, norm_type)
|
318 |
+
self.conv1 = nn.Conv2d(
|
319 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
320 |
+
)
|
321 |
+
self.norm2 = Normalize(out_channels, norm_type)
|
322 |
+
self.dropout = nn.Dropout(dropout)
|
323 |
+
self.conv2 = nn.Conv2d(
|
324 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
325 |
+
)
|
326 |
+
|
327 |
+
if self.in_channels != self.out_channels:
|
328 |
+
if self.use_conv_shortcut:
|
329 |
+
self.conv_shortcut = nn.Conv2d(
|
330 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
self.nin_shortcut = nn.Conv2d(
|
334 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
335 |
+
)
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
h = x
|
339 |
+
h = self.norm1(h)
|
340 |
+
h = nonlinearity(h)
|
341 |
+
h = self.conv1(h)
|
342 |
+
h = self.norm2(h)
|
343 |
+
h = nonlinearity(h)
|
344 |
+
h = self.dropout(h)
|
345 |
+
h = self.conv2(h)
|
346 |
+
|
347 |
+
if self.in_channels != self.out_channels:
|
348 |
+
if self.use_conv_shortcut:
|
349 |
+
x = self.conv_shortcut(x)
|
350 |
+
else:
|
351 |
+
x = self.nin_shortcut(x)
|
352 |
+
return x + h
|
353 |
+
|
354 |
+
|
355 |
+
class AttnBlock(nn.Module):
|
356 |
+
def __init__(self, in_channels, norm_type="group"):
|
357 |
+
super().__init__()
|
358 |
+
self.norm = Normalize(in_channels, norm_type)
|
359 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
360 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
361 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
362 |
+
self.proj_out = nn.Conv2d(
|
363 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
364 |
+
)
|
365 |
+
|
366 |
+
def forward(self, x):
|
367 |
+
h_ = x
|
368 |
+
h_ = self.norm(h_)
|
369 |
+
q = self.q(h_)
|
370 |
+
k = self.k(h_)
|
371 |
+
v = self.v(h_)
|
372 |
+
|
373 |
+
# compute attention
|
374 |
+
b, c, h, w = q.shape
|
375 |
+
q = q.reshape(b, c, h * w)
|
376 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
377 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
378 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
379 |
+
w_ = w_ * (int(c) ** (-0.5))
|
380 |
+
w_ = F.softmax(w_, dim=2)
|
381 |
+
|
382 |
+
# attend to values
|
383 |
+
v = v.reshape(b, c, h * w)
|
384 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
385 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
386 |
+
h_ = h_.reshape(b, c, h, w)
|
387 |
+
|
388 |
+
h_ = self.proj_out(h_)
|
389 |
+
|
390 |
+
return x + h_
|
391 |
+
|
392 |
+
|
393 |
+
def nonlinearity(x):
|
394 |
+
# swish
|
395 |
+
return x * torch.sigmoid(x)
|
396 |
+
|
397 |
+
|
398 |
+
def Normalize(in_channels, norm_type="group"):
|
399 |
+
assert norm_type in ["group", "batch"]
|
400 |
+
if norm_type == "group":
|
401 |
+
return nn.GroupNorm(
|
402 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
403 |
+
)
|
404 |
+
elif norm_type == "batch":
|
405 |
+
return nn.SyncBatchNorm(in_channels)
|
406 |
+
|
407 |
+
|
408 |
+
class Upsample(nn.Module):
|
409 |
+
def __init__(self, in_channels, with_conv):
|
410 |
+
super().__init__()
|
411 |
+
self.with_conv = with_conv
|
412 |
+
if self.with_conv:
|
413 |
+
self.conv = nn.Conv2d(
|
414 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
415 |
+
)
|
416 |
+
|
417 |
+
def forward(self, x):
|
418 |
+
if x.dtype != torch.float32:
|
419 |
+
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
|
420 |
+
torch.bfloat16
|
421 |
+
)
|
422 |
+
else:
|
423 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
424 |
+
|
425 |
+
if self.with_conv:
|
426 |
+
x = self.conv(x)
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
class Downsample(nn.Module):
|
431 |
+
def __init__(self, in_channels, with_conv):
|
432 |
+
super().__init__()
|
433 |
+
self.with_conv = with_conv
|
434 |
+
if self.with_conv:
|
435 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
436 |
+
self.conv = nn.Conv2d(
|
437 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
438 |
+
)
|
439 |
+
|
440 |
+
def forward(self, x):
|
441 |
+
if self.with_conv:
|
442 |
+
pad = (0, 1, 0, 1)
|
443 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
444 |
+
x = self.conv(x)
|
445 |
+
else:
|
446 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
447 |
+
return x
|
448 |
+
|
449 |
+
|
450 |
+
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
|
451 |
+
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
452 |
+
flat_affinity /= temperature
|
453 |
+
probs = F.softmax(flat_affinity, dim=-1)
|
454 |
+
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
|
455 |
+
if loss_type == "softmax":
|
456 |
+
target_probs = probs
|
457 |
+
else:
|
458 |
+
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
459 |
+
avg_probs = torch.mean(target_probs, dim=0)
|
460 |
+
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
|
461 |
+
sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
|
462 |
+
loss = sample_entropy - avg_entropy
|
463 |
+
return loss
|
464 |
+
|
465 |
+
|
466 |
+
class VQModel(nn.Module):
|
467 |
+
def __init__(self, config: ModelArgs):
|
468 |
+
super().__init__()
|
469 |
+
self.config = config
|
470 |
+
self.encoder = Encoder(
|
471 |
+
ch_mult=config.encoder_ch_mult,
|
472 |
+
z_channels=config.z_channels,
|
473 |
+
dropout=config.dropout_p,
|
474 |
+
)
|
475 |
+
self.decoder = Decoder(
|
476 |
+
ch_mult=config.decoder_ch_mult,
|
477 |
+
z_channels=config.z_channels,
|
478 |
+
dropout=config.dropout_p,
|
479 |
+
)
|
480 |
+
|
481 |
+
self.quantize = VectorQuantizer(
|
482 |
+
config.codebook_size,
|
483 |
+
config.codebook_embed_dim,
|
484 |
+
config.commit_loss_beta,
|
485 |
+
config.entropy_loss_ratio,
|
486 |
+
config.codebook_l2_norm,
|
487 |
+
config.codebook_show_usage,
|
488 |
+
)
|
489 |
+
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
|
490 |
+
self.post_quant_conv = nn.Conv2d(
|
491 |
+
config.codebook_embed_dim, config.z_channels, 1
|
492 |
+
)
|
493 |
+
|
494 |
+
def encode(self, x):
|
495 |
+
h = self.encoder(x)
|
496 |
+
h = self.quant_conv(h)
|
497 |
+
quant, emb_loss, info = self.quantize(h)
|
498 |
+
return quant, emb_loss, info
|
499 |
+
|
500 |
+
def decode(self, quant):
|
501 |
+
quant = self.post_quant_conv(quant)
|
502 |
+
dec = self.decoder(quant)
|
503 |
+
return dec
|
504 |
+
|
505 |
+
def decode_code(self, code_b, shape=None, channel_first=True):
|
506 |
+
quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
|
507 |
+
dec = self.decode(quant_b)
|
508 |
+
return dec
|
509 |
+
|
510 |
+
def forward(self, input):
|
511 |
+
quant, diff, _ = self.encode(input)
|
512 |
+
dec = self.decode(quant)
|
513 |
+
return dec, diff
|
514 |
+
|
515 |
+
|
516 |
+
#################################################################################
|
517 |
+
# VQ Model Configs #
|
518 |
+
#################################################################################
|
519 |
+
def VQ_16(**kwargs):
|
520 |
+
return VQModel(
|
521 |
+
ModelArgs(
|
522 |
+
encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
|
523 |
+
)
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
VQ_models = {"VQ-16": VQ_16}
|
janus/utils/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
janus/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (174 Bytes). View file
|
|
janus/utils/__pycache__/conversation.cpython-38.pyc
ADDED
Binary file (7.5 kB). View file
|
|
janus/utils/__pycache__/io.cpython-38.pyc
ADDED
Binary file (2.06 kB). View file
|
|
janus/utils/conversation.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
"""
|
21 |
+
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
22 |
+
"""
|
23 |
+
|
24 |
+
import dataclasses
|
25 |
+
from enum import IntEnum, auto
|
26 |
+
from typing import Dict, List
|
27 |
+
|
28 |
+
|
29 |
+
class SeparatorStyle(IntEnum):
|
30 |
+
"""Separator styles."""
|
31 |
+
|
32 |
+
ADD_COLON_SINGLE = auto()
|
33 |
+
ADD_COLON_TWO = auto()
|
34 |
+
ADD_COLON_SPACE_SINGLE = auto()
|
35 |
+
NO_COLON_SINGLE = auto()
|
36 |
+
NO_COLON_TWO = auto()
|
37 |
+
ADD_NEW_LINE_SINGLE = auto()
|
38 |
+
LLAMA2 = auto()
|
39 |
+
CHATGLM = auto()
|
40 |
+
CHATML = auto()
|
41 |
+
CHATINTERN = auto()
|
42 |
+
DOLLY = auto()
|
43 |
+
RWKV = auto()
|
44 |
+
PHOENIX = auto()
|
45 |
+
ROBIN = auto()
|
46 |
+
DeepSeek = auto()
|
47 |
+
PLAIN = auto()
|
48 |
+
ALIGNMENT = auto()
|
49 |
+
|
50 |
+
|
51 |
+
@dataclasses.dataclass
|
52 |
+
class Conversation:
|
53 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
54 |
+
|
55 |
+
# The name of this template
|
56 |
+
name: str
|
57 |
+
# The template of the system prompt
|
58 |
+
system_template: str = "{system_message}"
|
59 |
+
# The system message
|
60 |
+
system_message: str = ""
|
61 |
+
# The names of two roles
|
62 |
+
roles: List[str] = (("USER", "ASSISTANT"),)
|
63 |
+
# All messages. Each item is (role, message).
|
64 |
+
messages: List[List[str]] = ()
|
65 |
+
# The number of few shot examples
|
66 |
+
offset: int = 0
|
67 |
+
# The separator style and configurations
|
68 |
+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
69 |
+
sep: str = "\n"
|
70 |
+
sep2: str = None
|
71 |
+
# Stop criteria (the default one is EOS token)
|
72 |
+
stop_str: str = None
|
73 |
+
# Stops generation if meeting any token in this list
|
74 |
+
stop_token_ids: List[int] = None
|
75 |
+
|
76 |
+
def get_prompt(self) -> str:
|
77 |
+
"""Get the prompt for generation."""
|
78 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
79 |
+
|
80 |
+
if self.sep_style == SeparatorStyle.DeepSeek:
|
81 |
+
seps = [self.sep, self.sep2]
|
82 |
+
if system_prompt == "" or system_prompt is None:
|
83 |
+
ret = ""
|
84 |
+
else:
|
85 |
+
ret = system_prompt + seps[0]
|
86 |
+
for i, (role, message) in enumerate(self.messages):
|
87 |
+
if message:
|
88 |
+
ret += role + ": " + message + seps[i % 2]
|
89 |
+
else:
|
90 |
+
ret += role + ":"
|
91 |
+
return ret
|
92 |
+
elif self.sep_style == SeparatorStyle.LLAMA2:
|
93 |
+
seps = [self.sep, self.sep2]
|
94 |
+
if self.system_message:
|
95 |
+
ret = system_prompt
|
96 |
+
else:
|
97 |
+
ret = "[INST] "
|
98 |
+
for i, (role, message) in enumerate(self.messages):
|
99 |
+
tag = self.roles[i % 2]
|
100 |
+
if message:
|
101 |
+
if type(message) is tuple: # multimodal message
|
102 |
+
message, _ = message
|
103 |
+
if i == 0:
|
104 |
+
ret += message + " "
|
105 |
+
else:
|
106 |
+
ret += tag + " " + message + seps[i % 2]
|
107 |
+
else:
|
108 |
+
ret += tag
|
109 |
+
return ret
|
110 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
111 |
+
seps = [self.sep, self.sep2]
|
112 |
+
ret = ""
|
113 |
+
for i, (role, message) in enumerate(self.messages):
|
114 |
+
if message:
|
115 |
+
if type(message) is tuple:
|
116 |
+
message, _, _ = message
|
117 |
+
if i % 2 == 0:
|
118 |
+
ret += message + seps[i % 2]
|
119 |
+
else:
|
120 |
+
ret += message + seps[i % 2]
|
121 |
+
else:
|
122 |
+
ret += ""
|
123 |
+
return ret
|
124 |
+
elif self.sep_style == SeparatorStyle.ALIGNMENT:
|
125 |
+
seps = [self.sep, self.sep2]
|
126 |
+
ret = ""
|
127 |
+
for i, (role, message) in enumerate(self.messages):
|
128 |
+
if message:
|
129 |
+
if type(message) is tuple:
|
130 |
+
message, _, _ = message
|
131 |
+
if i % 2 == 0:
|
132 |
+
ret += "<image>\n" + seps[i % 2]
|
133 |
+
else:
|
134 |
+
ret += message + seps[i % 2]
|
135 |
+
else:
|
136 |
+
ret += ""
|
137 |
+
return ret
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
140 |
+
|
141 |
+
def get_prompt_for_current_round(self, content=None):
|
142 |
+
"""Get current round formatted question prompt during sft training"""
|
143 |
+
if self.sep_style == SeparatorStyle.PLAIN:
|
144 |
+
formatted_question = "<image>\n"
|
145 |
+
elif self.sep_style == SeparatorStyle.DeepSeek:
|
146 |
+
formatted_question = (
|
147 |
+
f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Unsupported sep_style: {self.sep_style}")
|
151 |
+
return formatted_question
|
152 |
+
|
153 |
+
def set_system_message(self, system_message: str):
|
154 |
+
"""Set the system message."""
|
155 |
+
self.system_message = system_message
|
156 |
+
|
157 |
+
def append_message(self, role: str, message: str):
|
158 |
+
"""Append a new message."""
|
159 |
+
self.messages.append([role, message])
|
160 |
+
|
161 |
+
def reset_message(self):
|
162 |
+
"""Reset a new message."""
|
163 |
+
self.messages = []
|
164 |
+
|
165 |
+
def update_last_message(self, message: str):
|
166 |
+
"""Update the last output.
|
167 |
+
|
168 |
+
The last message is typically set to be None when constructing the prompt,
|
169 |
+
so we need to update it in-place after getting the response from a model.
|
170 |
+
"""
|
171 |
+
self.messages[-1][1] = message
|
172 |
+
|
173 |
+
def to_gradio_chatbot(self):
|
174 |
+
"""Convert the conversation to gradio chatbot format."""
|
175 |
+
ret = []
|
176 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
177 |
+
if i % 2 == 0:
|
178 |
+
ret.append([msg, None])
|
179 |
+
else:
|
180 |
+
ret[-1][-1] = msg
|
181 |
+
return ret
|
182 |
+
|
183 |
+
def to_openai_api_messages(self):
|
184 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
185 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
186 |
+
ret = [{"role": "system", "content": system_prompt}]
|
187 |
+
|
188 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
189 |
+
if i % 2 == 0:
|
190 |
+
ret.append({"role": "user", "content": msg})
|
191 |
+
else:
|
192 |
+
if msg is not None:
|
193 |
+
ret.append({"role": "assistant", "content": msg})
|
194 |
+
return ret
|
195 |
+
|
196 |
+
def copy(self):
|
197 |
+
return Conversation(
|
198 |
+
name=self.name,
|
199 |
+
system_template=self.system_template,
|
200 |
+
system_message=self.system_message,
|
201 |
+
roles=self.roles,
|
202 |
+
messages=[[x, y] for x, y in self.messages],
|
203 |
+
offset=self.offset,
|
204 |
+
sep_style=self.sep_style,
|
205 |
+
sep=self.sep,
|
206 |
+
sep2=self.sep2,
|
207 |
+
stop_str=self.stop_str,
|
208 |
+
stop_token_ids=self.stop_token_ids,
|
209 |
+
)
|
210 |
+
|
211 |
+
def dict(self):
|
212 |
+
return {
|
213 |
+
"template_name": self.name,
|
214 |
+
"system_message": self.system_message,
|
215 |
+
"roles": self.roles,
|
216 |
+
"messages": self.messages,
|
217 |
+
"offset": self.offset,
|
218 |
+
}
|
219 |
+
|
220 |
+
|
221 |
+
# A global registry for all conversation templates
|
222 |
+
conv_templates: Dict[str, Conversation] = {}
|
223 |
+
|
224 |
+
|
225 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
226 |
+
"""Register a new conversation template."""
|
227 |
+
if not override:
|
228 |
+
assert (
|
229 |
+
template.name not in conv_templates
|
230 |
+
), f"{template.name} has been registered."
|
231 |
+
|
232 |
+
conv_templates[template.name] = template
|
233 |
+
|
234 |
+
|
235 |
+
def get_conv_template(name: str) -> Conversation:
|
236 |
+
"""Get a conversation template."""
|
237 |
+
return conv_templates[name].copy()
|
238 |
+
|
239 |
+
|
240 |
+
# llava_llama2 template
|
241 |
+
register_conv_template(
|
242 |
+
Conversation(
|
243 |
+
name="llava_llama2",
|
244 |
+
system_message="You are a helpful language and vision assistant. "
|
245 |
+
"You are able to understand the visual content that the user provides, "
|
246 |
+
"and assist the user with a variety of tasks using natural language.",
|
247 |
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
248 |
+
roles=("[INST]", "[/INST]"),
|
249 |
+
messages=(),
|
250 |
+
offset=0,
|
251 |
+
sep_style=SeparatorStyle.LLAMA2,
|
252 |
+
sep=" ",
|
253 |
+
sep2=" </s><s>",
|
254 |
+
stop_token_ids=[2],
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
# llama2 template
|
259 |
+
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
260 |
+
register_conv_template(
|
261 |
+
Conversation(
|
262 |
+
name="llama-2",
|
263 |
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
264 |
+
roles=("[INST]", "[/INST]"),
|
265 |
+
messages=(),
|
266 |
+
offset=0,
|
267 |
+
sep_style=SeparatorStyle.LLAMA2,
|
268 |
+
sep=" ",
|
269 |
+
sep2=" </s><s>",
|
270 |
+
stop_token_ids=[2],
|
271 |
+
)
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
# deepseek template
|
276 |
+
register_conv_template(
|
277 |
+
Conversation(
|
278 |
+
name="deepseek_old",
|
279 |
+
system_template="{system_message}",
|
280 |
+
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
281 |
+
# "thinking step by step to be sure you get the right answer.",
|
282 |
+
system_message="",
|
283 |
+
roles=("User", "Assistant"),
|
284 |
+
messages=(),
|
285 |
+
offset=0,
|
286 |
+
sep_style=SeparatorStyle.DeepSeek,
|
287 |
+
sep="\n\n",
|
288 |
+
sep2="<|end▁of▁sentence|>",
|
289 |
+
stop_token_ids=[100001],
|
290 |
+
stop_str=["User:", "<|end▁of▁sentence|>"],
|
291 |
+
)
|
292 |
+
)
|
293 |
+
register_conv_template(
|
294 |
+
Conversation(
|
295 |
+
name="deepseek",
|
296 |
+
system_template="{system_message}",
|
297 |
+
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
298 |
+
# "thinking step by step to be sure you get the right answer.",
|
299 |
+
system_message="",
|
300 |
+
roles=("<|User|>", "<|Assistant|>"),
|
301 |
+
messages=(),
|
302 |
+
offset=0,
|
303 |
+
sep_style=SeparatorStyle.DeepSeek,
|
304 |
+
sep="\n\n",
|
305 |
+
sep2="<|end▁of▁sentence|>",
|
306 |
+
stop_token_ids=[100001],
|
307 |
+
stop_str=["<|User|>", "<|end▁of▁sentence|>"]
|
308 |
+
)
|
309 |
+
)
|
310 |
+
|
311 |
+
register_conv_template(
|
312 |
+
Conversation(
|
313 |
+
name="plain",
|
314 |
+
system_template="",
|
315 |
+
system_message="",
|
316 |
+
roles=("", ""),
|
317 |
+
messages=(),
|
318 |
+
offset=0,
|
319 |
+
sep_style=SeparatorStyle.PLAIN,
|
320 |
+
sep="",
|
321 |
+
sep2="",
|
322 |
+
stop_token_ids=[2],
|
323 |
+
stop_str=["</s>"],
|
324 |
+
)
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
register_conv_template(
|
329 |
+
Conversation(
|
330 |
+
name="alignment",
|
331 |
+
system_template="",
|
332 |
+
system_message="",
|
333 |
+
roles=("", ""),
|
334 |
+
messages=(),
|
335 |
+
offset=0,
|
336 |
+
sep_style=SeparatorStyle.ALIGNMENT,
|
337 |
+
sep="",
|
338 |
+
sep2="",
|
339 |
+
stop_token_ids=[2],
|
340 |
+
stop_str=["</s>"],
|
341 |
+
)
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
if __name__ == "__main__":
|
346 |
+
# print("Llama-2 template:")
|
347 |
+
# conv = get_conv_template("llama-2")
|
348 |
+
# conv.set_system_message("You are a helpful, respectful and honest assistant.")
|
349 |
+
# conv.append_message(conv.roles[0], "Hello!")
|
350 |
+
# conv.append_message(conv.roles[1], "Hi!")
|
351 |
+
# conv.append_message(conv.roles[0], "How are you?")
|
352 |
+
# conv.append_message(conv.roles[1], None)
|
353 |
+
# print(conv.get_prompt())
|
354 |
+
|
355 |
+
# print("\n")
|
356 |
+
|
357 |
+
print("deepseek template:")
|
358 |
+
conv = get_conv_template("deepseek")
|
359 |
+
conv.append_message(conv.roles[0], "Hello!")
|
360 |
+
conv.append_message(conv.roles[1], "Hi! This is Tony.")
|
361 |
+
conv.append_message(conv.roles[0], "Who are you?")
|
362 |
+
conv.append_message(conv.roles[1], "I am a helpful assistant.")
|
363 |
+
conv.append_message(conv.roles[0], "How are you?")
|
364 |
+
conv.append_message(conv.roles[1], None)
|
365 |
+
print(conv.get_prompt())
|
janus/utils/io.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
import json
|
21 |
+
from typing import Dict, List
|
22 |
+
|
23 |
+
import PIL.Image
|
24 |
+
import torch
|
25 |
+
import base64
|
26 |
+
import io
|
27 |
+
from transformers import AutoModelForCausalLM
|
28 |
+
|
29 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
30 |
+
|
31 |
+
|
32 |
+
def load_pretrained_model(model_path: str):
|
33 |
+
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
34 |
+
tokenizer = vl_chat_processor.tokenizer
|
35 |
+
|
36 |
+
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
37 |
+
model_path, trust_remote_code=True
|
38 |
+
)
|
39 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
40 |
+
|
41 |
+
return tokenizer, vl_chat_processor, vl_gpt
|
42 |
+
|
43 |
+
|
44 |
+
def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
|
45 |
+
"""
|
46 |
+
|
47 |
+
Support file path or base64 images.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
|
51 |
+
[
|
52 |
+
{
|
53 |
+
"role": "User",
|
54 |
+
"content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
|
55 |
+
"images": ["./examples/table_datasets.png"]
|
56 |
+
},
|
57 |
+
{"role": "Assistant", "content": ""},
|
58 |
+
]
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
pil_images (List[PIL.Image.Image]): the list of PIL images.
|
62 |
+
|
63 |
+
"""
|
64 |
+
|
65 |
+
pil_images = []
|
66 |
+
|
67 |
+
for message in conversations:
|
68 |
+
if "images" not in message:
|
69 |
+
continue
|
70 |
+
|
71 |
+
for image_data in message["images"]:
|
72 |
+
if image_data.startswith("data:image"):
|
73 |
+
# Image data is in base64 format
|
74 |
+
_, image_data = image_data.split(",", 1)
|
75 |
+
image_bytes = base64.b64decode(image_data)
|
76 |
+
pil_img = PIL.Image.open(io.BytesIO(image_bytes))
|
77 |
+
else:
|
78 |
+
# Image data is a file path
|
79 |
+
pil_img = PIL.Image.open(image_data)
|
80 |
+
pil_img = pil_img.convert("RGB")
|
81 |
+
pil_images.append(pil_img)
|
82 |
+
|
83 |
+
return pil_images
|
84 |
+
|
85 |
+
|
86 |
+
def load_json(filepath):
|
87 |
+
with open(filepath, "r") as f:
|
88 |
+
data = json.load(f)
|
89 |
+
return data
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
diffusers
|
3 |
+
gradio
|
4 |
+
numpy
|
5 |
+
torch
|
6 |
+
safetensors
|
7 |
+
transformers
|
8 |
+
git+https://github.com/deepseek-ai/Janus
|
weights/RealESRGAN_x2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c830d067d54fc767b9543a8432f36d91bc2de313584e8bbfe4ac26a47339e899
|
3 |
+
size 67061725
|