Spaces:
Running
on
Zero
Running
on
Zero
Upload 31 files
Browse files- .gitattributes +2 -0
- Liquid_icon.png +3 -0
- README.md +0 -14
- app.py +343 -51
- baklava.png +3 -0
- chameleon/__init__.py +4 -0
- chameleon/download_data.py +88 -0
- chameleon/inference/__init__.py +4 -0
- chameleon/inference/alignment.py +79 -0
- chameleon/inference/chameleon.py +673 -0
- chameleon/inference/cudagraph.py +85 -0
- chameleon/inference/examples/batch.py +23 -0
- chameleon/inference/examples/multimodal_input.py +28 -0
- chameleon/inference/examples/simple.py +22 -0
- chameleon/inference/examples/streaming.py +22 -0
- chameleon/inference/examples/streaming_batch.py +24 -0
- chameleon/inference/generation.py +162 -0
- chameleon/inference/image_tokenizer.py +127 -0
- chameleon/inference/loader.py +71 -0
- chameleon/inference/logits_processor.py +336 -0
- chameleon/inference/model_adapter.py +118 -0
- chameleon/inference/stopping_criteria.py +55 -0
- chameleon/inference/token_selector.py +47 -0
- chameleon/inference/transformer.py +421 -0
- chameleon/inference/utils.py +34 -0
- chameleon/inference/vocab.py +122 -0
- chameleon/inference/vqgan.py +675 -0
- chameleon/vqgan.ckpt +3 -0
- chameleon/vqgan.yaml +57 -0
- conversation.py +460 -0
- helpers.py +99 -0
- requirements.txt +6 -1
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
baklava.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Liquid_icon.png filter=lfs diff=lfs merge=lfs -text
|
Liquid_icon.png
ADDED
![]() |
Git LFS Details
|
README.md
CHANGED
@@ -1,14 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Liquid Demo
|
3 |
-
emoji: 💬
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.0.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
short_description: A unified understanding and generation multimodal model
|
12 |
-
---
|
13 |
-
|
14 |
-
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,64 +1,356 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
"""
|
5 |
-
|
|
|
|
|
|
|
|
|
6 |
"""
|
7 |
-
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
8 |
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
if val[1]:
|
24 |
-
messages.append({"role": "assistant", "content": val[1]})
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
-
response = ""
|
29 |
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from threading import Thread
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
import torch
|
6 |
+
import PIL
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
9 |
+
import torch
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
import os
|
12 |
+
from tqdm import tqdm
|
13 |
+
from chameleon.inference.image_tokenizer import ImageTokenizer
|
14 |
+
from helpers import sample, expand2square, tokenizer_image_token
|
15 |
+
|
16 |
+
# from transformers import AutoProcessor, LlavaForConditionalGeneration
|
17 |
+
from transformers import TextIteratorStreamer
|
18 |
+
from conversation import conv_templates
|
19 |
+
import spaces
|
20 |
+
|
21 |
+
|
22 |
+
import os
|
23 |
+
os.system("pip uninstall -y gradio")
|
24 |
+
os.system("pip install gradio==4.44.1")
|
25 |
+
os.system("pip install gradio_client==1.3.0")
|
26 |
+
|
27 |
|
28 |
+
IMAGE_TOKEN_INDEX=-200
|
29 |
+
PLACEHOLDER = """
|
30 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
31 |
+
<img src='file/Liquid_icon.png' style="width: 80%; max-width: 600px; height: auto; opacity: 0.5;">
|
32 |
+
<h1 style="font-size: 20px; margin-bottom: 1px; opacity: 0.55;">Liquid-7B</h1>
|
33 |
+
</div>
|
34 |
"""
|
35 |
+
|
36 |
+
CSS ="""
|
37 |
+
.contain { display: flex; flex-direction: column; }
|
38 |
+
#component-0 { height: 100%; }
|
39 |
+
#chatbot { flex-grow: 1; }
|
40 |
"""
|
|
|
41 |
|
42 |
|
43 |
+
title_html = """
|
44 |
+
<div style="display: flex; flex-direction: column; align-items: center; gap: 10px;">
|
45 |
+
<h1 style="margin: 0; line-height: 1; text-align: center;"> Liquid: Language Models are Scalable Multi-modal <br> Generators via Unified Understanding and Generation</h1>
|
46 |
+
</div>
|
47 |
+
"""
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
links_html = f"""
|
50 |
+
<center><font size=3><a href='https://foundationvision.github.io/Liquid/'>Liquid</a> has been open-sourced on <a href='https://huggingface.co/Junfeng5/Liquid_V1_7B'>😊 Huggingface</a> and <a href='https://github.com/FoundationVision/Liquid'>🌟 GitHub</a>. If you find Liquid useful, a like❤️ or a star🌟 would be appreciated.</font></center>
|
51 |
+
"""
|
|
|
|
|
52 |
|
53 |
+
introduction = f"""
|
54 |
+
Liquid explores the potential of a single LLM as a multimodal generator and its scaling laws. It achieves the level of diffusion models in visual generation and discovers the mutual enhancement between understanding and generation. More details can be found on the project <a href='https://foundationvision.github.io/Liquid/'> homepage</a> and in the <a href='https://arxiv.org/abs/2412.04332'> paper</a>. """
|
55 |
|
|
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
model_id = 'Junfeng5/Liquid_V1_7B'
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id,padding_side='left')
|
60 |
+
vqllm = AutoModelForCausalLM.from_pretrained(
|
61 |
+
model_id,
|
62 |
+
attn_implementation='flash_attention_2',
|
63 |
+
torch_dtype=torch.bfloat16,
|
64 |
+
load_in_8bit=True,
|
65 |
+
max_memory={0: "40GiB" },
|
66 |
+
) # .to("cuda:0")
|
67 |
|
68 |
+
stop_flag = False
|
69 |
|
70 |
+
ori_vocabe_size = len(tokenizer)
|
71 |
+
|
72 |
+
vqgan_cfg_path = "chameleon/vqgan.yaml"
|
73 |
+
vqgan_ckpt_path = "chameleon/vqgan.ckpt"
|
74 |
+
image_tokenizer = ImageTokenizer( cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device="cuda:0",)
|
75 |
+
|
76 |
+
@spaces.GPU
|
77 |
+
def bot_streaming_I2T(message, history):
|
78 |
+
print(message)
|
79 |
+
global stop_flag
|
80 |
+
stop_flag = True
|
81 |
+
time.sleep(0.2)
|
82 |
+
stop_flag = False
|
83 |
+
torch.cuda.empty_cache()
|
84 |
+
if message["files"]:
|
85 |
+
# message["files"][-1] is a Dict or just a string
|
86 |
+
if type(message["files"][-1]) == dict:
|
87 |
+
image = message["files"][-1]["path"]
|
88 |
+
else:
|
89 |
+
image = message["files"][-1]
|
90 |
+
else:
|
91 |
+
# if there's no image uploaded for this turn, look for images in the past turns
|
92 |
+
# kept inside tuples, take the last one
|
93 |
+
for hist in history:
|
94 |
+
if type(hist[0]) == tuple:
|
95 |
+
image = hist[0][0]
|
96 |
+
try:
|
97 |
+
if image is None:
|
98 |
+
# Handle the case where image is None
|
99 |
+
gr.Error("You need to upload an image for LLaVA to work.")
|
100 |
+
except NameError:
|
101 |
+
# Handle the case where 'image' is not defined at all
|
102 |
+
gr.Error("You need to upload an image for LLaVA to work.")
|
103 |
+
|
104 |
+
qs = message['text']
|
105 |
+
qs = '<boi><image><eoi>' + '\n' + qs
|
106 |
+
conv = conv_templates['gemma'].copy()
|
107 |
+
conv.append_message(conv.roles[0], qs)
|
108 |
+
conv.append_message(conv.roles[1], None)
|
109 |
+
prompt = conv.get_prompt()
|
110 |
+
|
111 |
+
|
112 |
+
print(prompt)
|
113 |
+
image = Image.open(image).convert('RGB')
|
114 |
+
pad_image = expand2square(image, (122, 116, 104) )
|
115 |
+
input_image = pad_image.resize((512,512), PIL.Image.LANCZOS)
|
116 |
+
with torch.no_grad():
|
117 |
+
vq_code = image_tokenizer.img_tokens_from_pil(input_image)
|
118 |
+
vqcode = vq_code.cpu()
|
119 |
+
vqcode = vqcode+ len(tokenizer)
|
120 |
+
|
121 |
+
text_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
122 |
+
num_images = (text_ids == IMAGE_TOKEN_INDEX).sum()
|
123 |
+
image_token_indices = [-1] + torch.where(text_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [text_ids.shape[0]]
|
124 |
+
cur_input_ids = []
|
125 |
+
for i in range(num_images + 1):
|
126 |
+
cur_input_ids.append(text_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
127 |
+
if i < num_images:
|
128 |
+
cur_input_ids.append( vqcode )
|
129 |
+
input_ids = torch.cat(cur_input_ids, dim=0)
|
130 |
+
# input_embeddings = vqllm.embed_tokens(input_ids)
|
131 |
+
inputs = {
|
132 |
+
"input_ids":input_ids.unsqueeze(0).to("cuda:0"),
|
133 |
+
"max_new_tokens":1024,
|
134 |
+
"bos_token_id":tokenizer.bos_token_id, # Begin of sequence token
|
135 |
+
"eos_token_id":tokenizer.eos_token_id, # End of sequence token
|
136 |
+
"pad_token_id":tokenizer.pad_token_id, # Pad token
|
137 |
+
}
|
138 |
+
streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": False, "skip_prompt": True})
|
139 |
+
|
140 |
+
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
|
141 |
+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
|
142 |
+
thread = Thread(target=vqllm.generate, kwargs=generation_kwargs)
|
143 |
+
thread.start()
|
144 |
+
generated_text = ""
|
145 |
+
for new_text in streamer:
|
146 |
+
generated_text += new_text
|
147 |
+
time.sleep(0.06)
|
148 |
+
yield generated_text
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
def show_gallery(images):
|
153 |
+
gallery = gr.Gallery(images, label="Gallery", columns=4, height="auto",preview=True,scale=0.05) # 设置两行两列的布局
|
154 |
+
return gallery
|
155 |
+
|
156 |
+
@spaces.GPU
|
157 |
+
def bot_streaming_T2I(message, history,guidance_scale, temperature, top_K, top_P):
|
158 |
+
|
159 |
+
global stop_flag
|
160 |
+
stop_flag = True
|
161 |
+
time.sleep(0.2)
|
162 |
+
stop_flag = False
|
163 |
+
|
164 |
+
text_inputs = [message]*4 # generate 4 samples once
|
165 |
+
uncondition_text_inputs = ['<unconditional><boi>']*len(text_inputs)
|
166 |
+
for i in range(len(text_inputs)):
|
167 |
+
text_inputs[i] = text_inputs[i]+' Generate an image based on this description.<boi>'
|
168 |
+
|
169 |
+
ori_batchsize = len(text_inputs)
|
170 |
+
|
171 |
+
if guidance_scale>1:
|
172 |
+
model_inputs = tokenizer(text_inputs+uncondition_text_inputs, return_tensors="pt",padding=True).to("cuda:0")
|
173 |
+
else:
|
174 |
+
model_inputs = tokenizer(text_inputs, return_tensors="pt",padding=True).to("cuda:0")
|
175 |
+
with torch.no_grad():
|
176 |
+
sampling_kwargs={'temperature': temperature, 'top_k': top_K, 'top_p': top_P, 'sample_logits': True}
|
177 |
+
input_ids = model_inputs['input_ids']
|
178 |
+
cur_len = input_ids.shape[1]
|
179 |
+
model_kwargs = {'attention_mask':model_inputs['attention_mask'] , 'use_cache': True}
|
180 |
+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
181 |
+
|
182 |
+
pred_tokens = []
|
183 |
+
for i in tqdm(range(1024)):
|
184 |
+
if stop_flag:
|
185 |
+
print("generation is stoped")
|
186 |
+
del sampling_kwargs
|
187 |
+
del model_inputs
|
188 |
+
del outputs
|
189 |
+
torch.cuda.empty_cache()
|
190 |
+
break
|
191 |
+
model_inputs = vqllm.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
192 |
+
|
193 |
+
if i > 0 and guidance_scale>1:
|
194 |
+
outputs = vqllm(
|
195 |
+
**model_inputs,
|
196 |
+
return_dict=True,
|
197 |
+
output_attentions=False,
|
198 |
+
output_hidden_states=False,
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
outputs = vqllm(
|
202 |
+
**model_inputs,
|
203 |
+
return_dict=True,
|
204 |
+
output_attentions=False,
|
205 |
+
output_hidden_states=False,
|
206 |
+
)
|
207 |
+
|
208 |
+
next_token_logits = outputs.logits[:, -1:, :]
|
209 |
+
|
210 |
+
if guidance_scale>1:
|
211 |
+
cond_logits, uncond_logits = torch.split(next_token_logits, len(next_token_logits) // 2, dim=0)
|
212 |
+
cfg_logits = uncond_logits + (cond_logits - uncond_logits) * guidance_scale
|
213 |
+
half_next_token, _ = sample(cfg_logits, **sampling_kwargs)
|
214 |
+
pred_tokens.append(half_next_token)
|
215 |
+
next_token = torch.cat([half_next_token,half_next_token])
|
216 |
+
|
217 |
+
|
218 |
+
else:
|
219 |
+
next_token, next_prob = sample(next_token_logits, **sampling_kwargs)
|
220 |
+
pred_tokens.append(next_token)
|
221 |
+
|
222 |
+
# update generated ids, model inputs, and length for next step
|
223 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
224 |
+
model_kwargs = vqllm._update_model_kwargs_for_generation(
|
225 |
+
outputs,
|
226 |
+
model_kwargs,
|
227 |
+
is_encoder_decoder=vqllm.config.is_encoder_decoder,
|
228 |
+
)
|
229 |
+
|
230 |
+
del sampling_kwargs
|
231 |
+
del model_inputs
|
232 |
+
del outputs
|
233 |
+
image_vq_id = torch.cat(pred_tokens,dim=1)-ori_vocabe_size
|
234 |
+
image_vq_id = torch.clamp(image_vq_id, min=0, max=8191)
|
235 |
+
|
236 |
+
generated_image_list = []
|
237 |
+
for index, generate_id in enumerate(image_vq_id):
|
238 |
+
rec_img = image_tokenizer.pil_from_img_toks(generate_id)
|
239 |
+
generated_image_list.append(rec_img)
|
240 |
+
# rec_img.save('{}/{}.jpg'.format(image_save_pth,str(idx)))
|
241 |
+
|
242 |
+
torch.cuda.empty_cache()
|
243 |
+
# yield gr.Image(value=generated_image_list[0], label="Generated Image", show_download_button=True)
|
244 |
+
yield show_gallery(generated_image_list)
|
245 |
+
|
246 |
+
@spaces.GPU
|
247 |
+
def bot_streaming_T2T(message, history,temperature):
|
248 |
+
print(message)
|
249 |
+
global stop_flag
|
250 |
+
stop_flag = True
|
251 |
+
time.sleep(0.2)
|
252 |
+
stop_flag = False
|
253 |
+
torch.cuda.empty_cache()
|
254 |
+
qs = message
|
255 |
+
conv = conv_templates['gemma'].copy()
|
256 |
+
conv.append_message(conv.roles[0], qs)
|
257 |
+
conv.append_message(conv.roles[1], None)
|
258 |
+
prompt = conv.get_prompt()
|
259 |
+
|
260 |
+
print(prompt)
|
261 |
+
with torch.no_grad():
|
262 |
+
inputs = tokenizer([prompt], return_tensors="pt").to('cuda')
|
263 |
+
streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": False, "skip_prompt": True})
|
264 |
+
|
265 |
+
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
|
266 |
+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
|
267 |
+
thread = Thread(target=vqllm.generate, kwargs=generation_kwargs)
|
268 |
+
thread.start()
|
269 |
+
generated_text = ""
|
270 |
+
for new_text in streamer:
|
271 |
+
generated_text += new_text
|
272 |
+
yield generated_text
|
273 |
+
|
274 |
+
|
275 |
+
chatbot_T2I=gr.Chatbot(placeholder=PLACEHOLDER,height=600)
|
276 |
+
chat_input_T2I = gr.Textbox(placeholder="Enter text prompts...", show_label=False)
|
277 |
+
|
278 |
+
chatbot_I2T=gr.Chatbot(placeholder=PLACEHOLDER, scale=1)
|
279 |
+
chat_input_I2T = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
|
280 |
+
|
281 |
+
chatbot_T2T=gr.Chatbot(placeholder=PLACEHOLDER, scale=1)
|
282 |
+
chat_input_T2T = gr.Textbox(placeholder="Enter text prompts...", show_label=False)
|
283 |
+
|
284 |
+
|
285 |
+
with gr.Blocks(fill_height=True) as demo:
|
286 |
+
|
287 |
+
gr.Markdown(title_html)
|
288 |
+
gr.Markdown(links_html)
|
289 |
+
gr.Markdown(introduction)
|
290 |
+
|
291 |
+
with gr.Tab("Text To Image"):
|
292 |
+
|
293 |
+
description="Enter a text prompt or simply try one of the examples below to generate 4 images at once. Click to display the full image. You can configure hyperparameters for image generation in the Advanced Settings. "
|
294 |
+
gr.Markdown(description)
|
295 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
296 |
+
with gr.Row():
|
297 |
+
guidance_scale = gr.Slider(1.0, 20.0, value=7.0, label="Guidance Scale")
|
298 |
+
temperature = gr.Slider(0.0, 1.0, value=0.9, label="temperature")
|
299 |
+
top_K = gr.Slider(1, 8192, value=4096, label="Top K")
|
300 |
+
top_P = gr.Slider(0.0, 1.0, value=0.99, label="Top P")
|
301 |
+
|
302 |
+
aaa = gr.ChatInterface(
|
303 |
+
fn=bot_streaming_T2I,
|
304 |
+
examples=[
|
305 |
+
["young blue dragon with horn lightning in the style of dd fantasy full body",5.0, 0.9,4096,0.99],
|
306 |
+
["A majestic Goddes of beauty, charming dressed in a regal, jeweled gown and ornate crown, her golden hair cascading down her back, in the style of Pino Daeni",5.0, 0.9,4096,0.99],
|
307 |
+
["A highly realistic, closeup photograph of a beautiful 35 year old redread woman writing in her journal, sitting on her balcony wearing warm, stylish outfits. Shot on a Canon EOS R5, the image boasts sharp focus and intricate details. The heartwarming scene conveys love, connection, and the crisp winter atmosphere, dramatic lighting.",5.0, 0.9,4096,0.99],
|
308 |
+
["Portrait of an asian woman. She has pink violet hair style with modern complex hairdressing. The background is dark with cyberpunk neon lights. Inspired by Cyberpunk 2077 and Blade Runner. Ultra realistic picture. To capture the image, you will use a fullframe DSLR or mirrorless camera with a highresolution sensor, an aperture of f2.8 or wider, and a shutter speed of 1500 second or faster. You will use natural light and reflectors to create a balanced and welllit image, and will experiment with different angles and compositions to create the most i",5.0, 0.9,4096,0.99],
|
309 |
+
["female character fantasy world, for fantasy story, protagonist, interesting and detailed clothes, beautiful, medieval fantasy cinematic shot photo taken by canon, photo taken by fuji, photo taken by kodak incredibly detailed, sharpen, details professional lighting , film lighting 350mm lightroom cinematography, hyper realism, cinematic, film quality",5.0, 0.9,4096,0.99],
|
310 |
+
["strawberries splashing, swirling liquid, realism, octane render, raytracing",5.0, 0.9,4096,0.99],
|
311 |
+
["hedgehog face, floating in space, wearing space suit no helmet, cinematic, 50mm f1.8, unreal engine 5",5.0, 0.9,4096,0.99],
|
312 |
+
["artificial intelligence, revolution, publishing, writer, hyperrealistic",5.0, 0.9,4096,0.99],
|
313 |
+
["A pig dressed as a mason, by Bill Gekas",5.0, 0.9,4096,0.99],
|
314 |
+
],
|
315 |
+
stop_btn="Stop Generation",
|
316 |
+
additional_inputs = [guidance_scale, temperature, top_K, top_P],
|
317 |
+
additional_inputs_accordion="⚙️ Advanced Settings",
|
318 |
+
multimodal=False,
|
319 |
+
textbox=chat_input_T2I,
|
320 |
+
chatbot=chatbot_T2I,
|
321 |
+
fill_height=True,
|
322 |
+
)
|
323 |
+
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
with gr.Tab("Image To Text"):
|
328 |
+
bbb = gr.ChatInterface(
|
329 |
+
fn=bot_streaming_I2T,
|
330 |
+
examples=[ {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
|
331 |
+
description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
|
332 |
+
stop_btn="Stop Generation",
|
333 |
+
multimodal=True,
|
334 |
+
textbox=chat_input_I2T,
|
335 |
+
chatbot=chatbot_I2T,
|
336 |
+
)
|
337 |
+
|
338 |
+
with gr.Tab("Text To Text"):
|
339 |
+
|
340 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
341 |
+
with gr.Row():
|
342 |
+
texttemperature = gr.Slider(0.0, 1.0, value=0.9, label="texttemperature")
|
343 |
+
|
344 |
+
gr.ChatInterface(
|
345 |
+
fn=bot_streaming_T2T,
|
346 |
+
examples=[["a dog", 0.9]],
|
347 |
+
description="Chat with Liquid without images.",
|
348 |
+
stop_btn="Stop Generation",
|
349 |
+
additional_inputs = [texttemperature],
|
350 |
+
additional_inputs_accordion="⚙️ Advanced Settings",
|
351 |
+
multimodal=False,
|
352 |
+
textbox=chat_input_T2T,
|
353 |
+
chatbot=chatbot_T2T,
|
354 |
+
)
|
355 |
+
demo.queue(api_open=False)
|
356 |
+
demo.launch(allowed_paths=["./"], share=False )
|
baklava.png
ADDED
![]() |
Git LFS Details
|
chameleon/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
chameleon/download_data.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Chameleon License Agreement.
|
3 |
+
|
4 |
+
import hashlib
|
5 |
+
from pathlib import Path
|
6 |
+
import subprocess
|
7 |
+
import sys
|
8 |
+
|
9 |
+
|
10 |
+
def download_file(url: str, output_path: Path):
|
11 |
+
print(f"Downloading {output_path}")
|
12 |
+
subprocess.check_call(["wget", "--continue", url, "-O", str(output_path)])
|
13 |
+
|
14 |
+
|
15 |
+
def validate_checksum(folder: Path):
|
16 |
+
chks_parts = (folder / "checklist.chk").read_text().split()
|
17 |
+
for expected_checksum, file in zip(chks_parts[::2], chks_parts[1::2]):
|
18 |
+
file_path = folder / file
|
19 |
+
checksum = hashlib.md5(file_path.read_bytes()).hexdigest()
|
20 |
+
if checksum != expected_checksum:
|
21 |
+
print(f"Checksum mismatch for {file_path}")
|
22 |
+
sys.exit(1)
|
23 |
+
|
24 |
+
|
25 |
+
def download_tokenizer(presigned_url: str, target_folder: Path):
|
26 |
+
tokenizer_folder = target_folder / "tokenizer"
|
27 |
+
tokenizer_folder.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
for filename in [
|
30 |
+
"text_tokenizer.json",
|
31 |
+
"vqgan.ckpt",
|
32 |
+
"vqgan.yaml",
|
33 |
+
"checklist.chk",
|
34 |
+
]:
|
35 |
+
download_file(
|
36 |
+
presigned_url.replace("*", f"tokenizer/{filename}"),
|
37 |
+
tokenizer_folder / filename,
|
38 |
+
)
|
39 |
+
|
40 |
+
validate_checksum(tokenizer_folder)
|
41 |
+
|
42 |
+
|
43 |
+
def download_model(presigned_url: str, target_folder: Path, model: str):
|
44 |
+
model_folder = target_folder / "models" / model
|
45 |
+
model_folder.mkdir(parents=True, exist_ok=True)
|
46 |
+
|
47 |
+
download_filenames = ["params.json", "consolidate_params.json", "checklist.chk"]
|
48 |
+
|
49 |
+
if model == "7b":
|
50 |
+
download_filenames += ["consolidated.pth"]
|
51 |
+
elif model == "30b":
|
52 |
+
download_filenames += [f"consolidated.{i:02}.pth" for i in range(4)]
|
53 |
+
else:
|
54 |
+
print(f"Unknown model: {model}")
|
55 |
+
sys.exit(1)
|
56 |
+
|
57 |
+
for filename in download_filenames:
|
58 |
+
download_file(
|
59 |
+
presigned_url.replace("*", f"{model}/{filename}"),
|
60 |
+
model_folder / filename,
|
61 |
+
)
|
62 |
+
|
63 |
+
validate_checksum(model_folder)
|
64 |
+
|
65 |
+
|
66 |
+
def main():
|
67 |
+
presigned_url = (
|
68 |
+
sys.argv[1] if len(sys.argv) > 1 else input("Enter the URL from email: ")
|
69 |
+
)
|
70 |
+
|
71 |
+
target_folder = Path("./data")
|
72 |
+
target_folder.mkdir(parents=True, exist_ok=True)
|
73 |
+
|
74 |
+
download_tokenizer(presigned_url, target_folder)
|
75 |
+
|
76 |
+
model_size = input(
|
77 |
+
"Enter the list of models to download without spaces (7B,30B), or press Enter for all: "
|
78 |
+
)
|
79 |
+
if not model_size:
|
80 |
+
model_size = "7B,30B"
|
81 |
+
|
82 |
+
for model in model_size.split(","):
|
83 |
+
model = model.strip().lower()
|
84 |
+
download_model(presigned_url, target_folder, model)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
main()
|
chameleon/inference/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
chameleon/inference/alignment.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class PromptAlignment(ABC):
|
12 |
+
@abstractmethod
|
13 |
+
def start_index(self, input_ids: list[list[int]]) -> int:
|
14 |
+
...
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
|
18 |
+
...
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def postprocess_inputs(
|
22 |
+
self, inputs: torch.Tensor, original_inputs: torch.Tensor
|
23 |
+
) -> torch.Tensor:
|
24 |
+
...
|
25 |
+
|
26 |
+
|
27 |
+
class AlignPromptRight(PromptAlignment):
|
28 |
+
def __init__(self, pad_id: int):
|
29 |
+
self.pad_id = pad_id
|
30 |
+
|
31 |
+
def start_index(self, input_ids: list[list[int]]) -> int:
|
32 |
+
return max(len(sublist) for sublist in input_ids)
|
33 |
+
|
34 |
+
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.LongTensor:
|
35 |
+
max_length = max(len(sublist) for sublist in input_ids)
|
36 |
+
return torch.tensor(
|
37 |
+
[
|
38 |
+
([self.pad_id] * (max_length - len(sublist))) + sublist
|
39 |
+
for sublist in input_ids
|
40 |
+
],
|
41 |
+
requires_grad=False,
|
42 |
+
)
|
43 |
+
|
44 |
+
def postprocess_inputs(
|
45 |
+
self,
|
46 |
+
inputs: torch.Tensor,
|
47 |
+
original_inputs: torch.Tensor,
|
48 |
+
) -> torch.Tensor:
|
49 |
+
return inputs
|
50 |
+
|
51 |
+
|
52 |
+
class AlignPromptLeft(PromptAlignment):
|
53 |
+
def __init__(self, pad_id: int = -1):
|
54 |
+
self.pad_id = pad_id
|
55 |
+
|
56 |
+
def start_index(self, input_ids: list[list[int]]) -> int:
|
57 |
+
return min(len(sublist) for sublist in input_ids)
|
58 |
+
|
59 |
+
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
|
60 |
+
max_length = max(len(sublist) for sublist in input_ids)
|
61 |
+
return torch.tensor(
|
62 |
+
[
|
63 |
+
sublist + ([self.pad_id] * (max_length - len(sublist)))
|
64 |
+
for sublist in input_ids
|
65 |
+
],
|
66 |
+
requires_grad=False,
|
67 |
+
)
|
68 |
+
|
69 |
+
def postprocess_inputs(
|
70 |
+
self,
|
71 |
+
inputs: torch.Tensor,
|
72 |
+
original_inputs: torch.Tensor,
|
73 |
+
) -> torch.Tensor:
|
74 |
+
max_init_len = original_inputs.shape[1]
|
75 |
+
if inputs.shape[1] <= max_init_len:
|
76 |
+
original_inputs_limited = original_inputs[:, : inputs.shape[1]]
|
77 |
+
mask = original_inputs_limited != self.pad_id
|
78 |
+
inputs[mask] = original_inputs_limited[mask]
|
79 |
+
return inputs
|
chameleon/inference/chameleon.py
ADDED
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import base64
|
7 |
+
import io
|
8 |
+
import json
|
9 |
+
import math
|
10 |
+
import queue
|
11 |
+
import threading
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from enum import Enum
|
14 |
+
from multiprocessing import managers, queues, synchronize
|
15 |
+
from typing import Literal, Union
|
16 |
+
|
17 |
+
import PIL
|
18 |
+
import torch
|
19 |
+
import torch.distributed as dist
|
20 |
+
import torch.multiprocessing as mp
|
21 |
+
from PIL.Image import Image
|
22 |
+
from tokenizers import Tokenizer
|
23 |
+
from transformers import (
|
24 |
+
LogitsProcessor,
|
25 |
+
RepetitionPenaltyLogitsProcessor,
|
26 |
+
TemperatureLogitsWarper,
|
27 |
+
TopPLogitsWarper,
|
28 |
+
enable_full_determinism,
|
29 |
+
)
|
30 |
+
|
31 |
+
from chameleon.inference import loader
|
32 |
+
from chameleon.inference.alignment import AlignPromptRight
|
33 |
+
from chameleon.inference.generation import ChameleonGenerator
|
34 |
+
from chameleon.inference.image_tokenizer import ImageTokenizer
|
35 |
+
from chameleon.inference.logits_processor import (
|
36 |
+
AllowOnlyTokensLogitsProcessor,
|
37 |
+
DisallowTokensAtOrAfterIndexLogitsProcessor,
|
38 |
+
InBatchInstructCFGLogitsProcessor,
|
39 |
+
)
|
40 |
+
from chameleon.inference.model_adapter import ChameleonModelAdapter
|
41 |
+
from chameleon.inference.stopping_criteria import (
|
42 |
+
MaxLengthCriteria,
|
43 |
+
StopOnEOSAfterBatchIndex,
|
44 |
+
)
|
45 |
+
from chameleon.inference.token_selector import (
|
46 |
+
ArgmaxTokenSelector,
|
47 |
+
MultinomialTokenSelector,
|
48 |
+
ReplicatedInputTokenSelector,
|
49 |
+
)
|
50 |
+
from chameleon.inference.transformer import Transformer
|
51 |
+
from chameleon.inference.utils import DynamicGenerator, advance, random_unused_port
|
52 |
+
from chameleon.inference.vocab import VocabInfo, VocabTranslation
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class Options:
|
57 |
+
@dataclass
|
58 |
+
class Text:
|
59 |
+
repetition_penalty: float = 1.2
|
60 |
+
temp: float = 0.7
|
61 |
+
top_p: float = 0.9
|
62 |
+
greedy: bool = False
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class Image:
|
66 |
+
@dataclass
|
67 |
+
class CFG:
|
68 |
+
guidance_scale_text: float = 3.0
|
69 |
+
guidance_scale_image: float = 1.2
|
70 |
+
|
71 |
+
cfg: CFG = field(default_factory=CFG)
|
72 |
+
temp: float = 0.7
|
73 |
+
top_p: float = 0.9
|
74 |
+
greedy: bool = False
|
75 |
+
|
76 |
+
max_seq_len: int = 4096
|
77 |
+
max_gen_len: int = 4096
|
78 |
+
seed: int | None = None
|
79 |
+
txt: Text | bool = True
|
80 |
+
img: Image | bool = False
|
81 |
+
extra_eos_tokens: list[int | str] = field(default_factory=lambda: ["<racm3:break>"])
|
82 |
+
|
83 |
+
def __post_init__(self):
|
84 |
+
if self.txt == True:
|
85 |
+
self.txt = Options.Text()
|
86 |
+
if self.img == True:
|
87 |
+
self.img = Options.Image()
|
88 |
+
|
89 |
+
|
90 |
+
class TokenManager:
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
tokenizer_path: str,
|
94 |
+
vqgan_cfg_path: str,
|
95 |
+
vqgan_ckpt_path: str,
|
96 |
+
device: str | None = None,
|
97 |
+
):
|
98 |
+
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
99 |
+
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
|
100 |
+
self.translation = VocabTranslation(self.vocab, device=device)
|
101 |
+
self.image_tokenizer = ImageTokenizer(
|
102 |
+
cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device=device
|
103 |
+
)
|
104 |
+
|
105 |
+
def pil_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image:
|
106 |
+
image_tensor = self.translation.convert_bpe2img(bpe_tokens)
|
107 |
+
if image_tensor.shape[0] < 1024:
|
108 |
+
padding = (
|
109 |
+
torch.ones(
|
110 |
+
[1024 - image_tensor.shape[0]],
|
111 |
+
dtype=int,
|
112 |
+
device=image_tensor.device,
|
113 |
+
)
|
114 |
+
* image_tensor[0]
|
115 |
+
)
|
116 |
+
image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
|
117 |
+
|
118 |
+
return self.image_tokenizer.pil_from_img_toks(image_tensor)
|
119 |
+
|
120 |
+
def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
|
121 |
+
pil = self.pil_from_bpe_tokens(bpe_tokens)
|
122 |
+
img_io = io.BytesIO()
|
123 |
+
pil.save(img_io, format="PNG")
|
124 |
+
return img_io.getvalue()
|
125 |
+
|
126 |
+
def tokenize_text(self, text: str) -> list[int]:
|
127 |
+
return self.tokenizer.encode(text).ids
|
128 |
+
|
129 |
+
def tokenize_image(self, img: Image) -> list[int]:
|
130 |
+
return (
|
131 |
+
[self.vocab.begin_image]
|
132 |
+
+ self.translation.convert_img2bp2(
|
133 |
+
self.image_tokenizer.img_tokens_from_pil(img)
|
134 |
+
).tolist()
|
135 |
+
+ [self.vocab.end_image]
|
136 |
+
)
|
137 |
+
|
138 |
+
def tokenize_b64img(self, b64img: str) -> list[int]:
|
139 |
+
image_data = base64.b64decode(b64img)
|
140 |
+
image_file = io.BytesIO(image_data)
|
141 |
+
return self.tokenize_image(PIL.Image.open(image_file))
|
142 |
+
|
143 |
+
def tokens_from_ui(self, inputs: list[dict]) -> list[int]:
|
144 |
+
tokens = [self.vocab.bos_id]
|
145 |
+
for input_ in inputs:
|
146 |
+
if input_["type"] == "text":
|
147 |
+
tokens += self.tokenize_text(input_["value"])
|
148 |
+
elif input_["type"] == "image":
|
149 |
+
if type(input_["value"]) == str:
|
150 |
+
if input_["value"].startswith("data:"):
|
151 |
+
# Value Format: 'data:image/[^;]+;base64,[A-Za-z0-9+/]+={0,2}'
|
152 |
+
tokens += self.tokenize_b64img(input_["value"].split(",", 1)[1])
|
153 |
+
elif input_["value"].startswith("file:"):
|
154 |
+
tokens += self.tokenize_image(
|
155 |
+
PIL.Image.open(input_["value"].split(":", 1)[1])
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
raise ValueError("Unknown image format.")
|
159 |
+
elif type(input_["value"]) == Image:
|
160 |
+
tokens += self.tokenize_image(input_["value"])
|
161 |
+
else:
|
162 |
+
raise ValueError("Unknown image type.")
|
163 |
+
elif input_["type"] == "sentinel":
|
164 |
+
tokens += [
|
165 |
+
{
|
166 |
+
"<START-OF-IMAGE>": self.vocab.begin_image,
|
167 |
+
"<END-OF-TURN>": self.vocab.eot_id,
|
168 |
+
}[input_["value"]]
|
169 |
+
]
|
170 |
+
elif input_["type"] == "ids":
|
171 |
+
tokens += input_["value"]
|
172 |
+
else:
|
173 |
+
raise ValueError("Unknown input type.")
|
174 |
+
return tokens
|
175 |
+
|
176 |
+
def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
|
177 |
+
if isinstance(ids, torch.Tensor):
|
178 |
+
ids = ids.tolist()
|
179 |
+
|
180 |
+
for row, values in enumerate(ids):
|
181 |
+
try:
|
182 |
+
ids[row] = values[: values.index(self.vocab.eos_id)]
|
183 |
+
except ValueError:
|
184 |
+
pass
|
185 |
+
|
186 |
+
return self.tokenizer.decode_batch(ids)
|
187 |
+
|
188 |
+
def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
|
189 |
+
return [self.pil_from_bpe_tokens(sample) for sample in ids]
|
190 |
+
|
191 |
+
|
192 |
+
@dataclass
|
193 |
+
class DecodePiece:
|
194 |
+
token: ChameleonGenerator.Token
|
195 |
+
next_decoder: type["Decoder"] | None
|
196 |
+
|
197 |
+
|
198 |
+
class Decoder:
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
model: Transformer,
|
202 |
+
vocab: VocabInfo,
|
203 |
+
options: Options,
|
204 |
+
input_ids: list[int],
|
205 |
+
): ...
|
206 |
+
|
207 |
+
def __next__(self) -> DecodePiece: ...
|
208 |
+
|
209 |
+
|
210 |
+
class TextDecoder(Decoder):
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
model: Transformer,
|
214 |
+
vocab: VocabInfo,
|
215 |
+
options: Options,
|
216 |
+
input_ids: list[list[int]],
|
217 |
+
):
|
218 |
+
self.vocab = vocab
|
219 |
+
self.options = options
|
220 |
+
assert vocab.eos_id is not None
|
221 |
+
|
222 |
+
prompt_lens = [len(inp) for inp in input_ids]
|
223 |
+
max_prompt_len = max(prompt_lens)
|
224 |
+
max_seq_len = min(options.max_seq_len, max_prompt_len + options.max_gen_len)
|
225 |
+
|
226 |
+
self.eos_ids = [vocab.eos_id]
|
227 |
+
for extra_eos_token in options.extra_eos_tokens:
|
228 |
+
if isinstance(extra_eos_token, str):
|
229 |
+
extra_eos_token = vocab.name2val[extra_eos_token]
|
230 |
+
assert isinstance(extra_eos_token, int)
|
231 |
+
self.eos_ids.append(extra_eos_token)
|
232 |
+
|
233 |
+
stopping_criteria = [
|
234 |
+
MaxLengthCriteria(max_seq_len),
|
235 |
+
] + [StopOnEOSAfterBatchIndex(eos_id, [max_prompt_len] * len(prompt_lens)) for eos_id in self.eos_ids]
|
236 |
+
|
237 |
+
self.gen = ChameleonGenerator(
|
238 |
+
model=ChameleonModelAdapter(model, max_seq_len=max_seq_len),
|
239 |
+
input_ids=input_ids,
|
240 |
+
stopping_criteria=stopping_criteria,
|
241 |
+
logits_processors=self._logits_processors(),
|
242 |
+
alignment=AlignPromptRight(vocab.pad_id),
|
243 |
+
token_selector=(
|
244 |
+
ArgmaxTokenSelector()
|
245 |
+
if options.txt.greedy
|
246 |
+
else MultinomialTokenSelector()
|
247 |
+
),
|
248 |
+
)
|
249 |
+
advance(self.gen, max_prompt_len)
|
250 |
+
|
251 |
+
def _allowed_tokens(self) -> list[int]:
|
252 |
+
allowed_tokens = [self.vocab.eos_id]
|
253 |
+
if self.options.txt:
|
254 |
+
allowed_tokens += self.vocab.text_tokens
|
255 |
+
if self.options.img:
|
256 |
+
allowed_tokens += [self.vocab.begin_image]
|
257 |
+
return allowed_tokens
|
258 |
+
|
259 |
+
def _logits_processors(self) -> list[LogitsProcessor]:
|
260 |
+
logits_processors = [
|
261 |
+
AllowOnlyTokensLogitsProcessor(self._allowed_tokens()),
|
262 |
+
]
|
263 |
+
if isinstance(self.options.img, Options.Image):
|
264 |
+
logits_processors += [
|
265 |
+
DisallowTokensAtOrAfterIndexLogitsProcessor(
|
266 |
+
[self.vocab.begin_image],
|
267 |
+
self.options.max_seq_len - 1026,
|
268 |
+
),
|
269 |
+
]
|
270 |
+
if isinstance(self.options.txt, Options.Text):
|
271 |
+
logits_processors += [
|
272 |
+
RepetitionPenaltyLogitsProcessor(self.options.txt.repetition_penalty),
|
273 |
+
TemperatureLogitsWarper(self.options.txt.temp),
|
274 |
+
TopPLogitsWarper(self.options.txt.top_p),
|
275 |
+
]
|
276 |
+
return logits_processors
|
277 |
+
|
278 |
+
def __next__(self) -> DecodePiece:
|
279 |
+
tok = next(self.gen)
|
280 |
+
next_decoder = None
|
281 |
+
if (
|
282 |
+
self.vocab.begin_image not in self.eos_ids
|
283 |
+
and (tok.id == self.vocab.begin_image).all()
|
284 |
+
):
|
285 |
+
next_decoder = ImageDecoder
|
286 |
+
return DecodePiece(tok, next_decoder)
|
287 |
+
|
288 |
+
|
289 |
+
class ImageDecoder(Decoder):
|
290 |
+
def __init__(
|
291 |
+
self,
|
292 |
+
model: Transformer,
|
293 |
+
vocab: VocabInfo,
|
294 |
+
options: Options,
|
295 |
+
input_ids: list[list[int]],
|
296 |
+
):
|
297 |
+
assert isinstance(options.img, Options.Image)
|
298 |
+
self.vocab = vocab
|
299 |
+
self.options = options
|
300 |
+
self.batch_size = len(input_ids)
|
301 |
+
logits_processors = [
|
302 |
+
InBatchInstructCFGLogitsProcessor(
|
303 |
+
options.img.cfg.guidance_scale_text,
|
304 |
+
options.img.cfg.guidance_scale_image,
|
305 |
+
),
|
306 |
+
AllowOnlyTokensLogitsProcessor(vocab.image_tokens),
|
307 |
+
TemperatureLogitsWarper(options.img.temp),
|
308 |
+
TopPLogitsWarper(options.img.top_p),
|
309 |
+
]
|
310 |
+
|
311 |
+
for inp in input_ids:
|
312 |
+
if inp[-1] != self.vocab.begin_image:
|
313 |
+
inp.append(self.vocab.begin_image)
|
314 |
+
|
315 |
+
max_prompt_len = max(len(inp) for inp in input_ids)
|
316 |
+
self.gen = ChameleonGenerator(
|
317 |
+
model=ChameleonModelAdapter(model, max_seq_len=max_prompt_len + 1024),
|
318 |
+
input_ids=self._split_inputs_for_cfg(input_ids),
|
319 |
+
logits_processors=logits_processors,
|
320 |
+
alignment=AlignPromptRight(vocab.pad_id),
|
321 |
+
token_selector=ReplicatedInputTokenSelector(
|
322 |
+
(
|
323 |
+
ArgmaxTokenSelector()
|
324 |
+
if options.img.greedy
|
325 |
+
else MultinomialTokenSelector()
|
326 |
+
),
|
327 |
+
n=3,
|
328 |
+
),
|
329 |
+
)
|
330 |
+
advance(self.gen, max_prompt_len)
|
331 |
+
self.gen_count = 0
|
332 |
+
|
333 |
+
def _split_inputs_for_cfg(self, input_ids: list[list[int]]) -> list[list[int]]:
|
334 |
+
image_conditioned_allowed = set(self.vocab.image_tokens) | {
|
335 |
+
self.vocab.bos_id,
|
336 |
+
self.vocab.begin_image,
|
337 |
+
self.vocab.end_image,
|
338 |
+
}
|
339 |
+
|
340 |
+
full_conditioned = input_ids
|
341 |
+
|
342 |
+
image_conditioned = [
|
343 |
+
[id for id in sample if id in image_conditioned_allowed]
|
344 |
+
for sample in input_ids
|
345 |
+
]
|
346 |
+
|
347 |
+
unconditioned = [
|
348 |
+
[
|
349 |
+
self.vocab.bos_id,
|
350 |
+
self.vocab.begin_image,
|
351 |
+
]
|
352 |
+
] * self.batch_size
|
353 |
+
|
354 |
+
return full_conditioned + image_conditioned + unconditioned
|
355 |
+
|
356 |
+
def __next__(self) -> DecodePiece:
|
357 |
+
if self.gen_count == 1024:
|
358 |
+
id = torch.tensor([self.vocab.end_image] * self.batch_size)
|
359 |
+
logits = torch.full(
|
360 |
+
(self.batch_size, len(self.vocab.all_tokens)), -math.inf
|
361 |
+
)
|
362 |
+
logits[:, self.vocab.end_image] = 0
|
363 |
+
return DecodePiece(
|
364 |
+
ChameleonGenerator.Token(id=id, logits=logits),
|
365 |
+
TextDecoder,
|
366 |
+
)
|
367 |
+
|
368 |
+
tok = next(self.gen)
|
369 |
+
tok.id = tok.id.chunk(3)[0]
|
370 |
+
self.gen_count += 1
|
371 |
+
return DecodePiece(tok, None)
|
372 |
+
|
373 |
+
|
374 |
+
class Generator(Decoder):
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
model: Transformer,
|
378 |
+
vocab: VocabInfo,
|
379 |
+
options: Options,
|
380 |
+
input_ids: list[list[int]],
|
381 |
+
):
|
382 |
+
if options.seed is not None:
|
383 |
+
enable_full_determinism(options.seed, warn_only=True)
|
384 |
+
|
385 |
+
self.model = model
|
386 |
+
self.vocab = vocab
|
387 |
+
self.input_ids = input_ids[:]
|
388 |
+
self.generated_token_ids: list[torch.LongTensor] = []
|
389 |
+
self.options = options
|
390 |
+
if not self.options.txt:
|
391 |
+
self.dyngen = DynamicGenerator(
|
392 |
+
ImageDecoder(model, vocab, options, input_ids)
|
393 |
+
)
|
394 |
+
else:
|
395 |
+
self.dyngen = DynamicGenerator(
|
396 |
+
TextDecoder(model, vocab, options, input_ids)
|
397 |
+
)
|
398 |
+
|
399 |
+
def __iter__(self):
|
400 |
+
return self
|
401 |
+
|
402 |
+
def __next__(self) -> ChameleonGenerator.Token:
|
403 |
+
piece = next(self.dyngen)
|
404 |
+
self.generated_token_ids.append(piece.token.id)
|
405 |
+
if piece.next_decoder is not None:
|
406 |
+
if not self.options.txt:
|
407 |
+
raise StopIteration
|
408 |
+
|
409 |
+
self.input_ids = [
|
410 |
+
old_list + generated
|
411 |
+
for old_list, generated in zip(
|
412 |
+
self.input_ids, torch.stack(self.generated_token_ids).T.tolist()
|
413 |
+
)
|
414 |
+
]
|
415 |
+
self.generated_token_ids = []
|
416 |
+
self.dyngen.gen = piece.next_decoder(
|
417 |
+
self.model,
|
418 |
+
self.vocab,
|
419 |
+
self.options,
|
420 |
+
self.input_ids,
|
421 |
+
)
|
422 |
+
return piece.token
|
423 |
+
|
424 |
+
|
425 |
+
class DistributedMode(Enum):
|
426 |
+
AUTO = 0
|
427 |
+
THREAD = 1
|
428 |
+
PROCESS = 2
|
429 |
+
|
430 |
+
|
431 |
+
@dataclass
|
432 |
+
class _DistributedContext:
|
433 |
+
req_q: Union[queue.Queue, queues.Queue]
|
434 |
+
res_q: Union[queue.Queue, queues.Queue]
|
435 |
+
active_key: Union[dict[int, Literal[True]], managers.DictProxy]
|
436 |
+
active_key_lock: Union[threading.Lock, synchronize.Lock]
|
437 |
+
ready_barrier: Union[threading.Barrier, synchronize.Barrier]
|
438 |
+
worker_launcher: Union[type[threading.Thread], type[mp.Process]]
|
439 |
+
|
440 |
+
@staticmethod
|
441 |
+
def make_for_threading(world_size: int):
|
442 |
+
return _DistributedContext(
|
443 |
+
req_q=queue.Queue(),
|
444 |
+
res_q=queue.Queue(),
|
445 |
+
active_key={},
|
446 |
+
active_key_lock=threading.Lock(),
|
447 |
+
ready_barrier=threading.Barrier(world_size + 1),
|
448 |
+
worker_launcher=threading.Thread,
|
449 |
+
)
|
450 |
+
|
451 |
+
@staticmethod
|
452 |
+
def make_for_multiprocessing(world_size: int):
|
453 |
+
local_mp = mp.get_context("spawn")
|
454 |
+
return _DistributedContext(
|
455 |
+
req_q=local_mp.Queue(),
|
456 |
+
res_q=local_mp.Queue(),
|
457 |
+
active_key=local_mp.Manager().dict(),
|
458 |
+
active_key_lock=local_mp.Lock(),
|
459 |
+
ready_barrier=local_mp.Barrier(world_size + 1),
|
460 |
+
worker_launcher=local_mp.Process,
|
461 |
+
)
|
462 |
+
|
463 |
+
@staticmethod
|
464 |
+
def make(mode: DistributedMode, world_size: int):
|
465 |
+
if mode == DistributedMode.AUTO:
|
466 |
+
mode = DistributedMode.PROCESS
|
467 |
+
|
468 |
+
if mode == DistributedMode.THREAD:
|
469 |
+
return _DistributedContext.make_for_threading(world_size)
|
470 |
+
elif mode == DistributedMode.PROCESS:
|
471 |
+
return _DistributedContext.make_for_multiprocessing(world_size)
|
472 |
+
else:
|
473 |
+
raise ValueError("Unknown DistributedMode")
|
474 |
+
|
475 |
+
|
476 |
+
def _worker_impl(
|
477 |
+
init_method: str,
|
478 |
+
model: Transformer | str,
|
479 |
+
world_size: int,
|
480 |
+
rank: int,
|
481 |
+
vocab: VocabInfo,
|
482 |
+
dctx: _DistributedContext,
|
483 |
+
):
|
484 |
+
dist.init_process_group(
|
485 |
+
"nccl",
|
486 |
+
init_method=init_method,
|
487 |
+
world_size=world_size,
|
488 |
+
rank=rank,
|
489 |
+
)
|
490 |
+
|
491 |
+
torch.set_default_device(f"cuda:{rank}")
|
492 |
+
torch.cuda.set_device(rank)
|
493 |
+
if isinstance(model, str):
|
494 |
+
model = loader.load_model(model, rank=rank)
|
495 |
+
dctx.ready_barrier.wait()
|
496 |
+
|
497 |
+
is_coord = rank == 0
|
498 |
+
|
499 |
+
while True:
|
500 |
+
req = [Options(), [], 0, False]
|
501 |
+
if is_coord:
|
502 |
+
req = dctx.req_q.get()
|
503 |
+
|
504 |
+
dist.broadcast_object_list(req, src=0)
|
505 |
+
options, input_ids, key, shutdown = req
|
506 |
+
if shutdown:
|
507 |
+
break
|
508 |
+
|
509 |
+
for token in Generator(
|
510 |
+
model=model,
|
511 |
+
vocab=vocab,
|
512 |
+
options=options,
|
513 |
+
input_ids=input_ids,
|
514 |
+
):
|
515 |
+
if is_coord:
|
516 |
+
dctx.res_q.put((key, token))
|
517 |
+
|
518 |
+
to_continue = [True]
|
519 |
+
if is_coord:
|
520 |
+
with dctx.active_key_lock:
|
521 |
+
to_continue = [key in dctx.active_key]
|
522 |
+
dist.broadcast_object_list(to_continue, src=0)
|
523 |
+
if not to_continue[0]:
|
524 |
+
break
|
525 |
+
|
526 |
+
if is_coord:
|
527 |
+
dctx.res_q.put((key, None))
|
528 |
+
|
529 |
+
|
530 |
+
class ChameleonInferenceModel:
|
531 |
+
def __init__(
|
532 |
+
self,
|
533 |
+
model: Transformer | str,
|
534 |
+
tokenizer_path: str,
|
535 |
+
vqgan_cfg_path: str,
|
536 |
+
vqgan_ckpt_path: str,
|
537 |
+
*,
|
538 |
+
options: Options | None = None,
|
539 |
+
distributed_mode: DistributedMode = DistributedMode.AUTO,
|
540 |
+
):
|
541 |
+
self.options = options or Options()
|
542 |
+
self.next_key = 0
|
543 |
+
|
544 |
+
self.token_manager = TokenManager(
|
545 |
+
tokenizer_path=tokenizer_path,
|
546 |
+
vqgan_cfg_path=vqgan_cfg_path,
|
547 |
+
vqgan_ckpt_path=vqgan_ckpt_path,
|
548 |
+
device="cuda",
|
549 |
+
)
|
550 |
+
self.vocab = self.token_manager.vocab
|
551 |
+
|
552 |
+
world_size = 1
|
553 |
+
if isinstance(model, str):
|
554 |
+
world_size = loader.detect_shard_count(model)
|
555 |
+
self.dctx = _DistributedContext.make(distributed_mode, world_size)
|
556 |
+
|
557 |
+
init_method = f"tcp://0.0.0.0:{random_unused_port()}"
|
558 |
+
self.workers = [
|
559 |
+
self.dctx.worker_launcher(
|
560 |
+
target=_worker_impl,
|
561 |
+
args=(init_method, model, world_size, i, self.vocab, self.dctx),
|
562 |
+
daemon=True,
|
563 |
+
)
|
564 |
+
for i in range(world_size)
|
565 |
+
]
|
566 |
+
for w in self.workers:
|
567 |
+
w.start()
|
568 |
+
self.dctx.ready_barrier.wait()
|
569 |
+
|
570 |
+
def __del__(self):
|
571 |
+
try:
|
572 |
+
with self.dctx.active_key_lock:
|
573 |
+
self.dctx.active_key.clear()
|
574 |
+
self.dctx.req_q.put([None, None, None, True])
|
575 |
+
for w in self.workers:
|
576 |
+
w.join()
|
577 |
+
except FileNotFoundError:
|
578 |
+
pass
|
579 |
+
|
580 |
+
def stream(
|
581 |
+
self,
|
582 |
+
*,
|
583 |
+
input_ids: list[int] | None = None,
|
584 |
+
prompt_text: str | None = None,
|
585 |
+
prompt_ui: list[dict] | None = None,
|
586 |
+
batch_input_ids: list[list[int]] | None = None,
|
587 |
+
batch_prompt_text: list[str] | None = None,
|
588 |
+
batch_prompt_ui: list[list[dict]] | None = None,
|
589 |
+
options: Options | None = None,
|
590 |
+
):
|
591 |
+
# NOTE: Not thread-safe! Only one instance of generate may be run at a time.
|
592 |
+
|
593 |
+
if (
|
594 |
+
sum(
|
595 |
+
x is not None
|
596 |
+
for x in [
|
597 |
+
input_ids,
|
598 |
+
prompt_text,
|
599 |
+
prompt_ui,
|
600 |
+
batch_input_ids,
|
601 |
+
batch_prompt_text,
|
602 |
+
batch_prompt_ui,
|
603 |
+
]
|
604 |
+
)
|
605 |
+
!= 1
|
606 |
+
):
|
607 |
+
raise ValueError(
|
608 |
+
"Must specify exactly one of: input_ids, prompt_text, prompt_ui, batch_input_ids, batch_prompt_text, batch_prompt_ui"
|
609 |
+
)
|
610 |
+
|
611 |
+
options = options or self.options
|
612 |
+
|
613 |
+
if prompt_text is not None:
|
614 |
+
batch_prompt_text = [prompt_text]
|
615 |
+
if prompt_ui is not None:
|
616 |
+
batch_prompt_ui = [prompt_ui]
|
617 |
+
if input_ids is not None:
|
618 |
+
batch_input_ids = [input_ids]
|
619 |
+
if batch_prompt_text is not None:
|
620 |
+
batch_prompt_ui = [
|
621 |
+
[{"type": "text", "value": prompt_text}]
|
622 |
+
for prompt_text in batch_prompt_text
|
623 |
+
]
|
624 |
+
if batch_prompt_ui is not None:
|
625 |
+
batch_input_ids = [
|
626 |
+
self.token_manager.tokens_from_ui(prompt_ui)
|
627 |
+
for prompt_ui in batch_prompt_ui
|
628 |
+
]
|
629 |
+
|
630 |
+
assert batch_input_ids
|
631 |
+
|
632 |
+
if not options.txt and not options.img:
|
633 |
+
raise ValueError("Must specify at least one modality.")
|
634 |
+
if options.txt and options.img and len(batch_input_ids) > 1:
|
635 |
+
raise ValueError(
|
636 |
+
"Batch generation only supported for one modality at a time."
|
637 |
+
)
|
638 |
+
|
639 |
+
req_key = self.next_key
|
640 |
+
self.next_key += 1
|
641 |
+
|
642 |
+
with self.dctx.active_key_lock:
|
643 |
+
self.dctx.active_key[req_key] = True
|
644 |
+
|
645 |
+
self.dctx.req_q.put([options, batch_input_ids, req_key, False])
|
646 |
+
|
647 |
+
try:
|
648 |
+
while key_token := self.dctx.res_q.get():
|
649 |
+
key, token = key_token
|
650 |
+
if key != req_key:
|
651 |
+
# Residual from prior calls to generation. Skip.
|
652 |
+
continue
|
653 |
+
if token is None:
|
654 |
+
break
|
655 |
+
yield token
|
656 |
+
finally:
|
657 |
+
with self.dctx.active_key_lock:
|
658 |
+
del self.dctx.active_key[req_key]
|
659 |
+
|
660 |
+
def step(self, *args, **kwargs) -> ChameleonGenerator.Token:
|
661 |
+
return next(self.stream(*args, **kwargs))
|
662 |
+
|
663 |
+
def generate(self, *args, **kwargs) -> torch.LongTensor:
|
664 |
+
tokens = [t.id for t in self.stream(*args, **kwargs)]
|
665 |
+
if not tokens:
|
666 |
+
return torch.LongTensor()
|
667 |
+
return torch.stack(tokens).T
|
668 |
+
|
669 |
+
def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
|
670 |
+
return self.token_manager.decode_text(ids)
|
671 |
+
|
672 |
+
def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
|
673 |
+
return self.token_manager.decode_image(ids)
|
chameleon/inference/cudagraph.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import functools
|
7 |
+
from typing import Any, Callable, TypeVar
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
T = TypeVar("T")
|
12 |
+
FN = Callable[..., T] # type: ignore
|
13 |
+
|
14 |
+
|
15 |
+
class CUDAGraphWrapper:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
fn: FN[T],
|
19 |
+
warmup_iter: int = 1,
|
20 |
+
debug_dump_path: str | None = None,
|
21 |
+
):
|
22 |
+
self.fn = fn
|
23 |
+
self.warmup_iter = warmup_iter
|
24 |
+
self.debug_dump_path = debug_dump_path
|
25 |
+
self.graph: torch.cuda.CUDAGraph | None = None
|
26 |
+
self.result: T | None = None
|
27 |
+
|
28 |
+
def __call__(self, *args, **kwargs) -> Any: # type: ignore
|
29 |
+
if self.warmup_iter > 0:
|
30 |
+
self.warmup_iter -= 1
|
31 |
+
return self.fn(*args, **kwargs)
|
32 |
+
|
33 |
+
if self.graph is None:
|
34 |
+
self.graph = torch.cuda.CUDAGraph()
|
35 |
+
if self.debug_dump_path is not None:
|
36 |
+
self.graph.enable_debug_mode()
|
37 |
+
recording_kwargs = {}
|
38 |
+
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
|
39 |
+
# In PyTorch 2.1+ and nightlies from late Aug 2023,
|
40 |
+
# we can do this to maybe avoid watchdog-related crashes
|
41 |
+
recording_kwargs["capture_error_mode"] = "thread_local"
|
42 |
+
with torch.cuda.graph(self.graph, **recording_kwargs):
|
43 |
+
self.result = self.fn(*args, **kwargs)
|
44 |
+
torch.cuda.synchronize()
|
45 |
+
if self.debug_dump_path is not None:
|
46 |
+
self.graph.debug_dump(self.debug_dump_path)
|
47 |
+
|
48 |
+
assert self.graph is not None
|
49 |
+
self.graph.replay()
|
50 |
+
return self.result
|
51 |
+
|
52 |
+
|
53 |
+
def cudagraph_wrap(
|
54 |
+
*args,
|
55 |
+
warmup_iter: int = 1,
|
56 |
+
debug_dump_path: str | None = None,
|
57 |
+
) -> Callable[[FN[T]], FN[T]]:
|
58 |
+
def wrapper(fn: FN[T]) -> FN[T]:
|
59 |
+
graph_wrapper = CUDAGraphWrapper(
|
60 |
+
fn, warmup_iter=warmup_iter, debug_dump_path=debug_dump_path
|
61 |
+
)
|
62 |
+
|
63 |
+
@functools.wraps(fn)
|
64 |
+
def call_wrapper(*inner_args, **inner_kwargs):
|
65 |
+
return graph_wrapper(*inner_args, **inner_kwargs)
|
66 |
+
|
67 |
+
return call_wrapper
|
68 |
+
|
69 |
+
# @cudagraph_wrap
|
70 |
+
# def fn(...):
|
71 |
+
# ...
|
72 |
+
#
|
73 |
+
# - or -
|
74 |
+
#
|
75 |
+
# fast_fn = cudagraph_wrap(slow_fn, warmup_iter=2)
|
76 |
+
if len(args) == 1 and callable(args[0]):
|
77 |
+
return wrapper(args[0])
|
78 |
+
|
79 |
+
# @cudagraph_wrap(warmup_iter=3)
|
80 |
+
# def fn(...):
|
81 |
+
# ...
|
82 |
+
def decorator(fn: FN[T]) -> FN[T]:
|
83 |
+
return wrapper(fn)
|
84 |
+
|
85 |
+
return decorator
|
chameleon/inference/examples/batch.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from chameleon.inference.chameleon import ChameleonInferenceModel
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
model = ChameleonInferenceModel(
|
11 |
+
"./data/models/7b/",
|
12 |
+
"./data/tokenizer/text_tokenizer.json",
|
13 |
+
"./data/tokenizer/vqgan.yaml",
|
14 |
+
"./data/tokenizer/vqgan.ckpt",
|
15 |
+
)
|
16 |
+
|
17 |
+
batch_tokens = model.generate(batch_prompt_text=["All your base", "import asyncio"])
|
18 |
+
for text in model.decode_text(batch_tokens):
|
19 |
+
print(text)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
main()
|
chameleon/inference/examples/multimodal_input.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from chameleon.inference.chameleon import ChameleonInferenceModel
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
model = ChameleonInferenceModel(
|
11 |
+
"./data/models/7b/",
|
12 |
+
"./data/tokenizer/text_tokenizer.json",
|
13 |
+
"./data/tokenizer/vqgan.yaml",
|
14 |
+
"./data/tokenizer/vqgan.ckpt",
|
15 |
+
)
|
16 |
+
|
17 |
+
tokens = model.generate(
|
18 |
+
prompt_ui=[
|
19 |
+
{"type": "image", "value": "file:/path/to/image.jpeg"},
|
20 |
+
{"type": "text", "value": "What do you see?"},
|
21 |
+
{"type": "sentinel", "value": "<END-OF-TURN>"},
|
22 |
+
]
|
23 |
+
)
|
24 |
+
print(model.decode_text(tokens)[0])
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
main()
|
chameleon/inference/examples/simple.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from chameleon.inference.chameleon import ChameleonInferenceModel
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
model = ChameleonInferenceModel(
|
11 |
+
"./data/models/7b/",
|
12 |
+
"./data/tokenizer/text_tokenizer.json",
|
13 |
+
"./data/tokenizer/vqgan.yaml",
|
14 |
+
"./data/tokenizer/vqgan.ckpt",
|
15 |
+
)
|
16 |
+
|
17 |
+
tokens = model.generate(prompt_text="All your base")
|
18 |
+
print(model.decode_text(tokens)[0])
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
main()
|
chameleon/inference/examples/streaming.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from chameleon.inference.chameleon import ChameleonInferenceModel
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
model = ChameleonInferenceModel(
|
11 |
+
"./data/models/7b/",
|
12 |
+
"./data/tokenizer/text_tokenizer.json",
|
13 |
+
"./data/tokenizer/vqgan.yaml",
|
14 |
+
"./data/tokenizer/vqgan.ckpt",
|
15 |
+
)
|
16 |
+
|
17 |
+
for tokens in model.stream(prompt_text="All your base"):
|
18 |
+
print(model.decode_text(tokens.id.view(-1, 1))[0], end="")
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
main()
|
chameleon/inference/examples/streaming_batch.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from chameleon.inference.chameleon import ChameleonInferenceModel
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
model = ChameleonInferenceModel(
|
11 |
+
"./data/models/7b/",
|
12 |
+
"./data/tokenizer/text_tokenizer.json",
|
13 |
+
"./data/tokenizer/vqgan.yaml",
|
14 |
+
"./data/tokenizer/vqgan.ckpt",
|
15 |
+
)
|
16 |
+
|
17 |
+
for i, batch_tokens in enumerate(
|
18 |
+
model.stream(batch_prompt_text=["All your base", "import asyncio"])
|
19 |
+
):
|
20 |
+
print(model.decode_text(batch_tokens.id.view(-1, 1)))
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
main()
|
chameleon/inference/generation.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import (
|
10 |
+
LogitsProcessor,
|
11 |
+
LogitsProcessorList,
|
12 |
+
)
|
13 |
+
from transformers.generation.streamers import BaseStreamer
|
14 |
+
|
15 |
+
from chameleon.inference.alignment import AlignPromptLeft, PromptAlignment
|
16 |
+
from chameleon.inference.model_adapter import ModelAdapter
|
17 |
+
from chameleon.inference.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
18 |
+
from chameleon.inference.token_selector import MultinomialTokenSelector, TokenSelector
|
19 |
+
|
20 |
+
|
21 |
+
class ChameleonGenerator:
|
22 |
+
@dataclass
|
23 |
+
class Token:
|
24 |
+
id: torch.LongTensor
|
25 |
+
logits: torch.Tensor | None
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
model: ModelAdapter,
|
30 |
+
input_ids: list[list[int]],
|
31 |
+
stopping_criteria: StoppingCriteriaList | list[StoppingCriteria] | None = None,
|
32 |
+
logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
|
33 |
+
probability_processors: LogitsProcessorList
|
34 |
+
| list[LogitsProcessor]
|
35 |
+
| None = None,
|
36 |
+
token_selector: TokenSelector | None = None,
|
37 |
+
alignment: PromptAlignment = AlignPromptLeft(),
|
38 |
+
):
|
39 |
+
assert model.supports_alignment(alignment)
|
40 |
+
|
41 |
+
self.model = model
|
42 |
+
|
43 |
+
self.stopping_criteria = stopping_criteria
|
44 |
+
self.logits_processors = logits_processors
|
45 |
+
self.probability_processors = probability_processors
|
46 |
+
self.token_selector: TokenSelector = (
|
47 |
+
token_selector or MultinomialTokenSelector()
|
48 |
+
)
|
49 |
+
|
50 |
+
self.alignment = alignment
|
51 |
+
|
52 |
+
self.model.initialize(input_ids)
|
53 |
+
|
54 |
+
self._inputs = self.alignment.prepare_inputs(
|
55 |
+
input_ids
|
56 |
+
) # inputs.shape = [batch, seq-len]
|
57 |
+
|
58 |
+
self._idx = 0
|
59 |
+
self._start_idx = self.alignment.start_index(input_ids)
|
60 |
+
|
61 |
+
self._original_inputs = self._inputs.clone()
|
62 |
+
self._inputs = self._inputs[:, : self._start_idx]
|
63 |
+
|
64 |
+
def __iter__(self):
|
65 |
+
return self
|
66 |
+
|
67 |
+
@torch.inference_mode()
|
68 |
+
def __next__(self) -> Token:
|
69 |
+
# Are we done?
|
70 |
+
if self.stopping_criteria(self._inputs, None):
|
71 |
+
raise StopIteration
|
72 |
+
|
73 |
+
# Emit initial tokens.
|
74 |
+
# Model is not run for these.
|
75 |
+
# If you want the logits, you can do a separate forward pass outside generation.
|
76 |
+
if self._idx < self._start_idx:
|
77 |
+
idx, self._idx = self._idx, self._idx + 1
|
78 |
+
return ChameleonGenerator.Token(id=self._inputs[:, idx], logits=None)
|
79 |
+
|
80 |
+
# Run the model for the next token.
|
81 |
+
self._inputs = self._inputs.contiguous()
|
82 |
+
outputs = self.model(self._inputs) # outputs.shape = [batch, seq-len, vocab]
|
83 |
+
|
84 |
+
# Pull out and process the logits.
|
85 |
+
logits = outputs[:, -1, :] # logits.shape = [batch, vocab]
|
86 |
+
logits = self.logits_processors(self._inputs, logits)
|
87 |
+
probs = logits.softmax(dim=1) # probs.shape = [batch, vocab]
|
88 |
+
probs = self.probability_processors(self._inputs, probs)
|
89 |
+
|
90 |
+
# Select a token and add it to the inputs.
|
91 |
+
next_tokens = self.token_selector(
|
92 |
+
self._inputs, probs
|
93 |
+
) # next_tokens.shape = [batch]
|
94 |
+
self._inputs = torch.cat([self._inputs, next_tokens[:, None]], dim=1)
|
95 |
+
|
96 |
+
# Run alignment specific postprocessing.
|
97 |
+
self._inputs = self.alignment.postprocess_inputs(
|
98 |
+
self._inputs, self._original_inputs
|
99 |
+
)
|
100 |
+
|
101 |
+
# Return the next step result.
|
102 |
+
return ChameleonGenerator.Token(id=self._inputs[:, -1], logits=logits)
|
103 |
+
|
104 |
+
@property
|
105 |
+
def stopping_criteria(self) -> StoppingCriteriaList:
|
106 |
+
return self._stopping_criteria
|
107 |
+
|
108 |
+
@stopping_criteria.setter
|
109 |
+
def stopping_criteria(
|
110 |
+
self, value: StoppingCriteriaList | list[StoppingCriteria] | None
|
111 |
+
):
|
112 |
+
self._stopping_criteria = StoppingCriteriaList(value or [])
|
113 |
+
|
114 |
+
@property
|
115 |
+
def logits_processors(self) -> LogitsProcessorList:
|
116 |
+
return self._logits_processors
|
117 |
+
|
118 |
+
@logits_processors.setter
|
119 |
+
def logits_processors(
|
120 |
+
self, value: LogitsProcessorList | list[LogitsProcessor] | None
|
121 |
+
):
|
122 |
+
self._logits_processors = LogitsProcessorList(value or [])
|
123 |
+
|
124 |
+
@property
|
125 |
+
def probability_processors(self) -> LogitsProcessorList:
|
126 |
+
return self._probability_processors
|
127 |
+
|
128 |
+
@probability_processors.setter
|
129 |
+
def probability_processors(
|
130 |
+
self, value: LogitsProcessorList | list[LogitsProcessor] | None
|
131 |
+
):
|
132 |
+
self._probability_processors = LogitsProcessorList(value or [])
|
133 |
+
|
134 |
+
|
135 |
+
def run_generation(
|
136 |
+
model: torch.nn.Module,
|
137 |
+
input_ids: list[list[int]],
|
138 |
+
stopping_criteria: StoppingCriteriaList | list[StoppingCriteria],
|
139 |
+
logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
|
140 |
+
probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
|
141 |
+
token_selector: TokenSelector | None = None,
|
142 |
+
alignment: PromptAlignment = AlignPromptLeft(),
|
143 |
+
streamer: BaseStreamer | None = None,
|
144 |
+
) -> torch.LongTensor:
|
145 |
+
result = torch.empty((len(input_ids), 0), dtype=int)
|
146 |
+
for tok in ChameleonGenerator(
|
147 |
+
model=model,
|
148 |
+
input_ids=input_ids,
|
149 |
+
stopping_criteria=stopping_criteria,
|
150 |
+
logits_processors=logits_processors,
|
151 |
+
probability_processors=probability_processors,
|
152 |
+
token_selector=token_selector,
|
153 |
+
alignment=alignment,
|
154 |
+
):
|
155 |
+
if streamer is not None:
|
156 |
+
streamer.put(tok.id)
|
157 |
+
result = torch.cat([result, tok.id.view(-1, 1)], dim=1)
|
158 |
+
|
159 |
+
if streamer is not None:
|
160 |
+
streamer.end()
|
161 |
+
|
162 |
+
return result
|
chameleon/inference/image_tokenizer.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from chameleon.inference.vqgan import VQModel
|
13 |
+
|
14 |
+
|
15 |
+
class ImageTokenizer:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
cfg_path: str,
|
19 |
+
ckpt_path: str,
|
20 |
+
device: str,
|
21 |
+
):
|
22 |
+
with open(cfg_path) as f:
|
23 |
+
config = yaml.safe_load(f)
|
24 |
+
|
25 |
+
params = config["model"]["params"]
|
26 |
+
if "lossconfig" in params:
|
27 |
+
del params["lossconfig"]
|
28 |
+
params["ckpt_path"] = ckpt_path
|
29 |
+
|
30 |
+
self._vq_model = VQModel(**params)
|
31 |
+
self._vq_model.eval()
|
32 |
+
|
33 |
+
if device is None:
|
34 |
+
devices = {p.device for p in self._vq_model.parameters()}
|
35 |
+
assert len(devices) == 1
|
36 |
+
device = devices.pop()
|
37 |
+
else:
|
38 |
+
self._vq_model.to(device)
|
39 |
+
self._device = device
|
40 |
+
|
41 |
+
dtypes = {p.dtype for p in self._vq_model.parameters()}
|
42 |
+
assert len(dtypes) == 1
|
43 |
+
self._dtype = dtypes.pop()
|
44 |
+
|
45 |
+
def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
|
46 |
+
# Check if it's already in RGB format.
|
47 |
+
if img.mode == "RGB":
|
48 |
+
return img
|
49 |
+
|
50 |
+
vals_rgba = np.array(img.convert("RGBA"))
|
51 |
+
|
52 |
+
# If there is no transparency layer, simple convert and return.
|
53 |
+
if not (vals_rgba[:, :, 3] < 255).any():
|
54 |
+
return img.convert("RGB")
|
55 |
+
|
56 |
+
# There is a transparency layer, blend it with a white background.
|
57 |
+
|
58 |
+
# Calculate the alpha proportion for blending.
|
59 |
+
alpha = vals_rgba[:, :, 3] / 255.0
|
60 |
+
# Blend with white background.
|
61 |
+
vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[
|
62 |
+
:, :, np.newaxis
|
63 |
+
] * vals_rgba[:, :, :3]
|
64 |
+
return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
|
65 |
+
|
66 |
+
def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
|
67 |
+
# Resize with aspect ratio preservation.
|
68 |
+
s = min(img.size)
|
69 |
+
scale = target_image_size / s
|
70 |
+
new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
|
71 |
+
img = img.resize(new_size, PIL.Image.LANCZOS)
|
72 |
+
|
73 |
+
# Center crop.
|
74 |
+
x0 = (img.width - target_image_size) // 2
|
75 |
+
y0 = (img.height - target_image_size) // 2
|
76 |
+
img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
|
77 |
+
|
78 |
+
# Convert to tensor.
|
79 |
+
np_img = np.array(img) / 255.0 # Normalize to [0, 1]
|
80 |
+
np_img = np_img * 2 - 1 # Scale to [-1, 1]
|
81 |
+
tensor_img = (
|
82 |
+
torch.from_numpy(np_img).permute(2, 0, 1).float()
|
83 |
+
) # (Channels, Height, Width) format.
|
84 |
+
|
85 |
+
# Add batch dimension.
|
86 |
+
return tensor_img.unsqueeze(0)
|
87 |
+
|
88 |
+
def img_tokens_from_pil(self, image: PIL.Image) -> list[int]:
|
89 |
+
image = self._whiten_transparency(image)
|
90 |
+
vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype)
|
91 |
+
_, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input)
|
92 |
+
return img_toks
|
93 |
+
|
94 |
+
def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
|
95 |
+
# Ensure detachment and move tensor to CPU.
|
96 |
+
detached_chw_tensor = chw_tensor.detach().cpu()
|
97 |
+
|
98 |
+
# Normalize tensor to [0, 1] range from [-1, 1] range.
|
99 |
+
normalized_chw_tensor = (
|
100 |
+
torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
|
101 |
+
) / 2.0
|
102 |
+
|
103 |
+
# Permute CHW tensor to HWC format and convert to NumPy array.
|
104 |
+
hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
|
105 |
+
|
106 |
+
# Convert to an 8-bit unsigned integer format.
|
107 |
+
image_array_uint8 = (hwc_array * 255).astype(np.uint8)
|
108 |
+
|
109 |
+
# Convert NumPy array to PIL Image.
|
110 |
+
pil_image = Image.fromarray(image_array_uint8)
|
111 |
+
|
112 |
+
# Convert image to RGB if it is not already.
|
113 |
+
if pil_image.mode != "RGB":
|
114 |
+
pil_image = pil_image.convert("RGB")
|
115 |
+
|
116 |
+
return pil_image
|
117 |
+
|
118 |
+
def pil_from_img_toks(self, img_tensor: torch.Tensor, height=32,width=32) -> PIL.Image:
|
119 |
+
emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
|
120 |
+
# import pdb;pdb.set_trace()
|
121 |
+
codebook_entry = self._vq_model.quantize.get_codebook_entry(
|
122 |
+
img_tensor, (1, height, width, emb_dim)
|
123 |
+
)
|
124 |
+
# import pdb;pdb.set_trace()
|
125 |
+
pixels = self._vq_model.decode(codebook_entry)
|
126 |
+
# import pdb;pdb.set_trace()
|
127 |
+
return self._pil_from_chw_tensor(pixels[0])
|
chameleon/inference/loader.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import glob
|
7 |
+
import inspect
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from chameleon.inference.transformer import ModelArgs, Transformer
|
14 |
+
|
15 |
+
|
16 |
+
def _convert(model_args: ModelArgs, consolidated_path: Path) -> Transformer:
|
17 |
+
old_default_dtype = torch.get_default_dtype()
|
18 |
+
torch.set_default_dtype(torch.bfloat16)
|
19 |
+
|
20 |
+
model = Transformer(model_args)
|
21 |
+
|
22 |
+
transfer_results = model.load_state_dict(
|
23 |
+
torch.load(str(consolidated_path)),
|
24 |
+
strict=False,
|
25 |
+
)
|
26 |
+
|
27 |
+
# TODO: More generally, assert missing or unexpected keys are buffers.
|
28 |
+
assert transfer_results.missing_keys == []
|
29 |
+
assert transfer_results.unexpected_keys == ["rope.freqs"]
|
30 |
+
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
torch.set_default_dtype(old_default_dtype)
|
34 |
+
return model
|
35 |
+
|
36 |
+
|
37 |
+
def _get_checkpoint_path(src_dir: Path, rank: int | None) -> Path:
|
38 |
+
base_path = src_dir / "consolidated.pth"
|
39 |
+
if not rank and base_path.exists():
|
40 |
+
return base_path
|
41 |
+
|
42 |
+
alt_path = src_dir / f"consolidated.{rank:02}.pth"
|
43 |
+
if alt_path.exists():
|
44 |
+
return alt_path
|
45 |
+
|
46 |
+
raise ValueError("Consolidated checkpoint not found.")
|
47 |
+
|
48 |
+
|
49 |
+
def load_model(path: str, rank: int | None = None) -> Transformer:
|
50 |
+
src_dir = Path(path)
|
51 |
+
|
52 |
+
with open(src_dir / "params.json", "r") as f:
|
53 |
+
params = json.loads(f.read())
|
54 |
+
with open(src_dir / "consolidate_params.json", "r") as f:
|
55 |
+
consolidate_params = json.loads(f.read())
|
56 |
+
params = {**params, **params["model"], **consolidate_params}
|
57 |
+
|
58 |
+
known_param = inspect.signature(ModelArgs.__init__).parameters
|
59 |
+
filtered_params = {k: v for k, v in params.items() if k in known_param}
|
60 |
+
|
61 |
+
return _convert(
|
62 |
+
ModelArgs(**filtered_params),
|
63 |
+
_get_checkpoint_path(src_dir, rank),
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def detect_shard_count(path: str) -> int:
|
68 |
+
src_dir = Path(path)
|
69 |
+
if (src_dir / "consolidated.pth").exists():
|
70 |
+
return 1
|
71 |
+
return len(glob.glob(str(src_dir / "consolidated.*.pth")))
|
chameleon/inference/logits_processor.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import LogitsProcessor
|
10 |
+
|
11 |
+
|
12 |
+
class TopPProbabilityProcessor(LogitsProcessor):
|
13 |
+
# Modified version of TopPLogitsWarper to act on probabilities.
|
14 |
+
# Changes:
|
15 |
+
# * filter_value changed from -inf to 0
|
16 |
+
# * removed softmax
|
17 |
+
# * renormalize L1
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
top_p: float,
|
22 |
+
min_tokens_to_keep: int = 1,
|
23 |
+
):
|
24 |
+
top_p = float(top_p)
|
25 |
+
if top_p < 0 or top_p > 1.0:
|
26 |
+
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
27 |
+
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
|
28 |
+
raise ValueError(
|
29 |
+
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
|
30 |
+
)
|
31 |
+
|
32 |
+
self.top_p = top_p
|
33 |
+
self.min_tokens_to_keep = min_tokens_to_keep
|
34 |
+
|
35 |
+
def __call__(
|
36 |
+
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
|
37 |
+
) -> torch.FloatTensor:
|
38 |
+
# input_ids.shape=[batch, seq-len]
|
39 |
+
# probs.shape=[batch, vocab]
|
40 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=False)
|
41 |
+
cumulative_probs = sorted_probs.cumsum(dim=-1)
|
42 |
+
|
43 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
44 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
45 |
+
# Keep at least min_tokens_to_keep
|
46 |
+
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
47 |
+
|
48 |
+
# scatter sorted tensors to original indexing
|
49 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
50 |
+
1, sorted_indices, sorted_indices_to_remove
|
51 |
+
)
|
52 |
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
53 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
54 |
+
return probs
|
55 |
+
|
56 |
+
|
57 |
+
class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor):
|
58 |
+
def __init__(
|
59 |
+
self, token_ids: list[int], start_index: int, end_index: int | None = None
|
60 |
+
):
|
61 |
+
self.token_ids = torch.tensor(token_ids)
|
62 |
+
self.start_index = start_index
|
63 |
+
self.end_index = end_index if end_index is not None else math.inf
|
64 |
+
|
65 |
+
def __call__(
|
66 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
67 |
+
) -> torch.FloatTensor:
|
68 |
+
current_index = input_ids.shape[1]
|
69 |
+
if self.start_index <= current_index < self.end_index:
|
70 |
+
logits[:, self.token_ids] = -math.inf
|
71 |
+
return logits
|
72 |
+
|
73 |
+
|
74 |
+
class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
|
75 |
+
def __init__(self, token_ids: list[int]):
|
76 |
+
super().__init__(token_ids, 0)
|
77 |
+
|
78 |
+
|
79 |
+
class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
|
80 |
+
def __init__(self, token_ids: list[int], index: int):
|
81 |
+
super().__init__(token_ids, index, index + 1)
|
82 |
+
|
83 |
+
|
84 |
+
class DisallowTokensAfterIndexLogitsProcessor(
|
85 |
+
DisallowTokensInIndexRangeLogitsProcessor
|
86 |
+
):
|
87 |
+
def __init__(self, token_ids: list[int], index: int):
|
88 |
+
super().__init__(token_ids, index + 1)
|
89 |
+
|
90 |
+
|
91 |
+
class DisallowTokensAtOrAfterIndexLogitsProcessor(
|
92 |
+
DisallowTokensInIndexRangeLogitsProcessor
|
93 |
+
):
|
94 |
+
def __init__(self, token_ids: list[int], index: int):
|
95 |
+
super().__init__(token_ids, index)
|
96 |
+
|
97 |
+
|
98 |
+
class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
token_ids: list[int],
|
102 |
+
start_indices: list[int],
|
103 |
+
end_indices: list[int] | None = None,
|
104 |
+
):
|
105 |
+
self.token_ids = torch.tensor(token_ids)
|
106 |
+
self.start_indices = torch.tensor(start_indices)
|
107 |
+
self.end_indices = (
|
108 |
+
torch.tensor(end_indices)
|
109 |
+
if end_indices is not None
|
110 |
+
else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
|
111 |
+
)
|
112 |
+
|
113 |
+
def __call__(
|
114 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
115 |
+
) -> torch.FloatTensor:
|
116 |
+
# input_ids.shape = [batch, seq_len]
|
117 |
+
# logits.shape = [batch, vocab]
|
118 |
+
current_index = input_ids.shape[1]
|
119 |
+
mask = (self.start_indices <= current_index) & (
|
120 |
+
current_index < self.end_indices
|
121 |
+
)
|
122 |
+
# The following will fail if the mask is all False.
|
123 |
+
# logits[mask, self.token_ids] = -math.inf
|
124 |
+
logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf
|
125 |
+
return logits
|
126 |
+
|
127 |
+
|
128 |
+
class DisallowTokensAtBatchIndexLogitsProcessor(
|
129 |
+
DisallowTokensInBatchIndexRangeLogitsProcessor
|
130 |
+
):
|
131 |
+
def __init__(self, token_ids: list[int], batch_index: list[int]):
|
132 |
+
super().__init__(token_ids, batch_index, [i + 1 for i in batch_index])
|
133 |
+
|
134 |
+
|
135 |
+
class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor):
|
136 |
+
def __init__(
|
137 |
+
self, token_ids: list[int], start_index: int, end_index: int | None = None
|
138 |
+
):
|
139 |
+
self.token_ids = torch.tensor(token_ids)
|
140 |
+
self.start_index = start_index
|
141 |
+
self.end_index = end_index if end_index is not None else math.inf
|
142 |
+
|
143 |
+
def __call__(
|
144 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
145 |
+
) -> torch.FloatTensor:
|
146 |
+
current_index = input_ids.shape[1]
|
147 |
+
if self.start_index <= current_index < self.end_index:
|
148 |
+
replacement = torch.full_like(logits, -math.inf)
|
149 |
+
replacement[:, self.token_ids] = logits[:, self.token_ids]
|
150 |
+
logits[:] = replacement
|
151 |
+
return logits
|
152 |
+
|
153 |
+
|
154 |
+
class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
|
155 |
+
def __init__(self, token_ids: list[int]):
|
156 |
+
super().__init__(token_ids, 0)
|
157 |
+
|
158 |
+
|
159 |
+
class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
|
160 |
+
def __init__(self, token_ids: list[int], index: int):
|
161 |
+
super().__init__(token_ids, index, index + 1)
|
162 |
+
|
163 |
+
|
164 |
+
class AllowOnlyTokensAfterIndexLogitsProcessor(
|
165 |
+
AllowOnlyTokensInIndexRangeLogitsProcessor
|
166 |
+
):
|
167 |
+
def __init__(self, token_ids: list[int], index: int):
|
168 |
+
super().__init__(token_ids, index + 1)
|
169 |
+
|
170 |
+
|
171 |
+
class AllowOnlyTokensAtOrAfterIndexLogitsProcessor(
|
172 |
+
AllowOnlyTokensInIndexRangeLogitsProcessor
|
173 |
+
):
|
174 |
+
def __init__(self, token_ids: list[int], index: int):
|
175 |
+
super().__init__(token_ids, index)
|
176 |
+
|
177 |
+
|
178 |
+
class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
token_ids: list[int],
|
182 |
+
start_indices: list[int],
|
183 |
+
end_indices: list[int] | None = None,
|
184 |
+
):
|
185 |
+
self.token_ids = torch.tensor(token_ids)
|
186 |
+
self.start_indices = torch.tensor(start_indices)
|
187 |
+
self.end_indices = (
|
188 |
+
torch.tensor(end_indices)
|
189 |
+
if end_indices is not None
|
190 |
+
else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
|
191 |
+
)
|
192 |
+
|
193 |
+
def __call__(
|
194 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
195 |
+
) -> torch.FloatTensor:
|
196 |
+
# input_ids.shape = [batch, seq_len]
|
197 |
+
# logits.shape = [batch, vocab]
|
198 |
+
current_index = input_ids.shape[1]
|
199 |
+
mask = (self.start_indices <= current_index) & (
|
200 |
+
current_index < self.end_indices
|
201 |
+
)
|
202 |
+
|
203 |
+
valid_batch_indices = torch.where(mask)[0].unsqueeze(1)
|
204 |
+
full_mask = torch.full_like(logits, -math.inf)
|
205 |
+
full_mask[valid_batch_indices, self.token_ids] = logits[
|
206 |
+
valid_batch_indices, self.token_ids
|
207 |
+
]
|
208 |
+
|
209 |
+
logits[:] = torch.where(full_mask != -math.inf, full_mask, logits)
|
210 |
+
return logits
|
211 |
+
|
212 |
+
|
213 |
+
class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor):
|
214 |
+
def __init__(
|
215 |
+
self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int
|
216 |
+
):
|
217 |
+
self.trigger_token_id = trigger_token_id
|
218 |
+
self.subsequent_token_ids = torch.tensor(subsequent_token_ids)
|
219 |
+
self.offset = offset
|
220 |
+
|
221 |
+
def __call__(
|
222 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
223 |
+
) -> torch.FloatTensor:
|
224 |
+
# input_ids.shape=[batch, seq_len]
|
225 |
+
# logits.shape=[batch, vocab]
|
226 |
+
if input_ids.shape[1] < self.offset:
|
227 |
+
return logits
|
228 |
+
|
229 |
+
trigger_positions = (
|
230 |
+
input_ids[:, -self.offset] == self.trigger_token_id
|
231 |
+
).unsqueeze(-1)
|
232 |
+
|
233 |
+
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
|
234 |
+
disallowed_tokens_mask[:, self.subsequent_token_ids] = False
|
235 |
+
|
236 |
+
return logits.masked_fill_(
|
237 |
+
disallowed_tokens_mask & trigger_positions,
|
238 |
+
-math.inf,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
|
243 |
+
def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int):
|
244 |
+
self.trigger_token_id = trigger_token_id
|
245 |
+
self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze(
|
246 |
+
0
|
247 |
+
) # shape: [1, num_allowed_tokens]
|
248 |
+
self.width = width
|
249 |
+
|
250 |
+
def __call__(
|
251 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
252 |
+
) -> torch.FloatTensor:
|
253 |
+
# input_ids.shape=[batch, seq_len]
|
254 |
+
# logits.shape=[batch, vocab]
|
255 |
+
width = min(self.width, input_ids.shape[1])
|
256 |
+
trigger_positions = (
|
257 |
+
(input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)
|
258 |
+
)
|
259 |
+
|
260 |
+
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
|
261 |
+
disallowed_tokens_mask[:, self.allowed_token_ids] = False
|
262 |
+
|
263 |
+
return logits.masked_fill_(
|
264 |
+
disallowed_tokens_mask & trigger_positions,
|
265 |
+
-math.inf,
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
class CFGLogitsProcessor(LogitsProcessor):
|
270 |
+
def __init__(
|
271 |
+
self,
|
272 |
+
guidance_scale: float,
|
273 |
+
unconditional_ids: torch.LongTensor,
|
274 |
+
model,
|
275 |
+
):
|
276 |
+
self.guidance_scale = guidance_scale
|
277 |
+
self.unconditional_ids = unconditional_ids
|
278 |
+
self.model = model
|
279 |
+
|
280 |
+
def __call__(
|
281 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
282 |
+
) -> torch.FloatTensor:
|
283 |
+
conditioned_logits = logits
|
284 |
+
|
285 |
+
self.unconditional_ids = torch.cat(
|
286 |
+
[self.unconditional_ids, input_ids[:, -1:]], dim=1
|
287 |
+
)
|
288 |
+
unconditioned_outputs = self.model(self.unconditional_ids)
|
289 |
+
unconditioned_logits = unconditioned_outputs[:, -1, :]
|
290 |
+
return (
|
291 |
+
self.guidance_scale * (conditioned_logits - unconditioned_logits)
|
292 |
+
+ unconditioned_logits
|
293 |
+
)
|
294 |
+
|
295 |
+
|
296 |
+
class InBatchCFGLogitsProcessor(LogitsProcessor):
|
297 |
+
def __init__(self, guidance_scale: float):
|
298 |
+
self.guidance_scale = guidance_scale
|
299 |
+
|
300 |
+
def __call__(
|
301 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
302 |
+
) -> torch.FloatTensor:
|
303 |
+
# input_ids.shape=[2*batch, seq-len]
|
304 |
+
# logits.shape=[2*batch, vocab]
|
305 |
+
conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0)
|
306 |
+
mixed_logits = unconditioned_logits + self.guidance_scale * (
|
307 |
+
conditioned_logits - unconditioned_logits
|
308 |
+
)
|
309 |
+
return mixed_logits.repeat(2, 1)
|
310 |
+
|
311 |
+
|
312 |
+
class InBatchInstructCFGLogitsProcessor(LogitsProcessor):
|
313 |
+
# See https://arxiv.org/abs/2211.09800
|
314 |
+
|
315 |
+
def __init__(self, guidance_scale_text: float, guidance_scale_image: float):
|
316 |
+
self.guidance_scale_text = guidance_scale_text
|
317 |
+
self.guidance_scale_image = guidance_scale_image
|
318 |
+
|
319 |
+
def __call__(
|
320 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
321 |
+
) -> torch.FloatTensor:
|
322 |
+
# input_ids.shape=[3*batch, seq-len]
|
323 |
+
# logits.shape=[3*batch, vocab]
|
324 |
+
(
|
325 |
+
full_conditioned_logits,
|
326 |
+
image_conditioned_logits,
|
327 |
+
unconditioned_logits,
|
328 |
+
) = logits.chunk(3)
|
329 |
+
mixed_logits = (
|
330 |
+
unconditioned_logits
|
331 |
+
+ self.guidance_scale_image
|
332 |
+
* (image_conditioned_logits - unconditioned_logits)
|
333 |
+
+ self.guidance_scale_text
|
334 |
+
* (full_conditioned_logits - image_conditioned_logits)
|
335 |
+
)
|
336 |
+
return mixed_logits.repeat(3, 1)
|
chameleon/inference/model_adapter.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from chameleon.inference import transformer
|
12 |
+
from chameleon.inference.alignment import (
|
13 |
+
AlignPromptLeft,
|
14 |
+
AlignPromptRight,
|
15 |
+
PromptAlignment,
|
16 |
+
)
|
17 |
+
from chameleon.inference.cudagraph import cudagraph_wrap
|
18 |
+
|
19 |
+
|
20 |
+
class ModelAdapter(ABC):
|
21 |
+
@abstractmethod
|
22 |
+
def initialize(self, prompt_tokens: list[list[int]]):
|
23 |
+
...
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def supports_alignment(self, alignment: PromptAlignment) -> bool:
|
27 |
+
...
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
@torch.inference_mode()
|
31 |
+
def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
|
32 |
+
...
|
33 |
+
|
34 |
+
|
35 |
+
class ChameleonModelAdapter(ModelAdapter):
|
36 |
+
"""Adapter for Chameleon-style model that handles state, such as cache."""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
model: transformer.Transformer,
|
41 |
+
max_seq_len: int,
|
42 |
+
dtype: torch.dtype | None = None,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self._args = model.args
|
46 |
+
self._model = model
|
47 |
+
self._max_seq_len = max_seq_len
|
48 |
+
self._dtype = dtype or next(model.parameters()).data.dtype
|
49 |
+
|
50 |
+
def initialize(self, prompt_tokens: list[list[int]]):
|
51 |
+
self._prompt_lengths = [len(toks) for toks in prompt_tokens]
|
52 |
+
batch_size = len(prompt_tokens)
|
53 |
+
|
54 |
+
self._cache = transformer.make_cache(
|
55 |
+
args=self._args,
|
56 |
+
length=batch_size * self._max_seq_len,
|
57 |
+
dtype=self._dtype,
|
58 |
+
)
|
59 |
+
|
60 |
+
self._local_inputs = torch.zeros([batch_size], dtype=int, device="cuda")
|
61 |
+
|
62 |
+
self._forward = cudagraph_wrap(self._model.forward_with_attn_bias)
|
63 |
+
|
64 |
+
self._first_pass = True
|
65 |
+
|
66 |
+
def supports_alignment(self, alignment: PromptAlignment) -> bool:
|
67 |
+
return isinstance(alignment, AlignPromptLeft) or isinstance(
|
68 |
+
alignment, AlignPromptRight
|
69 |
+
)
|
70 |
+
|
71 |
+
def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
|
72 |
+
# inputs.shape=[batch, seq-len]
|
73 |
+
batch_size, seq_len = inputs.shape
|
74 |
+
|
75 |
+
if self._first_pass:
|
76 |
+
attn_seqlen = [min(pl, seq_len) for pl in self._prompt_lengths]
|
77 |
+
self._bias = transformer.AttnBias.from_seqlens(
|
78 |
+
q_seqlen=attn_seqlen,
|
79 |
+
kv_seqlen=attn_seqlen,
|
80 |
+
kv_padding=self._max_seq_len,
|
81 |
+
)
|
82 |
+
|
83 |
+
mask = torch.zeros_like(inputs, dtype=torch.bool)
|
84 |
+
for i, k in enumerate(self._prompt_lengths):
|
85 |
+
mask[i, -k:] = True
|
86 |
+
|
87 |
+
flat_outputs: torch.Tensor = self._forward( # type: ignore
|
88 |
+
token_values=inputs[mask],
|
89 |
+
attn_bias=self._bias,
|
90 |
+
cache=self._cache,
|
91 |
+
)
|
92 |
+
self._local_outputs = torch.full(
|
93 |
+
(inputs.shape[0], inputs.shape[1], flat_outputs.shape[-1]),
|
94 |
+
-math.inf,
|
95 |
+
)
|
96 |
+
self._local_outputs[mask] = flat_outputs
|
97 |
+
|
98 |
+
self._vocab_size = self._local_outputs.shape[-1]
|
99 |
+
|
100 |
+
self._bias.q_seqinfo.seqstart.copy_(
|
101 |
+
torch.arange(batch_size + 1, dtype=torch.int)
|
102 |
+
)
|
103 |
+
self._bias.q_seqinfo.max_seqlen = 1
|
104 |
+
self._bias.q_seqinfo.seqstart_py = self._bias.q_seqinfo.seqstart.tolist()
|
105 |
+
|
106 |
+
self._first_pass = False
|
107 |
+
|
108 |
+
else:
|
109 |
+
self._local_inputs.copy_(inputs[:, -1]) # type: ignore
|
110 |
+
|
111 |
+
self._local_outputs = self._forward( # type: ignore
|
112 |
+
token_values=self._local_inputs,
|
113 |
+
attn_bias=self._bias,
|
114 |
+
cache=self._cache,
|
115 |
+
)
|
116 |
+
|
117 |
+
self._bias.k_seqinfo.seqlen.add_(1)
|
118 |
+
return self._local_outputs.view(batch_size, -1, self._vocab_size)
|
chameleon/inference/stopping_criteria.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class StoppingCriteria:
|
10 |
+
def __call__(
|
11 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
12 |
+
) -> bool:
|
13 |
+
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
14 |
+
|
15 |
+
|
16 |
+
class StoppingCriteriaList(list):
|
17 |
+
def __call__(
|
18 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
19 |
+
) -> bool:
|
20 |
+
return any(criteria(input_ids, scores, **kwargs) for criteria in self)
|
21 |
+
|
22 |
+
|
23 |
+
class MaxLengthCriteria(StoppingCriteria):
|
24 |
+
def __init__(self, max_length: int):
|
25 |
+
self.max_length = max_length
|
26 |
+
|
27 |
+
def __call__(
|
28 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
29 |
+
) -> bool:
|
30 |
+
cur_len = input_ids.shape[-1]
|
31 |
+
return cur_len >= self.max_length
|
32 |
+
|
33 |
+
|
34 |
+
class StopOnEOS(StoppingCriteria):
|
35 |
+
def __init__(self, eos_id: int):
|
36 |
+
self._eos_id = eos_id
|
37 |
+
|
38 |
+
def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
|
39 |
+
# input_ids.shape=[batch, seq_len]
|
40 |
+
return (input_ids == self._eos_id).sum(dim=1).all()
|
41 |
+
|
42 |
+
|
43 |
+
class StopOnEOSAfterBatchIndex(StoppingCriteria):
|
44 |
+
def __init__(self, eos_id: int, batch_index: list[int]):
|
45 |
+
self._eos_id = eos_id
|
46 |
+
self.batch_index = torch.tensor(batch_index, dtype=torch.long).unsqueeze(1)
|
47 |
+
|
48 |
+
def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
|
49 |
+
# input_ids.shape=[batch, seq_len]
|
50 |
+
eos_mask = input_ids == self._eos_id
|
51 |
+
consider_eos_mask = (
|
52 |
+
torch.arange(input_ids.shape[1]).unsqueeze(0) >= self.batch_index
|
53 |
+
)
|
54 |
+
valid_eos = eos_mask & consider_eos_mask
|
55 |
+
return valid_eos.sum(dim=1).all()
|
chameleon/inference/token_selector.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class TokenSelector:
|
10 |
+
def __call__(
|
11 |
+
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
|
12 |
+
) -> torch.FloatTensor:
|
13 |
+
# input_ids.shape=[batch, seq_len]
|
14 |
+
# probs.shape=[batch, vocab]
|
15 |
+
...
|
16 |
+
|
17 |
+
|
18 |
+
class ArgmaxTokenSelector(TokenSelector):
|
19 |
+
def __call__(
|
20 |
+
self, _: torch.LongTensor, probs: torch.FloatTensor
|
21 |
+
) -> torch.LongTensor:
|
22 |
+
# probs.shape=[batch, vocab]
|
23 |
+
return probs.argmax(dim=1)
|
24 |
+
|
25 |
+
|
26 |
+
class MultinomialTokenSelector(TokenSelector):
|
27 |
+
def __call__(
|
28 |
+
self, _: torch.LongTensor, probs: torch.FloatTensor
|
29 |
+
) -> torch.LongTensor:
|
30 |
+
# probs.shape=[batch, vocab]
|
31 |
+
return probs.multinomial(num_samples=1).squeeze(1)
|
32 |
+
|
33 |
+
|
34 |
+
class ReplicatedInputTokenSelector(TokenSelector):
|
35 |
+
def __init__(self, token_selector: TokenSelector, n: int):
|
36 |
+
self.token_selector = token_selector
|
37 |
+
self.n = n
|
38 |
+
|
39 |
+
def __call__(
|
40 |
+
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
|
41 |
+
) -> torch.LongTensor:
|
42 |
+
# input_ids.shape=[n*batch, seq_len]
|
43 |
+
# probs.shape=[n*batch, vocab]
|
44 |
+
primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0]
|
45 |
+
primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0]
|
46 |
+
tokens = self.token_selector(primary_input_ids, primary_probs)
|
47 |
+
return tokens.repeat(self.n)
|
chameleon/inference/transformer.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import distributed as dist
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from xformers.ops import RMSNorm, fmha, rope_padded
|
13 |
+
from xformers.ops.fmha.attn_bias import (
|
14 |
+
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class ModelArgs:
|
20 |
+
model_parallel_size: int = 1
|
21 |
+
dim: int = 512
|
22 |
+
n_layers: int = 8
|
23 |
+
n_heads: int = 8
|
24 |
+
n_kv_heads: int | None = None
|
25 |
+
vocab_size: int = -1
|
26 |
+
ffn_dim_multiplier: float | None = None
|
27 |
+
multiple_of: int = 256
|
28 |
+
norm_eps: float = 1e-5
|
29 |
+
rope_theta: float = 10000.0
|
30 |
+
qk_normalization: bool = False
|
31 |
+
swin_norm: bool = False
|
32 |
+
|
33 |
+
|
34 |
+
LayerCache = tuple[torch.Tensor, torch.Tensor]
|
35 |
+
|
36 |
+
|
37 |
+
class Attention(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
model_parallel_size: int,
|
41 |
+
dim: int,
|
42 |
+
head_dim: int,
|
43 |
+
n_heads: int,
|
44 |
+
n_kv_heads: int,
|
45 |
+
rope_theta: float,
|
46 |
+
qk_normalization: bool = False,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.model_parallel_size = model_parallel_size
|
51 |
+
|
52 |
+
self.head_dim = head_dim
|
53 |
+
self.rope_theta = rope_theta
|
54 |
+
|
55 |
+
self.n_local_heads = n_heads // model_parallel_size
|
56 |
+
self.n_local_kv_heads = n_kv_heads // model_parallel_size
|
57 |
+
|
58 |
+
self.wqkv = nn.Linear(
|
59 |
+
dim,
|
60 |
+
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
|
61 |
+
bias=False,
|
62 |
+
dtype=torch.bfloat16,
|
63 |
+
)
|
64 |
+
self.wo = nn.Linear(
|
65 |
+
self.n_local_heads * head_dim,
|
66 |
+
dim,
|
67 |
+
bias=False,
|
68 |
+
dtype=torch.bfloat16,
|
69 |
+
)
|
70 |
+
|
71 |
+
self.qk_normalization = qk_normalization
|
72 |
+
if qk_normalization:
|
73 |
+
self.q_normalization = torch.nn.LayerNorm(head_dim)
|
74 |
+
self.k_normalization = torch.nn.LayerNorm(head_dim)
|
75 |
+
|
76 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
77 |
+
|
78 |
+
# This adapter makes sure we can load vanilla
|
79 |
+
# Llama checkpoints where wq, wk, and wv are
|
80 |
+
# not fused in a single parameter
|
81 |
+
def load_hook(
|
82 |
+
self,
|
83 |
+
state_dict,
|
84 |
+
prefix,
|
85 |
+
local_metadata,
|
86 |
+
strict,
|
87 |
+
missing_keys,
|
88 |
+
unexpected_keys,
|
89 |
+
error_msgs,
|
90 |
+
):
|
91 |
+
if prefix + "wq.weight" in state_dict:
|
92 |
+
wq = state_dict.pop(prefix + "wq.weight")
|
93 |
+
wk = state_dict.pop(prefix + "wk.weight")
|
94 |
+
wv = state_dict.pop(prefix + "wv.weight")
|
95 |
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
x: torch.Tensor,
|
100 |
+
cache: LayerCache,
|
101 |
+
attn_bias: AttnBias,
|
102 |
+
group: dist.ProcessGroup | None = None,
|
103 |
+
) -> torch.Tensor:
|
104 |
+
# x.shape is (sum(seq_lens), dim)
|
105 |
+
#
|
106 |
+
# Since we support heterogenous sequence
|
107 |
+
# lengths, the hidden states are all
|
108 |
+
# concatenated together along the usual
|
109 |
+
# sequence dimension. The attention below
|
110 |
+
# finds out where sequences start & end
|
111 |
+
# using the provided attention bias.
|
112 |
+
xqkv = self.wqkv(x)
|
113 |
+
xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
|
114 |
+
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
|
115 |
+
xk, xv = xkv.chunk(2, 1)
|
116 |
+
|
117 |
+
if self.qk_normalization:
|
118 |
+
xq = xq.view(-1, self.n_local_heads, self.head_dim)
|
119 |
+
xq = self.q_normalization(xq)
|
120 |
+
xq = xq.view(-1, self.n_local_heads * self.head_dim)
|
121 |
+
|
122 |
+
xk = xk.view(-1, self.n_local_kv_heads, self.head_dim)
|
123 |
+
xk = self.k_normalization(xk)
|
124 |
+
xk = xk.view(-1, self.n_local_kv_heads * self.head_dim)
|
125 |
+
|
126 |
+
output_shape = xq.shape
|
127 |
+
xq = xq.view(1, xq.shape[0], self.n_local_heads, self.head_dim)
|
128 |
+
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, self.head_dim)
|
129 |
+
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, self.head_dim)
|
130 |
+
cache_k, cache_v = cache
|
131 |
+
|
132 |
+
xq = rope_padded(
|
133 |
+
xq=xq,
|
134 |
+
xk=xk,
|
135 |
+
xv=xv,
|
136 |
+
cache_k=cache_k,
|
137 |
+
cache_v=cache_v,
|
138 |
+
attn_bias=attn_bias,
|
139 |
+
theta=self.rope_theta,
|
140 |
+
)
|
141 |
+
|
142 |
+
# Handle GQA
|
143 |
+
# Q shape: [B, M, Hkv, Hq // Hkv, K]
|
144 |
+
heads_per_group = self.n_local_heads // self.n_local_kv_heads
|
145 |
+
cache_k = cache_k.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
|
146 |
+
cache_v = cache_v.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
|
147 |
+
xq = xq.reshape(
|
148 |
+
[*xq.shape[:2], self.n_local_kv_heads, heads_per_group, xq.shape[-1]]
|
149 |
+
)
|
150 |
+
|
151 |
+
# rope_padded() updated the caches, so we
|
152 |
+
# call attention directly
|
153 |
+
output = fmha.memory_efficient_attention_forward(
|
154 |
+
xq, cache_k, cache_v, attn_bias
|
155 |
+
)
|
156 |
+
|
157 |
+
output = self.wo(output.reshape(output_shape))
|
158 |
+
if self.model_parallel_size > 1:
|
159 |
+
dist.all_reduce(output, group=group)
|
160 |
+
|
161 |
+
return output
|
162 |
+
|
163 |
+
|
164 |
+
class FeedForward(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
model_parallel_size: int,
|
168 |
+
dim: int,
|
169 |
+
hidden_dim: int,
|
170 |
+
multiple_of: int,
|
171 |
+
ffn_dim_multiplier: float | None,
|
172 |
+
):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.model_parallel_size = model_parallel_size
|
176 |
+
|
177 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
178 |
+
if ffn_dim_multiplier is not None:
|
179 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
180 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
181 |
+
assert hidden_dim % model_parallel_size == 0
|
182 |
+
|
183 |
+
self.w13 = nn.Linear(
|
184 |
+
dim,
|
185 |
+
2 * hidden_dim // model_parallel_size,
|
186 |
+
bias=False,
|
187 |
+
)
|
188 |
+
self.w2 = nn.Linear(
|
189 |
+
hidden_dim // model_parallel_size,
|
190 |
+
dim,
|
191 |
+
bias=False,
|
192 |
+
)
|
193 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
194 |
+
|
195 |
+
# This adapter makes sure we can load vanilla
|
196 |
+
# Llama checkpoints where w1 and w3 are not
|
197 |
+
# fused in a single parameter
|
198 |
+
def load_hook(
|
199 |
+
self,
|
200 |
+
state_dict,
|
201 |
+
prefix,
|
202 |
+
local_metadata,
|
203 |
+
strict,
|
204 |
+
missing_keys,
|
205 |
+
unexpected_keys,
|
206 |
+
error_msgs,
|
207 |
+
):
|
208 |
+
if prefix + "w1.weight" in state_dict:
|
209 |
+
w1 = state_dict.pop(prefix + "w1.weight")
|
210 |
+
w3 = state_dict.pop(prefix + "w3.weight")
|
211 |
+
state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
|
212 |
+
|
213 |
+
def forward(
|
214 |
+
self, x: torch.Tensor, group: dist.ProcessGroup | None = None
|
215 |
+
) -> torch.Tensor:
|
216 |
+
x13 = self.w13(x)
|
217 |
+
x1, x3 = x13.chunk(2, -1)
|
218 |
+
output = self.w2(F.silu(x1) * x3)
|
219 |
+
if self.model_parallel_size > 1:
|
220 |
+
dist.all_reduce(output, group=group)
|
221 |
+
return output
|
222 |
+
|
223 |
+
|
224 |
+
class TransformerBlock(nn.Module):
|
225 |
+
def __init__(self, args: ModelArgs):
|
226 |
+
super().__init__()
|
227 |
+
|
228 |
+
assert args.dim % args.n_heads == 0
|
229 |
+
head_dim = args.dim // args.n_heads
|
230 |
+
if args.n_kv_heads is not None:
|
231 |
+
n_kv_heads = args.n_kv_heads
|
232 |
+
else:
|
233 |
+
n_kv_heads = args.n_heads
|
234 |
+
|
235 |
+
model_parallel_size = args.model_parallel_size
|
236 |
+
assert args.n_heads % n_kv_heads == 0
|
237 |
+
assert args.n_heads % model_parallel_size == 0
|
238 |
+
assert n_kv_heads % model_parallel_size == 0
|
239 |
+
|
240 |
+
self.attention = Attention(
|
241 |
+
model_parallel_size=model_parallel_size,
|
242 |
+
dim=args.dim,
|
243 |
+
head_dim=head_dim,
|
244 |
+
n_heads=args.n_heads,
|
245 |
+
n_kv_heads=n_kv_heads,
|
246 |
+
rope_theta=args.rope_theta,
|
247 |
+
qk_normalization=args.qk_normalization,
|
248 |
+
)
|
249 |
+
self.feed_forward = FeedForward(
|
250 |
+
model_parallel_size=model_parallel_size,
|
251 |
+
dim=args.dim,
|
252 |
+
hidden_dim=4 * args.dim,
|
253 |
+
multiple_of=args.multiple_of,
|
254 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
255 |
+
)
|
256 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
257 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
258 |
+
self.swin_norm = args.swin_norm
|
259 |
+
|
260 |
+
def forward(
|
261 |
+
self,
|
262 |
+
x: torch.Tensor,
|
263 |
+
cache: LayerCache,
|
264 |
+
attn_bias: AttnBias,
|
265 |
+
group: dist.ProcessGroup | None = None,
|
266 |
+
) -> torch.Tensor:
|
267 |
+
if self.swin_norm:
|
268 |
+
h = x + self.attention_norm(
|
269 |
+
self.attention.forward(
|
270 |
+
x,
|
271 |
+
cache,
|
272 |
+
attn_bias,
|
273 |
+
group=group,
|
274 |
+
)
|
275 |
+
)
|
276 |
+
out = h + self.ffn_norm(self.feed_forward(h, group=group))
|
277 |
+
else:
|
278 |
+
h = x + self.attention.forward(
|
279 |
+
self.attention_norm(x),
|
280 |
+
cache,
|
281 |
+
attn_bias,
|
282 |
+
group=group,
|
283 |
+
)
|
284 |
+
out = h + self.feed_forward(self.ffn_norm(h), group=group)
|
285 |
+
return out
|
286 |
+
|
287 |
+
|
288 |
+
class Transformer(nn.Module):
|
289 |
+
def __init__(self, args: ModelArgs):
|
290 |
+
super().__init__()
|
291 |
+
self.args = args
|
292 |
+
|
293 |
+
self.model_parallel_size = args.model_parallel_size
|
294 |
+
assert args.dim % self.model_parallel_size == 0
|
295 |
+
assert args.vocab_size > 0
|
296 |
+
assert args.vocab_size % self.model_parallel_size == 0
|
297 |
+
|
298 |
+
self.tok_embeddings = nn.Embedding(
|
299 |
+
num_embeddings=args.vocab_size,
|
300 |
+
embedding_dim=args.dim // self.model_parallel_size,
|
301 |
+
)
|
302 |
+
|
303 |
+
self.layers = nn.ModuleList()
|
304 |
+
for _ in range(args.n_layers):
|
305 |
+
self.layers.append(TransformerBlock(args))
|
306 |
+
|
307 |
+
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
308 |
+
|
309 |
+
self.output = nn.Linear(
|
310 |
+
args.dim,
|
311 |
+
args.vocab_size // self.model_parallel_size,
|
312 |
+
bias=False,
|
313 |
+
)
|
314 |
+
|
315 |
+
@torch.no_grad()
|
316 |
+
def forward_with_attn_bias(
|
317 |
+
self,
|
318 |
+
token_values: torch.Tensor,
|
319 |
+
attn_bias: AttnBias,
|
320 |
+
cache: list[LayerCache],
|
321 |
+
group: dist.ProcessGroup | None = None,
|
322 |
+
) -> torch.Tensor:
|
323 |
+
h = self.tok_embeddings(token_values)
|
324 |
+
if self.model_parallel_size > 1:
|
325 |
+
gather = [torch.empty_like(h) for _ in range(self.model_parallel_size)]
|
326 |
+
dist.all_gather(gather, h, group=group)
|
327 |
+
h = torch.cat(gather, dim=-1)
|
328 |
+
|
329 |
+
for i, layer in enumerate(self.layers):
|
330 |
+
h = layer(h, cache[i], attn_bias, group=group)
|
331 |
+
|
332 |
+
logits = self.output(self.norm(h))
|
333 |
+
if self.model_parallel_size > 1:
|
334 |
+
gather = [torch.empty_like(logits) for _ in range(self.model_parallel_size)]
|
335 |
+
dist.all_gather(gather, logits, group=group)
|
336 |
+
logits = torch.cat(gather, dim=-1)
|
337 |
+
return logits.float()
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
token_values: torch.Tensor,
|
342 |
+
token_lengths: torch.Tensor,
|
343 |
+
start_pos: torch.Tensor,
|
344 |
+
cache: list[LayerCache],
|
345 |
+
kv_padding: int,
|
346 |
+
group: dist.ProcessGroup | None = None,
|
347 |
+
) -> torch.Tensor:
|
348 |
+
attn_bias = AttnBias.from_seqlens(
|
349 |
+
q_seqlen=token_lengths.tolist(),
|
350 |
+
kv_seqlen=(start_pos + token_lengths).tolist(),
|
351 |
+
kv_padding=kv_padding,
|
352 |
+
)
|
353 |
+
return self.forward_with_attn_bias(token_values, attn_bias, cache, group=group)
|
354 |
+
|
355 |
+
|
356 |
+
def make_cache(
|
357 |
+
args: ModelArgs,
|
358 |
+
length: int,
|
359 |
+
device: str | torch.device | None = None,
|
360 |
+
n_layers: int | None = None,
|
361 |
+
dtype: torch.dtype | None = None,
|
362 |
+
) -> list[LayerCache]:
|
363 |
+
"""
|
364 |
+
Allocate a cache to be used with the Transformer module.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
args (ModelArgs): the model configuration.
|
368 |
+
length (int): per layer cache size.
|
369 |
+
It is usually budgeted as ``max_batch * max_seq``
|
370 |
+
device (torch.device, optional): the device on which
|
371 |
+
the cache should be allocated.
|
372 |
+
n_layers (int, optional): the number of layers to
|
373 |
+
allocate a cache for (defaults to the model
|
374 |
+
settings).
|
375 |
+
dtype (torch.dtype, optional): the dtype to use for
|
376 |
+
cache entries (defaults to the default dtype).
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
The cache object to pass to ``Tranformer.forward``.
|
380 |
+
"""
|
381 |
+
|
382 |
+
head_dim = args.dim // args.n_heads
|
383 |
+
n_kv_heads = args.n_kv_heads
|
384 |
+
if n_kv_heads is None:
|
385 |
+
n_kv_heads = args.n_heads
|
386 |
+
n_local_kv_heads = n_kv_heads // args.model_parallel_size
|
387 |
+
|
388 |
+
if n_layers is None:
|
389 |
+
n_layers = args.n_layers
|
390 |
+
|
391 |
+
shape = (1, length, n_local_kv_heads, head_dim)
|
392 |
+
return [
|
393 |
+
(
|
394 |
+
torch.zeros(shape, device=device, dtype=dtype),
|
395 |
+
torch.zeros(shape, device=device, dtype=dtype),
|
396 |
+
)
|
397 |
+
for _ in range(n_layers)
|
398 |
+
]
|
399 |
+
|
400 |
+
|
401 |
+
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
|
402 |
+
"""
|
403 |
+
Take a prefix view of a larger cache.
|
404 |
+
|
405 |
+
The original cache object remains of identical size and valid
|
406 |
+
after the shrinked alias has been used. This function is useful
|
407 |
+
when a cache was allocated for a larger batch size than what is
|
408 |
+
necessary.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
cache: the cache to take a view in.
|
412 |
+
length (int): the desired length
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
A view in the input cache object.
|
416 |
+
"""
|
417 |
+
|
418 |
+
if len(cache) > 0:
|
419 |
+
assert cache[0][0].shape[1] >= length
|
420 |
+
|
421 |
+
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
|
chameleon/inference/utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import socket
|
7 |
+
from typing import Generator, Generic, Iterator, TypeVar
|
8 |
+
|
9 |
+
T = TypeVar("T")
|
10 |
+
|
11 |
+
|
12 |
+
class DynamicGenerator(Generic[T]):
|
13 |
+
def __init__(self, gen: Generator[T, None, None]):
|
14 |
+
self.gen = gen
|
15 |
+
|
16 |
+
def __iter__(self) -> Iterator[T]:
|
17 |
+
return self
|
18 |
+
|
19 |
+
def __next__(self) -> T:
|
20 |
+
return next(self.gen)
|
21 |
+
|
22 |
+
|
23 |
+
def advance(iterator: Iterator[T], steps: int):
|
24 |
+
try:
|
25 |
+
for _ in range(steps):
|
26 |
+
next(iterator)
|
27 |
+
except StopIteration:
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
def random_unused_port():
|
32 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
33 |
+
s.bind(("", 0))
|
34 |
+
return s.getsockname()[1]
|
chameleon/inference/vocab.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from functools import cached_property
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class VocabInfo:
|
12 |
+
def __init__(self, vocab_map: dict[str, int]):
|
13 |
+
self.name2val = vocab_map
|
14 |
+
|
15 |
+
self.bos_id = vocab_map.get("<s>")
|
16 |
+
self.eos_id = vocab_map.get("</s>")
|
17 |
+
self.boi_id = vocab_map.get("<racm3:break>")
|
18 |
+
self.eoi_id = vocab_map.get("<eoss>")
|
19 |
+
self.pad_id = vocab_map.get("<pad>")
|
20 |
+
self.eot_id = vocab_map.get("<reserved08706>")
|
21 |
+
|
22 |
+
@property
|
23 |
+
def begin_sequence(self) -> int:
|
24 |
+
return self.bos_id
|
25 |
+
|
26 |
+
@property
|
27 |
+
def end_sequence(self) -> int:
|
28 |
+
return self.eos_id
|
29 |
+
|
30 |
+
@property
|
31 |
+
def begin_image(self) -> int:
|
32 |
+
return self.boi_id
|
33 |
+
|
34 |
+
@property
|
35 |
+
def end_image(self) -> int:
|
36 |
+
return self.eoi_id
|
37 |
+
|
38 |
+
@property
|
39 |
+
def padding(self) -> int:
|
40 |
+
return self.pad_id
|
41 |
+
|
42 |
+
@property
|
43 |
+
def end_turn(self) -> int:
|
44 |
+
return self.eot_id
|
45 |
+
|
46 |
+
@cached_property
|
47 |
+
def val2name(self) -> dict[int, str]:
|
48 |
+
return {v: k for k, v in self.name2val.items()}
|
49 |
+
|
50 |
+
@cached_property
|
51 |
+
def all_tokens(self) -> list[int]:
|
52 |
+
return sorted(self.name2val.values())
|
53 |
+
|
54 |
+
@cached_property
|
55 |
+
def image_tokens(self) -> list[int]:
|
56 |
+
return sorted(
|
57 |
+
[val for name, val in self.name2val.items() if name.startswith("IMGIMG")]
|
58 |
+
)
|
59 |
+
|
60 |
+
@cached_property
|
61 |
+
def special_tokens(self) -> list[int]:
|
62 |
+
return sorted(
|
63 |
+
[
|
64 |
+
val
|
65 |
+
for name, val in self.name2val.items()
|
66 |
+
if name.startswith("<") and name != "<"
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
@cached_property
|
71 |
+
def text_tokens(self) -> list[int]:
|
72 |
+
return sorted(
|
73 |
+
set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens)
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
class VocabTranslation:
|
78 |
+
def __init__(self, vocab_info: VocabInfo, device: str | None = None):
|
79 |
+
self._vocab = vocab_info
|
80 |
+
self._device = device
|
81 |
+
|
82 |
+
@cached_property
|
83 |
+
def bpe2img(self) -> dict[int, int]:
|
84 |
+
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
|
85 |
+
|
86 |
+
def remap(old_name: str) -> str:
|
87 |
+
return "".join(
|
88 |
+
img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
|
89 |
+
)
|
90 |
+
|
91 |
+
return {
|
92 |
+
tok: int(remap(self._vocab.val2name[tok]))
|
93 |
+
for tok in self._vocab.image_tokens
|
94 |
+
}
|
95 |
+
|
96 |
+
@cached_property
|
97 |
+
def img2bpe(self) -> dict[int, int]:
|
98 |
+
return {v: k for k, v in self.bpe2img.items()}
|
99 |
+
|
100 |
+
@cached_property
|
101 |
+
def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]:
|
102 |
+
sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device)
|
103 |
+
sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device)
|
104 |
+
return sorted_bpe, sorted_img
|
105 |
+
|
106 |
+
@cached_property
|
107 |
+
def img2bpe_mapping_tensor(self) -> torch.LongTensor:
|
108 |
+
mapping = torch.zeros(
|
109 |
+
max(self.img2bpe.keys()) + 1,
|
110 |
+
dtype=torch.int,
|
111 |
+
device=self._device,
|
112 |
+
)
|
113 |
+
for k, v in self.img2bpe.items():
|
114 |
+
mapping[k] = v
|
115 |
+
return mapping
|
116 |
+
|
117 |
+
def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor:
|
118 |
+
bpe_tok, img_tok = self.bpe2img_search_tensors
|
119 |
+
return img_tok[torch.searchsorted(bpe_tok, bpe_batch)]
|
120 |
+
|
121 |
+
def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor:
|
122 |
+
return self.img2bpe_mapping_tensor[img_batch]
|
chameleon/inference/vqgan.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
|
8 |
+
[with minimal dependencies]
|
9 |
+
|
10 |
+
This implementation is inference-only -- training steps and optimizer components
|
11 |
+
introduce significant additional dependencies
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
class VectorQuantizer2(nn.Module):
|
21 |
+
"""
|
22 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
23 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
24 |
+
"""
|
25 |
+
|
26 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
27 |
+
# backwards compatibility we use the buggy version by default, but you can
|
28 |
+
# specify legacy=False to fix it.
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
n_e,
|
32 |
+
e_dim,
|
33 |
+
beta,
|
34 |
+
remap=None,
|
35 |
+
unknown_index="random",
|
36 |
+
sane_index_shape=False,
|
37 |
+
legacy=True,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.n_e = n_e
|
41 |
+
self.e_dim = e_dim
|
42 |
+
self.beta = beta
|
43 |
+
self.legacy = legacy
|
44 |
+
|
45 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
46 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
47 |
+
|
48 |
+
self.remap = remap
|
49 |
+
if self.remap is not None:
|
50 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
51 |
+
self.re_embed = self.used.shape[0]
|
52 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
53 |
+
if self.unknown_index == "extra":
|
54 |
+
self.unknown_index = self.re_embed
|
55 |
+
self.re_embed = self.re_embed + 1
|
56 |
+
print(
|
57 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
58 |
+
f"Using {self.unknown_index} for unknown indices."
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
self.re_embed = n_e
|
62 |
+
|
63 |
+
self.sane_index_shape = sane_index_shape
|
64 |
+
|
65 |
+
def remap_to_used(self, inds):
|
66 |
+
ishape = inds.shape
|
67 |
+
assert len(ishape) > 1
|
68 |
+
inds = inds.reshape(ishape[0], -1)
|
69 |
+
used = self.used.to(inds)
|
70 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
71 |
+
new = match.argmax(-1)
|
72 |
+
unknown = match.sum(2) < 1
|
73 |
+
if self.unknown_index == "random":
|
74 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
75 |
+
device=new.device
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
new[unknown] = self.unknown_index
|
79 |
+
return new.reshape(ishape)
|
80 |
+
|
81 |
+
def unmap_to_all(self, inds):
|
82 |
+
ishape = inds.shape
|
83 |
+
assert len(ishape) > 1
|
84 |
+
inds = inds.reshape(ishape[0], -1)
|
85 |
+
used = self.used.to(inds)
|
86 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
87 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
88 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
89 |
+
return back.reshape(ishape)
|
90 |
+
|
91 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
92 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
93 |
+
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
94 |
+
assert return_logits is False, "Only for interface compatible with Gumbel"
|
95 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
96 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
97 |
+
z_flattened = z.view(-1, self.e_dim)
|
98 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
99 |
+
|
100 |
+
d = (
|
101 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
102 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
103 |
+
- 2
|
104 |
+
* torch.einsum(
|
105 |
+
"bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
|
106 |
+
)
|
107 |
+
)
|
108 |
+
|
109 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
110 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
111 |
+
perplexity = None
|
112 |
+
min_encodings = None
|
113 |
+
|
114 |
+
# compute loss for embedding
|
115 |
+
if not self.legacy:
|
116 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
117 |
+
(z_q - z.detach()) ** 2
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
121 |
+
(z_q - z.detach()) ** 2
|
122 |
+
)
|
123 |
+
|
124 |
+
# preserve gradients
|
125 |
+
z_q = z + (z_q - z).detach()
|
126 |
+
|
127 |
+
# reshape back to match original input shape
|
128 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
129 |
+
|
130 |
+
if self.remap is not None:
|
131 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
132 |
+
z.shape[0], -1
|
133 |
+
) # add batch axis
|
134 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
135 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
136 |
+
|
137 |
+
if self.sane_index_shape:
|
138 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
139 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
140 |
+
)
|
141 |
+
|
142 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
143 |
+
|
144 |
+
def get_codebook_entry(self, indices, shape):
|
145 |
+
# shape specifying (batch, height, width, channel)
|
146 |
+
if self.remap is not None:
|
147 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
148 |
+
indices = self.unmap_to_all(indices)
|
149 |
+
indices = indices.reshape(-1) # flatten again
|
150 |
+
|
151 |
+
# get quantized latent vectors
|
152 |
+
z_q = self.embedding(indices)
|
153 |
+
|
154 |
+
if shape is not None:
|
155 |
+
z_q = z_q.view(shape)
|
156 |
+
# reshape back to match original input shape
|
157 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
158 |
+
|
159 |
+
return z_q
|
160 |
+
|
161 |
+
|
162 |
+
# Alias
|
163 |
+
VectorQuantizer = VectorQuantizer2
|
164 |
+
|
165 |
+
|
166 |
+
def nonlinearity(x):
|
167 |
+
# swish
|
168 |
+
return x * torch.sigmoid(x)
|
169 |
+
|
170 |
+
|
171 |
+
def Normalize(in_channels, num_groups=32):
|
172 |
+
return torch.nn.GroupNorm(
|
173 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
class Upsample(nn.Module):
|
178 |
+
def __init__(self, in_channels, with_conv):
|
179 |
+
super().__init__()
|
180 |
+
self.with_conv = with_conv
|
181 |
+
if self.with_conv:
|
182 |
+
self.conv = torch.nn.Conv2d(
|
183 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
184 |
+
)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
188 |
+
if self.with_conv:
|
189 |
+
x = self.conv(x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
class Downsample(nn.Module):
|
194 |
+
def __init__(self, in_channels, with_conv):
|
195 |
+
super().__init__()
|
196 |
+
self.with_conv = with_conv
|
197 |
+
if self.with_conv:
|
198 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
199 |
+
self.conv = torch.nn.Conv2d(
|
200 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
if self.with_conv:
|
205 |
+
pad = (0, 1, 0, 1)
|
206 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
207 |
+
x = self.conv(x)
|
208 |
+
else:
|
209 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
210 |
+
return x
|
211 |
+
|
212 |
+
|
213 |
+
class ResnetBlock(nn.Module):
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
*,
|
217 |
+
in_channels,
|
218 |
+
out_channels=None,
|
219 |
+
conv_shortcut=False,
|
220 |
+
dropout,
|
221 |
+
temb_channels=512,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
self.in_channels = in_channels
|
225 |
+
out_channels = in_channels if out_channels is None else out_channels
|
226 |
+
self.out_channels = out_channels
|
227 |
+
self.use_conv_shortcut = conv_shortcut
|
228 |
+
|
229 |
+
self.norm1 = Normalize(in_channels)
|
230 |
+
self.conv1 = torch.nn.Conv2d(
|
231 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
232 |
+
)
|
233 |
+
if temb_channels > 0:
|
234 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
235 |
+
self.norm2 = Normalize(out_channels)
|
236 |
+
self.dropout = torch.nn.Dropout(dropout)
|
237 |
+
self.conv2 = torch.nn.Conv2d(
|
238 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
239 |
+
)
|
240 |
+
if self.in_channels != self.out_channels:
|
241 |
+
if self.use_conv_shortcut:
|
242 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
243 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
247 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
248 |
+
)
|
249 |
+
|
250 |
+
def forward(self, x, temb):
|
251 |
+
h = x
|
252 |
+
h = self.norm1(h)
|
253 |
+
h = nonlinearity(h)
|
254 |
+
h = self.conv1(h)
|
255 |
+
|
256 |
+
if temb is not None:
|
257 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
258 |
+
|
259 |
+
h = self.norm2(h)
|
260 |
+
h = nonlinearity(h)
|
261 |
+
h = self.dropout(h)
|
262 |
+
h = self.conv2(h)
|
263 |
+
|
264 |
+
if self.in_channels != self.out_channels:
|
265 |
+
if self.use_conv_shortcut:
|
266 |
+
x = self.conv_shortcut(x)
|
267 |
+
else:
|
268 |
+
x = self.nin_shortcut(x)
|
269 |
+
|
270 |
+
return x + h
|
271 |
+
|
272 |
+
|
273 |
+
class AttnBlock(nn.Module):
|
274 |
+
def __init__(self, in_channels):
|
275 |
+
super().__init__()
|
276 |
+
self.in_channels = in_channels
|
277 |
+
|
278 |
+
self.norm = Normalize(in_channels)
|
279 |
+
self.q = torch.nn.Conv2d(
|
280 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
281 |
+
)
|
282 |
+
self.k = torch.nn.Conv2d(
|
283 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
284 |
+
)
|
285 |
+
self.v = torch.nn.Conv2d(
|
286 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
287 |
+
)
|
288 |
+
self.proj_out = torch.nn.Conv2d(
|
289 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(self, x):
|
293 |
+
h_ = x
|
294 |
+
h_ = self.norm(h_)
|
295 |
+
q = self.q(h_)
|
296 |
+
k = self.k(h_)
|
297 |
+
v = self.v(h_)
|
298 |
+
|
299 |
+
# compute attention
|
300 |
+
b, c, h, w = q.shape
|
301 |
+
q = q.reshape(b, c, h * w)
|
302 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
303 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
304 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
305 |
+
w_ = w_ * (int(c) ** (-0.5))
|
306 |
+
w_ = F.softmax(w_, dim=2)
|
307 |
+
|
308 |
+
# attend to values
|
309 |
+
v = v.reshape(b, c, h * w)
|
310 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
311 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
312 |
+
h_ = h_.reshape(b, c, h, w)
|
313 |
+
|
314 |
+
h_ = self.proj_out(h_)
|
315 |
+
|
316 |
+
return x + h_
|
317 |
+
|
318 |
+
|
319 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
320 |
+
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
|
321 |
+
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
322 |
+
if attn_type == "vanilla":
|
323 |
+
return AttnBlock(in_channels)
|
324 |
+
elif attn_type == "none":
|
325 |
+
return nn.Identity(in_channels)
|
326 |
+
else:
|
327 |
+
raise ValueError("Unexpected attention type")
|
328 |
+
|
329 |
+
|
330 |
+
class Encoder(nn.Module):
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
*,
|
334 |
+
ch,
|
335 |
+
out_ch,
|
336 |
+
ch_mult=(1, 2, 4, 8),
|
337 |
+
num_res_blocks,
|
338 |
+
attn_resolutions,
|
339 |
+
dropout=0.0,
|
340 |
+
resamp_with_conv=True,
|
341 |
+
in_channels,
|
342 |
+
resolution,
|
343 |
+
z_channels,
|
344 |
+
double_z=True,
|
345 |
+
use_linear_attn=False,
|
346 |
+
attn_type="vanilla",
|
347 |
+
**ignore_kwargs,
|
348 |
+
):
|
349 |
+
super().__init__()
|
350 |
+
if use_linear_attn:
|
351 |
+
attn_type = "linear"
|
352 |
+
self.ch = ch
|
353 |
+
self.temb_ch = 0
|
354 |
+
self.num_resolutions = len(ch_mult)
|
355 |
+
self.num_res_blocks = num_res_blocks
|
356 |
+
self.resolution = resolution
|
357 |
+
self.in_channels = in_channels
|
358 |
+
|
359 |
+
# downsampling
|
360 |
+
self.conv_in = torch.nn.Conv2d(
|
361 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
362 |
+
)
|
363 |
+
|
364 |
+
curr_res = resolution
|
365 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
366 |
+
self.in_ch_mult = in_ch_mult
|
367 |
+
self.down = nn.ModuleList()
|
368 |
+
for i_level in range(self.num_resolutions):
|
369 |
+
block = nn.ModuleList()
|
370 |
+
attn = nn.ModuleList()
|
371 |
+
block_in = ch * in_ch_mult[i_level]
|
372 |
+
block_out = ch * ch_mult[i_level]
|
373 |
+
for i_block in range(self.num_res_blocks):
|
374 |
+
block.append(
|
375 |
+
ResnetBlock(
|
376 |
+
in_channels=block_in,
|
377 |
+
out_channels=block_out,
|
378 |
+
temb_channels=self.temb_ch,
|
379 |
+
dropout=dropout,
|
380 |
+
)
|
381 |
+
)
|
382 |
+
block_in = block_out
|
383 |
+
if curr_res in attn_resolutions:
|
384 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
385 |
+
down = nn.Module()
|
386 |
+
down.block = block
|
387 |
+
down.attn = attn
|
388 |
+
if i_level != self.num_resolutions - 1:
|
389 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
390 |
+
curr_res = curr_res // 2
|
391 |
+
self.down.append(down)
|
392 |
+
|
393 |
+
# middle
|
394 |
+
self.mid = nn.Module()
|
395 |
+
self.mid.block_1 = ResnetBlock(
|
396 |
+
in_channels=block_in,
|
397 |
+
out_channels=block_in,
|
398 |
+
temb_channels=self.temb_ch,
|
399 |
+
dropout=dropout,
|
400 |
+
)
|
401 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
402 |
+
self.mid.block_2 = ResnetBlock(
|
403 |
+
in_channels=block_in,
|
404 |
+
out_channels=block_in,
|
405 |
+
temb_channels=self.temb_ch,
|
406 |
+
dropout=dropout,
|
407 |
+
)
|
408 |
+
|
409 |
+
# end
|
410 |
+
self.norm_out = Normalize(block_in)
|
411 |
+
self.conv_out = torch.nn.Conv2d(
|
412 |
+
block_in,
|
413 |
+
2 * z_channels if double_z else z_channels,
|
414 |
+
kernel_size=3,
|
415 |
+
stride=1,
|
416 |
+
padding=1,
|
417 |
+
)
|
418 |
+
|
419 |
+
def forward(self, x):
|
420 |
+
# timestep embedding
|
421 |
+
temb = None
|
422 |
+
|
423 |
+
# downsampling
|
424 |
+
hs = [self.conv_in(x)]
|
425 |
+
for i_level in range(self.num_resolutions):
|
426 |
+
for i_block in range(self.num_res_blocks):
|
427 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
428 |
+
if len(self.down[i_level].attn) > 0:
|
429 |
+
h = self.down[i_level].attn[i_block](h)
|
430 |
+
hs.append(h)
|
431 |
+
if i_level != self.num_resolutions - 1:
|
432 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
433 |
+
|
434 |
+
# middle
|
435 |
+
h = hs[-1]
|
436 |
+
h = self.mid.block_1(h, temb)
|
437 |
+
h = self.mid.attn_1(h)
|
438 |
+
h = self.mid.block_2(h, temb)
|
439 |
+
|
440 |
+
# end
|
441 |
+
h = self.norm_out(h)
|
442 |
+
h = nonlinearity(h)
|
443 |
+
h = self.conv_out(h)
|
444 |
+
return h
|
445 |
+
|
446 |
+
|
447 |
+
class Decoder(nn.Module):
|
448 |
+
def __init__(
|
449 |
+
self,
|
450 |
+
*,
|
451 |
+
ch,
|
452 |
+
out_ch,
|
453 |
+
ch_mult=(1, 2, 4, 8),
|
454 |
+
num_res_blocks,
|
455 |
+
attn_resolutions,
|
456 |
+
dropout=0.0,
|
457 |
+
resamp_with_conv=True,
|
458 |
+
in_channels,
|
459 |
+
resolution,
|
460 |
+
z_channels,
|
461 |
+
give_pre_end=False,
|
462 |
+
tanh_out=False,
|
463 |
+
use_linear_attn=False,
|
464 |
+
attn_type="vanilla",
|
465 |
+
**ignorekwargs,
|
466 |
+
):
|
467 |
+
super().__init__()
|
468 |
+
if use_linear_attn:
|
469 |
+
attn_type = "linear"
|
470 |
+
self.ch = ch
|
471 |
+
self.temb_ch = 0
|
472 |
+
self.num_resolutions = len(ch_mult)
|
473 |
+
self.num_res_blocks = num_res_blocks
|
474 |
+
self.resolution = resolution
|
475 |
+
self.in_channels = in_channels
|
476 |
+
self.give_pre_end = give_pre_end
|
477 |
+
self.tanh_out = tanh_out
|
478 |
+
|
479 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
480 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
481 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
482 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
483 |
+
|
484 |
+
# z to block_in
|
485 |
+
self.conv_in = torch.nn.Conv2d(
|
486 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
487 |
+
)
|
488 |
+
|
489 |
+
# middle
|
490 |
+
self.mid = nn.Module()
|
491 |
+
self.mid.block_1 = ResnetBlock(
|
492 |
+
in_channels=block_in,
|
493 |
+
out_channels=block_in,
|
494 |
+
temb_channels=self.temb_ch,
|
495 |
+
dropout=dropout,
|
496 |
+
)
|
497 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
498 |
+
self.mid.block_2 = ResnetBlock(
|
499 |
+
in_channels=block_in,
|
500 |
+
out_channels=block_in,
|
501 |
+
temb_channels=self.temb_ch,
|
502 |
+
dropout=dropout,
|
503 |
+
)
|
504 |
+
|
505 |
+
# upsampling
|
506 |
+
self.up = nn.ModuleList()
|
507 |
+
for i_level in reversed(range(self.num_resolutions)):
|
508 |
+
block = nn.ModuleList()
|
509 |
+
attn = nn.ModuleList()
|
510 |
+
block_out = ch * ch_mult[i_level]
|
511 |
+
for i_block in range(self.num_res_blocks + 1):
|
512 |
+
block.append(
|
513 |
+
ResnetBlock(
|
514 |
+
in_channels=block_in,
|
515 |
+
out_channels=block_out,
|
516 |
+
temb_channels=self.temb_ch,
|
517 |
+
dropout=dropout,
|
518 |
+
)
|
519 |
+
)
|
520 |
+
block_in = block_out
|
521 |
+
if curr_res in attn_resolutions:
|
522 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
523 |
+
up = nn.Module()
|
524 |
+
up.block = block
|
525 |
+
up.attn = attn
|
526 |
+
if i_level != 0:
|
527 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
528 |
+
curr_res = curr_res * 2
|
529 |
+
self.up.insert(0, up) # prepend to get consistent order
|
530 |
+
|
531 |
+
# end
|
532 |
+
self.norm_out = Normalize(block_in)
|
533 |
+
self.conv_out = torch.nn.Conv2d(
|
534 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
535 |
+
)
|
536 |
+
|
537 |
+
def forward(self, z):
|
538 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
539 |
+
self.last_z_shape = z.shape
|
540 |
+
|
541 |
+
# timestep embedding
|
542 |
+
temb = None
|
543 |
+
|
544 |
+
# z to block_in
|
545 |
+
h = self.conv_in(z)
|
546 |
+
|
547 |
+
# middle
|
548 |
+
h = self.mid.block_1(h, temb)
|
549 |
+
h = self.mid.attn_1(h)
|
550 |
+
h = self.mid.block_2(h, temb)
|
551 |
+
|
552 |
+
# upsampling
|
553 |
+
for i_level in reversed(range(self.num_resolutions)):
|
554 |
+
for i_block in range(self.num_res_blocks + 1):
|
555 |
+
h = self.up[i_level].block[i_block](h, temb)
|
556 |
+
if len(self.up[i_level].attn) > 0:
|
557 |
+
h = self.up[i_level].attn[i_block](h)
|
558 |
+
if i_level != 0:
|
559 |
+
h = self.up[i_level].upsample(h)
|
560 |
+
|
561 |
+
# end
|
562 |
+
if self.give_pre_end:
|
563 |
+
return h
|
564 |
+
|
565 |
+
h = self.norm_out(h)
|
566 |
+
h = nonlinearity(h)
|
567 |
+
h = self.conv_out(h)
|
568 |
+
if self.tanh_out:
|
569 |
+
h = torch.tanh(h)
|
570 |
+
return h
|
571 |
+
|
572 |
+
|
573 |
+
class VQModel(nn.Module):
|
574 |
+
def __init__(
|
575 |
+
self,
|
576 |
+
ddconfig,
|
577 |
+
n_embed,
|
578 |
+
embed_dim,
|
579 |
+
ckpt_path=None,
|
580 |
+
ignore_keys=[],
|
581 |
+
image_key="image",
|
582 |
+
colorize_nlabels=None,
|
583 |
+
monitor=None,
|
584 |
+
scheduler_config=None,
|
585 |
+
lr_g_factor=1.0,
|
586 |
+
remap=None,
|
587 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
588 |
+
):
|
589 |
+
super().__init__()
|
590 |
+
self.image_key = image_key
|
591 |
+
self.encoder = Encoder(**ddconfig)
|
592 |
+
self.decoder = Decoder(**ddconfig)
|
593 |
+
self.quantize = VectorQuantizer(
|
594 |
+
n_embed,
|
595 |
+
embed_dim,
|
596 |
+
beta=0.25,
|
597 |
+
remap=remap,
|
598 |
+
sane_index_shape=sane_index_shape,
|
599 |
+
)
|
600 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
601 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
602 |
+
if ckpt_path is not None:
|
603 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
604 |
+
self.image_key = image_key
|
605 |
+
if colorize_nlabels is not None:
|
606 |
+
assert isinstance(colorize_nlabels, int)
|
607 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
608 |
+
if monitor is not None:
|
609 |
+
self.monitor = monitor
|
610 |
+
self.scheduler_config = scheduler_config
|
611 |
+
self.lr_g_factor = lr_g_factor
|
612 |
+
|
613 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
614 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
615 |
+
keys = list(sd.keys())
|
616 |
+
for k in keys:
|
617 |
+
for ik in ignore_keys:
|
618 |
+
if k.startswith(ik):
|
619 |
+
print("Deleting key {} from state_dict.".format(k))
|
620 |
+
del sd[k]
|
621 |
+
self.load_state_dict(sd, strict=False)
|
622 |
+
print(f"VQModel loaded from {path}")
|
623 |
+
|
624 |
+
def encode(self, x):
|
625 |
+
h = self.encoder(x)
|
626 |
+
h = self.quant_conv(h)
|
627 |
+
quant, emb_loss, info = self.quantize(h)
|
628 |
+
return quant, emb_loss, info
|
629 |
+
|
630 |
+
def decode(self, quant):
|
631 |
+
quant = self.post_quant_conv(quant)
|
632 |
+
dec = self.decoder(quant)
|
633 |
+
return dec
|
634 |
+
|
635 |
+
def decode_code(self, code_b):
|
636 |
+
quant_b = self.quantize.embed_code(code_b)
|
637 |
+
dec = self.decode(quant_b)
|
638 |
+
return dec
|
639 |
+
|
640 |
+
def forward(self, input):
|
641 |
+
quant, diff, _ = self.encode(input)
|
642 |
+
dec = self.decode(quant)
|
643 |
+
return dec, diff
|
644 |
+
|
645 |
+
def get_input(self, batch, k):
|
646 |
+
x = batch[k]
|
647 |
+
if len(x.shape) == 3:
|
648 |
+
x = x[..., None]
|
649 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
650 |
+
return x.float()
|
651 |
+
|
652 |
+
def get_last_layer(self):
|
653 |
+
return self.decoder.conv_out.weight
|
654 |
+
|
655 |
+
def log_images(self, batch, **kwargs):
|
656 |
+
log = dict()
|
657 |
+
x = self.get_input(batch, self.image_key)
|
658 |
+
x = x.to(self.device)
|
659 |
+
xrec, _ = self(x)
|
660 |
+
if x.shape[1] > 3:
|
661 |
+
# colorize with random projection
|
662 |
+
assert xrec.shape[1] > 3
|
663 |
+
x = self.to_rgb(x)
|
664 |
+
xrec = self.to_rgb(xrec)
|
665 |
+
log["inputs"] = x
|
666 |
+
log["reconstructions"] = xrec
|
667 |
+
return log
|
668 |
+
|
669 |
+
def to_rgb(self, x):
|
670 |
+
assert self.image_key == "segmentation"
|
671 |
+
if not hasattr(self, "colorize"):
|
672 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
673 |
+
x = F.conv2d(x, weight=self.colorize)
|
674 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
675 |
+
return x
|
chameleon/vqgan.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ede986bf6b171db3081ce171ad88e4ac970793cea14c180b3e5ac5105f4cb43
|
3 |
+
size 281270377
|
chameleon/vqgan.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: taming.models.vqgan.VQModel
|
4 |
+
params:
|
5 |
+
embed_dim: 256
|
6 |
+
n_embed: 8192
|
7 |
+
ddconfig:
|
8 |
+
double_z: false
|
9 |
+
z_channels: 256
|
10 |
+
resolution: 512
|
11 |
+
in_channels: 3
|
12 |
+
out_ch: 3
|
13 |
+
ch: 128
|
14 |
+
ch_mult:
|
15 |
+
- 1
|
16 |
+
- 1
|
17 |
+
- 2
|
18 |
+
- 2
|
19 |
+
- 4
|
20 |
+
num_res_blocks: 2
|
21 |
+
attn_resolutions: []
|
22 |
+
dropout: 0.0
|
23 |
+
lossconfig:
|
24 |
+
target: taming.modules.losses.vqperceptual_vit_vqgan.VQLPIPSWithDiscriminator
|
25 |
+
params:
|
26 |
+
disc_start: 100001
|
27 |
+
perceptual_weight: 1.0
|
28 |
+
adversarial_weight: 0.5
|
29 |
+
disc_params:
|
30 |
+
size: 512
|
31 |
+
ckpt_path: manifold://fair_onellm_checkpoints/tree/v2/tokenizer/vqgan_wm_0209.ckpt
|
32 |
+
data:
|
33 |
+
target: main.DataModuleFromConfig
|
34 |
+
params:
|
35 |
+
batch_size: 4
|
36 |
+
num_workers: 10
|
37 |
+
image_size: 512
|
38 |
+
filter_image_size: 512
|
39 |
+
dataset: coco
|
40 |
+
aesthetics_th: 0
|
41 |
+
clipsim_th: 0
|
42 |
+
--distributed-world-size: null
|
43 |
+
'32': null
|
44 |
+
--distributed-port: null
|
45 |
+
'17338': null
|
46 |
+
--save-dir: null
|
47 |
+
/checkpoint/shellysheynin/shutterstock/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN:
|
48 |
+
log_every-500:
|
49 |
+
ngpu32: null
|
50 |
+
--tensorboard-logdir: null
|
51 |
+
/checkpoint/shellysheynin/tensorboard_logs/2023-03-30/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN:
|
52 |
+
log_every-500:
|
53 |
+
ngpu32: null
|
54 |
+
'14561': null
|
55 |
+
/checkpoint/shellysheynin/tensorboard_logs/2023-04-02/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN:
|
56 |
+
log_every-500:
|
57 |
+
ngpu32: null
|
conversation.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
class SeparatorStyle(Enum):
|
10 |
+
"""Different separator style."""
|
11 |
+
SINGLE = auto()
|
12 |
+
TWO = auto()
|
13 |
+
MPT = auto()
|
14 |
+
PLAIN = auto()
|
15 |
+
LLAMA_2 = auto()
|
16 |
+
GEMMA = auto()
|
17 |
+
|
18 |
+
|
19 |
+
@dataclasses.dataclass
|
20 |
+
class Conversation:
|
21 |
+
"""A class that keeps all conversation history."""
|
22 |
+
system: str
|
23 |
+
roles: List[str]
|
24 |
+
messages: List[List[str]]
|
25 |
+
offset: int
|
26 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
27 |
+
sep: str = "###"
|
28 |
+
sep2: str = None
|
29 |
+
version: str = "Unknown"
|
30 |
+
|
31 |
+
skip_next: bool = False
|
32 |
+
|
33 |
+
def get_prompt(self):
|
34 |
+
messages = self.messages
|
35 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
36 |
+
messages = self.messages.copy()
|
37 |
+
init_role, init_msg = messages[0].copy()
|
38 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
39 |
+
if 'mmtag' in self.version:
|
40 |
+
messages[0] = (init_role, init_msg)
|
41 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
42 |
+
messages.insert(1, (self.roles[1], "Received."))
|
43 |
+
else:
|
44 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
45 |
+
|
46 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
47 |
+
ret = self.system + self.sep
|
48 |
+
for role, message in messages:
|
49 |
+
if message:
|
50 |
+
if type(message) is tuple:
|
51 |
+
message = message[0]
|
52 |
+
ret += role + ": " + message + self.sep
|
53 |
+
else:
|
54 |
+
ret += role + ":"
|
55 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
56 |
+
seps = [self.sep, self.sep2]
|
57 |
+
ret = self.system + seps[0]
|
58 |
+
for i, (role, message) in enumerate(messages):
|
59 |
+
if message:
|
60 |
+
if type(message) is tuple:
|
61 |
+
message = message[0]
|
62 |
+
ret += role + ": " + message + seps[i % 2]
|
63 |
+
else:
|
64 |
+
ret += role + ":"
|
65 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
66 |
+
ret = self.system + self.sep
|
67 |
+
for role, message in messages:
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message = message[0]
|
71 |
+
ret += role + message + self.sep
|
72 |
+
else:
|
73 |
+
ret += role
|
74 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
75 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
76 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
77 |
+
ret = ""
|
78 |
+
|
79 |
+
for i, (role, message) in enumerate(messages):
|
80 |
+
if i == 0:
|
81 |
+
assert message, "first message should not be none"
|
82 |
+
assert role == self.roles[0], "first message should come from user"
|
83 |
+
if message:
|
84 |
+
if type(message) is tuple:
|
85 |
+
message, _, _ = message
|
86 |
+
if i == 0: message = wrap_sys(self.system) + message
|
87 |
+
if i % 2 == 0:
|
88 |
+
message = wrap_inst(message)
|
89 |
+
ret += self.sep + message
|
90 |
+
else:
|
91 |
+
ret += " " + message + " " + self.sep2
|
92 |
+
else:
|
93 |
+
ret += ""
|
94 |
+
ret = ret.lstrip(self.sep)
|
95 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
96 |
+
seps = [self.sep, self.sep2]
|
97 |
+
ret = self.system + seps[0]
|
98 |
+
for i, (role, message) in enumerate(messages):
|
99 |
+
if message:
|
100 |
+
if type(message) is tuple:
|
101 |
+
message, _, _ = message
|
102 |
+
ret += "<start_of_turn>" + role + "\n" + message + "<end_of_turn>\n" + seps[i % 2]
|
103 |
+
else:
|
104 |
+
ret += "<start_of_turn>" + role + "\n"
|
105 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
106 |
+
seps = [self.sep, self.sep2]
|
107 |
+
ret = self.system
|
108 |
+
for i, (role, message) in enumerate(messages):
|
109 |
+
if message:
|
110 |
+
if type(message) is tuple:
|
111 |
+
message, _, _ = message
|
112 |
+
ret += message + seps[i % 2]
|
113 |
+
else:
|
114 |
+
ret += ""
|
115 |
+
else:
|
116 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
117 |
+
|
118 |
+
return ret
|
119 |
+
|
120 |
+
def append_message(self, role, message):
|
121 |
+
self.messages.append([role, message])
|
122 |
+
|
123 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
|
124 |
+
if image_process_mode == "Pad":
|
125 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
126 |
+
width, height = pil_img.size
|
127 |
+
if width == height:
|
128 |
+
return pil_img
|
129 |
+
elif width > height:
|
130 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
131 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
132 |
+
return result
|
133 |
+
else:
|
134 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
135 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
136 |
+
return result
|
137 |
+
image = expand2square(image)
|
138 |
+
elif image_process_mode in ["Default", "Crop"]:
|
139 |
+
pass
|
140 |
+
elif image_process_mode == "Resize":
|
141 |
+
image = image.resize((336, 336))
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
144 |
+
if max(image.size) > max_len:
|
145 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
146 |
+
aspect_ratio = max_hw / min_hw
|
147 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
148 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
149 |
+
W, H = image.size
|
150 |
+
if H > W:
|
151 |
+
H, W = longest_edge, shortest_edge
|
152 |
+
else:
|
153 |
+
H, W = shortest_edge, longest_edge
|
154 |
+
image = image.resize((W, H))
|
155 |
+
if return_pil:
|
156 |
+
return image
|
157 |
+
else:
|
158 |
+
buffered = BytesIO()
|
159 |
+
image.save(buffered, format=image_format)
|
160 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
161 |
+
return img_b64_str
|
162 |
+
|
163 |
+
def get_images(self, return_pil=False):
|
164 |
+
images = []
|
165 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
166 |
+
if i % 2 == 0:
|
167 |
+
if type(msg) is tuple:
|
168 |
+
msg, image, image_process_mode = msg
|
169 |
+
image = self.process_image(image, image_process_mode, return_pil=return_pil)
|
170 |
+
images.append(image)
|
171 |
+
return images
|
172 |
+
|
173 |
+
def to_gradio_chatbot(self):
|
174 |
+
ret = []
|
175 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
176 |
+
if i % 2 == 0:
|
177 |
+
if type(msg) is tuple:
|
178 |
+
msg, image, image_process_mode = msg
|
179 |
+
img_b64_str = self.process_image(
|
180 |
+
image, "Default", return_pil=False,
|
181 |
+
image_format='JPEG')
|
182 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
183 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
184 |
+
ret.append([msg, None])
|
185 |
+
else:
|
186 |
+
ret.append([msg, None])
|
187 |
+
else:
|
188 |
+
if type(msg) is tuple and len(msg) == 2:
|
189 |
+
msg, img_b64_str = msg
|
190 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
191 |
+
msg = msg.strip() + img_str
|
192 |
+
ret[-1][-1] = msg
|
193 |
+
return ret
|
194 |
+
|
195 |
+
def copy(self):
|
196 |
+
return Conversation(
|
197 |
+
system=self.system,
|
198 |
+
roles=self.roles,
|
199 |
+
messages=[[x, y] for x, y in self.messages],
|
200 |
+
offset=self.offset,
|
201 |
+
sep_style=self.sep_style,
|
202 |
+
sep=self.sep,
|
203 |
+
sep2=self.sep2,
|
204 |
+
version=self.version)
|
205 |
+
|
206 |
+
def dict(self):
|
207 |
+
if len(self.get_images()) > 0:
|
208 |
+
return {
|
209 |
+
"system": self.system,
|
210 |
+
"roles": self.roles,
|
211 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
212 |
+
"offset": self.offset,
|
213 |
+
"sep": self.sep,
|
214 |
+
"sep2": self.sep2,
|
215 |
+
}
|
216 |
+
return {
|
217 |
+
"system": self.system,
|
218 |
+
"roles": self.roles,
|
219 |
+
"messages": self.messages,
|
220 |
+
"offset": self.offset,
|
221 |
+
"sep": self.sep,
|
222 |
+
"sep2": self.sep2,
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
conv_vicuna_v0 = Conversation(
|
227 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
228 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
229 |
+
roles=("Human", "Assistant"),
|
230 |
+
messages=(
|
231 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
232 |
+
("Assistant",
|
233 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
234 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
235 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
236 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
237 |
+
"renewable and non-renewable energy sources:\n"
|
238 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
239 |
+
"energy sources are finite and will eventually run out.\n"
|
240 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
241 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
242 |
+
"and other negative effects.\n"
|
243 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
244 |
+
"have lower operational costs than non-renewable sources.\n"
|
245 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
246 |
+
"locations than non-renewable sources.\n"
|
247 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
248 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
249 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
250 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
251 |
+
),
|
252 |
+
offset=2,
|
253 |
+
sep_style=SeparatorStyle.SINGLE,
|
254 |
+
sep="###",
|
255 |
+
)
|
256 |
+
|
257 |
+
conv_vicuna_v1 = Conversation(
|
258 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
259 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
260 |
+
roles=("USER", "ASSISTANT"),
|
261 |
+
version="v1",
|
262 |
+
messages=(),
|
263 |
+
offset=0,
|
264 |
+
sep_style=SeparatorStyle.TWO,
|
265 |
+
sep=" ",
|
266 |
+
sep2="</s>",
|
267 |
+
)
|
268 |
+
|
269 |
+
conv_llama_2 = Conversation(
|
270 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
271 |
+
|
272 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
273 |
+
roles=("USER", "ASSISTANT"),
|
274 |
+
version="llama_v2",
|
275 |
+
messages=(),
|
276 |
+
offset=0,
|
277 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
278 |
+
sep="<s>",
|
279 |
+
sep2="</s>",
|
280 |
+
)
|
281 |
+
|
282 |
+
conv_llava_llama_2 = Conversation(
|
283 |
+
system="You are a helpful language and vision assistant. "
|
284 |
+
"You are able to understand the visual content that the user provides, "
|
285 |
+
"and assist the user with a variety of tasks using natural language.",
|
286 |
+
roles=("USER", "ASSISTANT"),
|
287 |
+
version="llama_v2",
|
288 |
+
messages=(),
|
289 |
+
offset=0,
|
290 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
291 |
+
sep="<s>",
|
292 |
+
sep2="</s>",
|
293 |
+
)
|
294 |
+
|
295 |
+
conv_mpt = Conversation(
|
296 |
+
system="""<|im_start|>system
|
297 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
298 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
299 |
+
version="mpt",
|
300 |
+
messages=(),
|
301 |
+
offset=0,
|
302 |
+
sep_style=SeparatorStyle.MPT,
|
303 |
+
sep="<|im_end|>",
|
304 |
+
)
|
305 |
+
|
306 |
+
conv_llava_plain = Conversation(
|
307 |
+
system="",
|
308 |
+
roles=("", ""),
|
309 |
+
messages=(
|
310 |
+
),
|
311 |
+
offset=0,
|
312 |
+
sep_style=SeparatorStyle.PLAIN,
|
313 |
+
sep="\n",
|
314 |
+
)
|
315 |
+
|
316 |
+
conv_llava_v0 = Conversation(
|
317 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
318 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
319 |
+
roles=("Human", "Assistant"),
|
320 |
+
messages=(
|
321 |
+
),
|
322 |
+
offset=0,
|
323 |
+
sep_style=SeparatorStyle.SINGLE,
|
324 |
+
sep="###",
|
325 |
+
)
|
326 |
+
|
327 |
+
conv_llava_v0_mmtag = Conversation(
|
328 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
329 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
330 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
331 |
+
roles=("Human", "Assistant"),
|
332 |
+
messages=(
|
333 |
+
),
|
334 |
+
offset=0,
|
335 |
+
sep_style=SeparatorStyle.SINGLE,
|
336 |
+
sep="###",
|
337 |
+
version="v0_mmtag",
|
338 |
+
)
|
339 |
+
|
340 |
+
conv_llava_v1 = Conversation(
|
341 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
342 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
343 |
+
roles=("USER", "ASSISTANT"),
|
344 |
+
version="v1",
|
345 |
+
messages=(),
|
346 |
+
offset=0,
|
347 |
+
sep_style=SeparatorStyle.TWO,
|
348 |
+
sep=" ",
|
349 |
+
sep2="</s>",
|
350 |
+
)
|
351 |
+
|
352 |
+
conv_vicuna_imgsp_v1 = Conversation(
|
353 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
354 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
355 |
+
roles=("USER", "ASSISTANT"),
|
356 |
+
version="imgsp_v1",
|
357 |
+
messages=(),
|
358 |
+
offset=0,
|
359 |
+
sep_style=SeparatorStyle.TWO,
|
360 |
+
sep=" ",
|
361 |
+
sep2="</s>",
|
362 |
+
)
|
363 |
+
|
364 |
+
conv_llava_plain_guided = Conversation(
|
365 |
+
system="",
|
366 |
+
roles=("", ""),
|
367 |
+
version="plain_guided",
|
368 |
+
messages=(
|
369 |
+
),
|
370 |
+
offset=0,
|
371 |
+
sep_style=SeparatorStyle.PLAIN,
|
372 |
+
sep="\n",
|
373 |
+
)
|
374 |
+
|
375 |
+
conv_llava_v1_mmtag = Conversation(
|
376 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
377 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
378 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
379 |
+
roles=("USER", "ASSISTANT"),
|
380 |
+
messages=(),
|
381 |
+
offset=0,
|
382 |
+
sep_style=SeparatorStyle.TWO,
|
383 |
+
sep=" ",
|
384 |
+
sep2="</s>",
|
385 |
+
version="v1_mmtag",
|
386 |
+
)
|
387 |
+
|
388 |
+
conv_phi_2 = Conversation(
|
389 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
390 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
391 |
+
roles=("USER", "ASSISTANT"),
|
392 |
+
version="phi2",
|
393 |
+
messages=(),
|
394 |
+
offset=0,
|
395 |
+
sep_style=SeparatorStyle.TWO,
|
396 |
+
sep=" ",
|
397 |
+
sep2="<|endoftext|>",
|
398 |
+
)
|
399 |
+
|
400 |
+
conv_mistral_instruct = Conversation(
|
401 |
+
system="",
|
402 |
+
roles=("USER", "ASSISTANT"),
|
403 |
+
version="llama_v2",
|
404 |
+
messages=(),
|
405 |
+
offset=0,
|
406 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
407 |
+
sep="<s>",
|
408 |
+
sep2="</s>",
|
409 |
+
)
|
410 |
+
|
411 |
+
conv_gemma = Conversation(
|
412 |
+
system="",
|
413 |
+
roles=("user", "model"),
|
414 |
+
version="gemma",
|
415 |
+
messages=(),
|
416 |
+
offset=0,
|
417 |
+
sep_style=SeparatorStyle.GEMMA,
|
418 |
+
sep="",
|
419 |
+
sep2="<eos>",
|
420 |
+
)
|
421 |
+
|
422 |
+
conv_chatml_direct = Conversation(
|
423 |
+
system="""<|im_start|>system
|
424 |
+
Answer the questions.""",
|
425 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
426 |
+
version="mpt",
|
427 |
+
messages=(),
|
428 |
+
offset=0,
|
429 |
+
sep_style=SeparatorStyle.MPT,
|
430 |
+
sep="<|im_end|>",
|
431 |
+
)
|
432 |
+
|
433 |
+
default_conversation = conv_vicuna_v1
|
434 |
+
conv_templates = {
|
435 |
+
"default": conv_vicuna_v0,
|
436 |
+
"v0": conv_vicuna_v0,
|
437 |
+
"v1": conv_vicuna_v1,
|
438 |
+
"vicuna_v1": conv_vicuna_v1,
|
439 |
+
"phi_2": conv_phi_2,
|
440 |
+
"gemma": conv_gemma,
|
441 |
+
"llama_2": conv_llama_2,
|
442 |
+
"imgsp_v1": conv_vicuna_imgsp_v1,
|
443 |
+
"plain_guided": conv_llava_plain_guided,
|
444 |
+
"mistral_instruct": conv_mistral_instruct,
|
445 |
+
"chatml_direct": conv_chatml_direct,
|
446 |
+
"mistral_direct": conv_chatml_direct,
|
447 |
+
"plain": conv_llava_plain,
|
448 |
+
"v0_plain": conv_llava_plain,
|
449 |
+
"llava_v0": conv_llava_v0,
|
450 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
451 |
+
"llava_v1": conv_llava_v1,
|
452 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
453 |
+
"llava_llama_2": conv_llava_llama_2,
|
454 |
+
|
455 |
+
"mpt": conv_mpt,
|
456 |
+
}
|
457 |
+
|
458 |
+
|
459 |
+
if __name__ == "__main__":
|
460 |
+
print(default_conversation.get_prompt())
|
helpers.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
|
7 |
+
def top_k_top_p_filtering(
|
8 |
+
logits,
|
9 |
+
top_k: int = 0,
|
10 |
+
top_p: float = 1.0,
|
11 |
+
filter_value: float = -float("Inf"),
|
12 |
+
min_tokens_to_keep: int = 1,
|
13 |
+
):
|
14 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
15 |
+
Args:
|
16 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
17 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
18 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
19 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
20 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
21 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
22 |
+
"""
|
23 |
+
|
24 |
+
logits[:,:256000]=filter_value
|
25 |
+
if top_k > 0:
|
26 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
27 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
28 |
+
|
29 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
30 |
+
logits[indices_to_remove] = filter_value
|
31 |
+
|
32 |
+
if top_p < 1.0:
|
33 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
34 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
35 |
+
|
36 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
37 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
38 |
+
if min_tokens_to_keep > 1:
|
39 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
40 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
41 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
42 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
43 |
+
sorted_indices_to_remove[..., 0] = 0
|
44 |
+
|
45 |
+
# scatter sorted tensors to original indexing
|
46 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
47 |
+
logits[indices_to_remove] = filter_value
|
48 |
+
# import pdb;pdb.set_trace()
|
49 |
+
return logits
|
50 |
+
|
51 |
+
|
52 |
+
def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
|
53 |
+
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
54 |
+
if top_k > 0 or top_p < 1.0:
|
55 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
56 |
+
probs = F.softmax(logits, dim=-1)
|
57 |
+
if sample_logits:
|
58 |
+
idx = torch.multinomial(probs, num_samples=1)
|
59 |
+
else:
|
60 |
+
_, idx = torch.topk(probs, k=1, dim=-1)
|
61 |
+
return idx, probs
|
62 |
+
|
63 |
+
|
64 |
+
def expand2square(pil_img, background_color):
|
65 |
+
width, height = pil_img.size
|
66 |
+
if width == height:
|
67 |
+
return pil_img
|
68 |
+
elif width > height:
|
69 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
70 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
71 |
+
return result
|
72 |
+
else:
|
73 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
74 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
75 |
+
return result
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=-200, return_tensors=None):
|
80 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
81 |
+
|
82 |
+
def insert_separator(X, sep):
|
83 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
84 |
+
|
85 |
+
input_ids = []
|
86 |
+
offset = 0
|
87 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
88 |
+
offset = 1
|
89 |
+
input_ids.append(prompt_chunks[0][0])
|
90 |
+
|
91 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
92 |
+
input_ids.extend(x[offset:])
|
93 |
+
|
94 |
+
if return_tensors is not None:
|
95 |
+
if return_tensors == 'pt':
|
96 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
97 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
98 |
+
return input_ids
|
99 |
+
|
requirements.txt
CHANGED
@@ -1 +1,6 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers==4.39.2
|
3 |
+
spaces
|
4 |
+
pillow
|
5 |
+
accelerate
|
6 |
+
tqdm
|