Junfeng5 commited on
Commit
0d7e8be
·
verified ·
1 Parent(s): c1f6040

Upload 31 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ baklava.png filter=lfs diff=lfs merge=lfs -text
37
+ Liquid_icon.png filter=lfs diff=lfs merge=lfs -text
Liquid_icon.png ADDED

Git LFS Details

  • SHA256: 7d65c5aa3ed6ebc4d9327b3962690cda4ada81b9359daf5dcbe9528f0635f0b6
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
README.md CHANGED
@@ -1,14 +0,0 @@
1
- ---
2
- title: Liquid Demo
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: A unified understanding and generation multimodal model
12
- ---
13
-
14
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,64 +1,356 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
41
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from threading import Thread
3
+
4
  import gradio as gr
5
+ import torch
6
+ import PIL
7
+ from PIL import Image
8
+ from transformers import AutoConfig, AutoModelForCausalLM
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ import os
12
+ from tqdm import tqdm
13
+ from chameleon.inference.image_tokenizer import ImageTokenizer
14
+ from helpers import sample, expand2square, tokenizer_image_token
15
+
16
+ # from transformers import AutoProcessor, LlavaForConditionalGeneration
17
+ from transformers import TextIteratorStreamer
18
+ from conversation import conv_templates
19
+ import spaces
20
+
21
+
22
+ import os
23
+ os.system("pip uninstall -y gradio")
24
+ os.system("pip install gradio==4.44.1")
25
+ os.system("pip install gradio_client==1.3.0")
26
+
27
 
28
+ IMAGE_TOKEN_INDEX=-200
29
+ PLACEHOLDER = """
30
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
31
+ <img src='file/Liquid_icon.png' style="width: 80%; max-width: 600px; height: auto; opacity: 0.5;">
32
+ <h1 style="font-size: 20px; margin-bottom: 1px; opacity: 0.55;">Liquid-7B</h1>
33
+ </div>
34
  """
35
+
36
+ CSS ="""
37
+ .contain { display: flex; flex-direction: column; }
38
+ #component-0 { height: 100%; }
39
+ #chatbot { flex-grow: 1; }
40
  """
 
41
 
42
 
43
+ title_html = """
44
+ <div style="display: flex; flex-direction: column; align-items: center; gap: 10px;">
45
+ <h1 style="margin: 0; line-height: 1; text-align: center;"> Liquid: Language Models are Scalable Multi-modal <br> Generators via Unified Understanding and Generation</h1>
46
+ </div>
47
+ """
 
 
 
 
48
 
49
+ links_html = f"""
50
+ <center><font size=3><a href='https://foundationvision.github.io/Liquid/'>Liquid</a> has been open-sourced on <a href='https://huggingface.co/Junfeng5/Liquid_V1_7B'>😊 Huggingface</a> and <a href='https://github.com/FoundationVision/Liquid'>🌟 GitHub</a>. If you find Liquid useful, a like❤️ or a star🌟 would be appreciated.</font></center>
51
+ """
 
 
52
 
53
+ introduction = f"""
54
+ Liquid explores the potential of a single LLM as a multimodal generator and its scaling laws. It achieves the level of diffusion models in visual generation and discovers the mutual enhancement between understanding and generation. More details can be found on the project <a href='https://foundationvision.github.io/Liquid/'> homepage</a> and in the <a href='https://arxiv.org/abs/2412.04332'> paper</a>. """
55
 
 
56
 
 
 
 
 
 
 
 
 
57
 
58
+ model_id = 'Junfeng5/Liquid_V1_7B'
59
+ tokenizer = AutoTokenizer.from_pretrained(model_id,padding_side='left')
60
+ vqllm = AutoModelForCausalLM.from_pretrained(
61
+ model_id,
62
+ attn_implementation='flash_attention_2',
63
+ torch_dtype=torch.bfloat16,
64
+ load_in_8bit=True,
65
+ max_memory={0: "40GiB" },
66
+ ) # .to("cuda:0")
67
 
68
+ stop_flag = False
69
 
70
+ ori_vocabe_size = len(tokenizer)
71
+
72
+ vqgan_cfg_path = "chameleon/vqgan.yaml"
73
+ vqgan_ckpt_path = "chameleon/vqgan.ckpt"
74
+ image_tokenizer = ImageTokenizer( cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device="cuda:0",)
75
+
76
+ @spaces.GPU
77
+ def bot_streaming_I2T(message, history):
78
+ print(message)
79
+ global stop_flag
80
+ stop_flag = True
81
+ time.sleep(0.2)
82
+ stop_flag = False
83
+ torch.cuda.empty_cache()
84
+ if message["files"]:
85
+ # message["files"][-1] is a Dict or just a string
86
+ if type(message["files"][-1]) == dict:
87
+ image = message["files"][-1]["path"]
88
+ else:
89
+ image = message["files"][-1]
90
+ else:
91
+ # if there's no image uploaded for this turn, look for images in the past turns
92
+ # kept inside tuples, take the last one
93
+ for hist in history:
94
+ if type(hist[0]) == tuple:
95
+ image = hist[0][0]
96
+ try:
97
+ if image is None:
98
+ # Handle the case where image is None
99
+ gr.Error("You need to upload an image for LLaVA to work.")
100
+ except NameError:
101
+ # Handle the case where 'image' is not defined at all
102
+ gr.Error("You need to upload an image for LLaVA to work.")
103
+
104
+ qs = message['text']
105
+ qs = '<boi><image><eoi>' + '\n' + qs
106
+ conv = conv_templates['gemma'].copy()
107
+ conv.append_message(conv.roles[0], qs)
108
+ conv.append_message(conv.roles[1], None)
109
+ prompt = conv.get_prompt()
110
+
111
+
112
+ print(prompt)
113
+ image = Image.open(image).convert('RGB')
114
+ pad_image = expand2square(image, (122, 116, 104) )
115
+ input_image = pad_image.resize((512,512), PIL.Image.LANCZOS)
116
+ with torch.no_grad():
117
+ vq_code = image_tokenizer.img_tokens_from_pil(input_image)
118
+ vqcode = vq_code.cpu()
119
+ vqcode = vqcode+ len(tokenizer)
120
+
121
+ text_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
122
+ num_images = (text_ids == IMAGE_TOKEN_INDEX).sum()
123
+ image_token_indices = [-1] + torch.where(text_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [text_ids.shape[0]]
124
+ cur_input_ids = []
125
+ for i in range(num_images + 1):
126
+ cur_input_ids.append(text_ids[image_token_indices[i]+1:image_token_indices[i+1]])
127
+ if i < num_images:
128
+ cur_input_ids.append( vqcode )
129
+ input_ids = torch.cat(cur_input_ids, dim=0)
130
+ # input_embeddings = vqllm.embed_tokens(input_ids)
131
+ inputs = {
132
+ "input_ids":input_ids.unsqueeze(0).to("cuda:0"),
133
+ "max_new_tokens":1024,
134
+ "bos_token_id":tokenizer.bos_token_id, # Begin of sequence token
135
+ "eos_token_id":tokenizer.eos_token_id, # End of sequence token
136
+ "pad_token_id":tokenizer.pad_token_id, # Pad token
137
+ }
138
+ streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": False, "skip_prompt": True})
139
+
140
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
141
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
142
+ thread = Thread(target=vqllm.generate, kwargs=generation_kwargs)
143
+ thread.start()
144
+ generated_text = ""
145
+ for new_text in streamer:
146
+ generated_text += new_text
147
+ time.sleep(0.06)
148
+ yield generated_text
149
+
150
+
151
+
152
+ def show_gallery(images):
153
+ gallery = gr.Gallery(images, label="Gallery", columns=4, height="auto",preview=True,scale=0.05) # 设置两行两列的布局
154
+ return gallery
155
+
156
+ @spaces.GPU
157
+ def bot_streaming_T2I(message, history,guidance_scale, temperature, top_K, top_P):
158
+
159
+ global stop_flag
160
+ stop_flag = True
161
+ time.sleep(0.2)
162
+ stop_flag = False
163
+
164
+ text_inputs = [message]*4 # generate 4 samples once
165
+ uncondition_text_inputs = ['<unconditional><boi>']*len(text_inputs)
166
+ for i in range(len(text_inputs)):
167
+ text_inputs[i] = text_inputs[i]+' Generate an image based on this description.<boi>'
168
+
169
+ ori_batchsize = len(text_inputs)
170
+
171
+ if guidance_scale>1:
172
+ model_inputs = tokenizer(text_inputs+uncondition_text_inputs, return_tensors="pt",padding=True).to("cuda:0")
173
+ else:
174
+ model_inputs = tokenizer(text_inputs, return_tensors="pt",padding=True).to("cuda:0")
175
+ with torch.no_grad():
176
+ sampling_kwargs={'temperature': temperature, 'top_k': top_K, 'top_p': top_P, 'sample_logits': True}
177
+ input_ids = model_inputs['input_ids']
178
+ cur_len = input_ids.shape[1]
179
+ model_kwargs = {'attention_mask':model_inputs['attention_mask'] , 'use_cache': True}
180
+ model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
181
+
182
+ pred_tokens = []
183
+ for i in tqdm(range(1024)):
184
+ if stop_flag:
185
+ print("generation is stoped")
186
+ del sampling_kwargs
187
+ del model_inputs
188
+ del outputs
189
+ torch.cuda.empty_cache()
190
+ break
191
+ model_inputs = vqllm.prepare_inputs_for_generation(input_ids, **model_kwargs)
192
+
193
+ if i > 0 and guidance_scale>1:
194
+ outputs = vqllm(
195
+ **model_inputs,
196
+ return_dict=True,
197
+ output_attentions=False,
198
+ output_hidden_states=False,
199
+ )
200
+ else:
201
+ outputs = vqllm(
202
+ **model_inputs,
203
+ return_dict=True,
204
+ output_attentions=False,
205
+ output_hidden_states=False,
206
+ )
207
+
208
+ next_token_logits = outputs.logits[:, -1:, :]
209
+
210
+ if guidance_scale>1:
211
+ cond_logits, uncond_logits = torch.split(next_token_logits, len(next_token_logits) // 2, dim=0)
212
+ cfg_logits = uncond_logits + (cond_logits - uncond_logits) * guidance_scale
213
+ half_next_token, _ = sample(cfg_logits, **sampling_kwargs)
214
+ pred_tokens.append(half_next_token)
215
+ next_token = torch.cat([half_next_token,half_next_token])
216
+
217
+
218
+ else:
219
+ next_token, next_prob = sample(next_token_logits, **sampling_kwargs)
220
+ pred_tokens.append(next_token)
221
+
222
+ # update generated ids, model inputs, and length for next step
223
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
224
+ model_kwargs = vqllm._update_model_kwargs_for_generation(
225
+ outputs,
226
+ model_kwargs,
227
+ is_encoder_decoder=vqllm.config.is_encoder_decoder,
228
+ )
229
+
230
+ del sampling_kwargs
231
+ del model_inputs
232
+ del outputs
233
+ image_vq_id = torch.cat(pred_tokens,dim=1)-ori_vocabe_size
234
+ image_vq_id = torch.clamp(image_vq_id, min=0, max=8191)
235
+
236
+ generated_image_list = []
237
+ for index, generate_id in enumerate(image_vq_id):
238
+ rec_img = image_tokenizer.pil_from_img_toks(generate_id)
239
+ generated_image_list.append(rec_img)
240
+ # rec_img.save('{}/{}.jpg'.format(image_save_pth,str(idx)))
241
+
242
+ torch.cuda.empty_cache()
243
+ # yield gr.Image(value=generated_image_list[0], label="Generated Image", show_download_button=True)
244
+ yield show_gallery(generated_image_list)
245
+
246
+ @spaces.GPU
247
+ def bot_streaming_T2T(message, history,temperature):
248
+ print(message)
249
+ global stop_flag
250
+ stop_flag = True
251
+ time.sleep(0.2)
252
+ stop_flag = False
253
+ torch.cuda.empty_cache()
254
+ qs = message
255
+ conv = conv_templates['gemma'].copy()
256
+ conv.append_message(conv.roles[0], qs)
257
+ conv.append_message(conv.roles[1], None)
258
+ prompt = conv.get_prompt()
259
+
260
+ print(prompt)
261
+ with torch.no_grad():
262
+ inputs = tokenizer([prompt], return_tensors="pt").to('cuda')
263
+ streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": False, "skip_prompt": True})
264
+
265
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
266
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
267
+ thread = Thread(target=vqllm.generate, kwargs=generation_kwargs)
268
+ thread.start()
269
+ generated_text = ""
270
+ for new_text in streamer:
271
+ generated_text += new_text
272
+ yield generated_text
273
+
274
+
275
+ chatbot_T2I=gr.Chatbot(placeholder=PLACEHOLDER,height=600)
276
+ chat_input_T2I = gr.Textbox(placeholder="Enter text prompts...", show_label=False)
277
+
278
+ chatbot_I2T=gr.Chatbot(placeholder=PLACEHOLDER, scale=1)
279
+ chat_input_I2T = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
280
+
281
+ chatbot_T2T=gr.Chatbot(placeholder=PLACEHOLDER, scale=1)
282
+ chat_input_T2T = gr.Textbox(placeholder="Enter text prompts...", show_label=False)
283
+
284
+
285
+ with gr.Blocks(fill_height=True) as demo:
286
+
287
+ gr.Markdown(title_html)
288
+ gr.Markdown(links_html)
289
+ gr.Markdown(introduction)
290
+
291
+ with gr.Tab("Text To Image"):
292
+
293
+ description="Enter a text prompt or simply try one of the examples below to generate 4 images at once. Click to display the full image. You can configure hyperparameters for image generation in the Advanced Settings. "
294
+ gr.Markdown(description)
295
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
296
+ with gr.Row():
297
+ guidance_scale = gr.Slider(1.0, 20.0, value=7.0, label="Guidance Scale")
298
+ temperature = gr.Slider(0.0, 1.0, value=0.9, label="temperature")
299
+ top_K = gr.Slider(1, 8192, value=4096, label="Top K")
300
+ top_P = gr.Slider(0.0, 1.0, value=0.99, label="Top P")
301
+
302
+ aaa = gr.ChatInterface(
303
+ fn=bot_streaming_T2I,
304
+ examples=[
305
+ ["young blue dragon with horn lightning in the style of dd fantasy full body",5.0, 0.9,4096,0.99],
306
+ ["A majestic Goddes of beauty, charming dressed in a regal, jeweled gown and ornate crown, her golden hair cascading down her back, in the style of Pino Daeni",5.0, 0.9,4096,0.99],
307
+ ["A highly realistic, closeup photograph of a beautiful 35 year old redread woman writing in her journal, sitting on her balcony wearing warm, stylish outfits. Shot on a Canon EOS R5, the image boasts sharp focus and intricate details. The heartwarming scene conveys love, connection, and the crisp winter atmosphere, dramatic lighting.",5.0, 0.9,4096,0.99],
308
+ ["Portrait of an asian woman. She has pink violet hair style with modern complex hairdressing. The background is dark with cyberpunk neon lights. Inspired by Cyberpunk 2077 and Blade Runner. Ultra realistic picture. To capture the image, you will use a fullframe DSLR or mirrorless camera with a highresolution sensor, an aperture of f2.8 or wider, and a shutter speed of 1500 second or faster. You will use natural light and reflectors to create a balanced and welllit image, and will experiment with different angles and compositions to create the most i",5.0, 0.9,4096,0.99],
309
+ ["female character fantasy world, for fantasy story, protagonist, interesting and detailed clothes, beautiful, medieval fantasy cinematic shot photo taken by canon, photo taken by fuji, photo taken by kodak incredibly detailed, sharpen, details professional lighting , film lighting 350mm lightroom cinematography, hyper realism, cinematic, film quality",5.0, 0.9,4096,0.99],
310
+ ["strawberries splashing, swirling liquid, realism, octane render, raytracing",5.0, 0.9,4096,0.99],
311
+ ["hedgehog face, floating in space, wearing space suit no helmet, cinematic, 50mm f1.8, unreal engine 5",5.0, 0.9,4096,0.99],
312
+ ["artificial intelligence, revolution, publishing, writer, hyperrealistic",5.0, 0.9,4096,0.99],
313
+ ["A pig dressed as a mason, by Bill Gekas",5.0, 0.9,4096,0.99],
314
+ ],
315
+ stop_btn="Stop Generation",
316
+ additional_inputs = [guidance_scale, temperature, top_K, top_P],
317
+ additional_inputs_accordion="⚙️ Advanced Settings",
318
+ multimodal=False,
319
+ textbox=chat_input_T2I,
320
+ chatbot=chatbot_T2I,
321
+ fill_height=True,
322
+ )
323
+
324
+
325
+
326
+
327
+ with gr.Tab("Image To Text"):
328
+ bbb = gr.ChatInterface(
329
+ fn=bot_streaming_I2T,
330
+ examples=[ {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
331
+ description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
332
+ stop_btn="Stop Generation",
333
+ multimodal=True,
334
+ textbox=chat_input_I2T,
335
+ chatbot=chatbot_I2T,
336
+ )
337
+
338
+ with gr.Tab("Text To Text"):
339
+
340
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
341
+ with gr.Row():
342
+ texttemperature = gr.Slider(0.0, 1.0, value=0.9, label="texttemperature")
343
+
344
+ gr.ChatInterface(
345
+ fn=bot_streaming_T2T,
346
+ examples=[["a dog", 0.9]],
347
+ description="Chat with Liquid without images.",
348
+ stop_btn="Stop Generation",
349
+ additional_inputs = [texttemperature],
350
+ additional_inputs_accordion="⚙️ Advanced Settings",
351
+ multimodal=False,
352
+ textbox=chat_input_T2T,
353
+ chatbot=chatbot_T2T,
354
+ )
355
+ demo.queue(api_open=False)
356
+ demo.launch(allowed_paths=["./"], share=False )
baklava.png ADDED

Git LFS Details

  • SHA256: 7839e93dd753e5356176bf70d38c43bc56355099d8891ead7aaa342029369268
  • Pointer size: 132 Bytes
  • Size of remote file: 2.04 MB
chameleon/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
chameleon/download_data.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Chameleon License Agreement.
3
+
4
+ import hashlib
5
+ from pathlib import Path
6
+ import subprocess
7
+ import sys
8
+
9
+
10
+ def download_file(url: str, output_path: Path):
11
+ print(f"Downloading {output_path}")
12
+ subprocess.check_call(["wget", "--continue", url, "-O", str(output_path)])
13
+
14
+
15
+ def validate_checksum(folder: Path):
16
+ chks_parts = (folder / "checklist.chk").read_text().split()
17
+ for expected_checksum, file in zip(chks_parts[::2], chks_parts[1::2]):
18
+ file_path = folder / file
19
+ checksum = hashlib.md5(file_path.read_bytes()).hexdigest()
20
+ if checksum != expected_checksum:
21
+ print(f"Checksum mismatch for {file_path}")
22
+ sys.exit(1)
23
+
24
+
25
+ def download_tokenizer(presigned_url: str, target_folder: Path):
26
+ tokenizer_folder = target_folder / "tokenizer"
27
+ tokenizer_folder.mkdir(parents=True, exist_ok=True)
28
+
29
+ for filename in [
30
+ "text_tokenizer.json",
31
+ "vqgan.ckpt",
32
+ "vqgan.yaml",
33
+ "checklist.chk",
34
+ ]:
35
+ download_file(
36
+ presigned_url.replace("*", f"tokenizer/{filename}"),
37
+ tokenizer_folder / filename,
38
+ )
39
+
40
+ validate_checksum(tokenizer_folder)
41
+
42
+
43
+ def download_model(presigned_url: str, target_folder: Path, model: str):
44
+ model_folder = target_folder / "models" / model
45
+ model_folder.mkdir(parents=True, exist_ok=True)
46
+
47
+ download_filenames = ["params.json", "consolidate_params.json", "checklist.chk"]
48
+
49
+ if model == "7b":
50
+ download_filenames += ["consolidated.pth"]
51
+ elif model == "30b":
52
+ download_filenames += [f"consolidated.{i:02}.pth" for i in range(4)]
53
+ else:
54
+ print(f"Unknown model: {model}")
55
+ sys.exit(1)
56
+
57
+ for filename in download_filenames:
58
+ download_file(
59
+ presigned_url.replace("*", f"{model}/{filename}"),
60
+ model_folder / filename,
61
+ )
62
+
63
+ validate_checksum(model_folder)
64
+
65
+
66
+ def main():
67
+ presigned_url = (
68
+ sys.argv[1] if len(sys.argv) > 1 else input("Enter the URL from email: ")
69
+ )
70
+
71
+ target_folder = Path("./data")
72
+ target_folder.mkdir(parents=True, exist_ok=True)
73
+
74
+ download_tokenizer(presigned_url, target_folder)
75
+
76
+ model_size = input(
77
+ "Enter the list of models to download without spaces (7B,30B), or press Enter for all: "
78
+ )
79
+ if not model_size:
80
+ model_size = "7B,30B"
81
+
82
+ for model in model_size.split(","):
83
+ model = model.strip().lower()
84
+ download_model(presigned_url, target_folder, model)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
chameleon/inference/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
chameleon/inference/alignment.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import torch
9
+
10
+
11
+ class PromptAlignment(ABC):
12
+ @abstractmethod
13
+ def start_index(self, input_ids: list[list[int]]) -> int:
14
+ ...
15
+
16
+ @abstractmethod
17
+ def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
18
+ ...
19
+
20
+ @abstractmethod
21
+ def postprocess_inputs(
22
+ self, inputs: torch.Tensor, original_inputs: torch.Tensor
23
+ ) -> torch.Tensor:
24
+ ...
25
+
26
+
27
+ class AlignPromptRight(PromptAlignment):
28
+ def __init__(self, pad_id: int):
29
+ self.pad_id = pad_id
30
+
31
+ def start_index(self, input_ids: list[list[int]]) -> int:
32
+ return max(len(sublist) for sublist in input_ids)
33
+
34
+ def prepare_inputs(self, input_ids: list[list[int]]) -> torch.LongTensor:
35
+ max_length = max(len(sublist) for sublist in input_ids)
36
+ return torch.tensor(
37
+ [
38
+ ([self.pad_id] * (max_length - len(sublist))) + sublist
39
+ for sublist in input_ids
40
+ ],
41
+ requires_grad=False,
42
+ )
43
+
44
+ def postprocess_inputs(
45
+ self,
46
+ inputs: torch.Tensor,
47
+ original_inputs: torch.Tensor,
48
+ ) -> torch.Tensor:
49
+ return inputs
50
+
51
+
52
+ class AlignPromptLeft(PromptAlignment):
53
+ def __init__(self, pad_id: int = -1):
54
+ self.pad_id = pad_id
55
+
56
+ def start_index(self, input_ids: list[list[int]]) -> int:
57
+ return min(len(sublist) for sublist in input_ids)
58
+
59
+ def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
60
+ max_length = max(len(sublist) for sublist in input_ids)
61
+ return torch.tensor(
62
+ [
63
+ sublist + ([self.pad_id] * (max_length - len(sublist)))
64
+ for sublist in input_ids
65
+ ],
66
+ requires_grad=False,
67
+ )
68
+
69
+ def postprocess_inputs(
70
+ self,
71
+ inputs: torch.Tensor,
72
+ original_inputs: torch.Tensor,
73
+ ) -> torch.Tensor:
74
+ max_init_len = original_inputs.shape[1]
75
+ if inputs.shape[1] <= max_init_len:
76
+ original_inputs_limited = original_inputs[:, : inputs.shape[1]]
77
+ mask = original_inputs_limited != self.pad_id
78
+ inputs[mask] = original_inputs_limited[mask]
79
+ return inputs
chameleon/inference/chameleon.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import base64
7
+ import io
8
+ import json
9
+ import math
10
+ import queue
11
+ import threading
12
+ from dataclasses import dataclass, field
13
+ from enum import Enum
14
+ from multiprocessing import managers, queues, synchronize
15
+ from typing import Literal, Union
16
+
17
+ import PIL
18
+ import torch
19
+ import torch.distributed as dist
20
+ import torch.multiprocessing as mp
21
+ from PIL.Image import Image
22
+ from tokenizers import Tokenizer
23
+ from transformers import (
24
+ LogitsProcessor,
25
+ RepetitionPenaltyLogitsProcessor,
26
+ TemperatureLogitsWarper,
27
+ TopPLogitsWarper,
28
+ enable_full_determinism,
29
+ )
30
+
31
+ from chameleon.inference import loader
32
+ from chameleon.inference.alignment import AlignPromptRight
33
+ from chameleon.inference.generation import ChameleonGenerator
34
+ from chameleon.inference.image_tokenizer import ImageTokenizer
35
+ from chameleon.inference.logits_processor import (
36
+ AllowOnlyTokensLogitsProcessor,
37
+ DisallowTokensAtOrAfterIndexLogitsProcessor,
38
+ InBatchInstructCFGLogitsProcessor,
39
+ )
40
+ from chameleon.inference.model_adapter import ChameleonModelAdapter
41
+ from chameleon.inference.stopping_criteria import (
42
+ MaxLengthCriteria,
43
+ StopOnEOSAfterBatchIndex,
44
+ )
45
+ from chameleon.inference.token_selector import (
46
+ ArgmaxTokenSelector,
47
+ MultinomialTokenSelector,
48
+ ReplicatedInputTokenSelector,
49
+ )
50
+ from chameleon.inference.transformer import Transformer
51
+ from chameleon.inference.utils import DynamicGenerator, advance, random_unused_port
52
+ from chameleon.inference.vocab import VocabInfo, VocabTranslation
53
+
54
+
55
+ @dataclass
56
+ class Options:
57
+ @dataclass
58
+ class Text:
59
+ repetition_penalty: float = 1.2
60
+ temp: float = 0.7
61
+ top_p: float = 0.9
62
+ greedy: bool = False
63
+
64
+ @dataclass
65
+ class Image:
66
+ @dataclass
67
+ class CFG:
68
+ guidance_scale_text: float = 3.0
69
+ guidance_scale_image: float = 1.2
70
+
71
+ cfg: CFG = field(default_factory=CFG)
72
+ temp: float = 0.7
73
+ top_p: float = 0.9
74
+ greedy: bool = False
75
+
76
+ max_seq_len: int = 4096
77
+ max_gen_len: int = 4096
78
+ seed: int | None = None
79
+ txt: Text | bool = True
80
+ img: Image | bool = False
81
+ extra_eos_tokens: list[int | str] = field(default_factory=lambda: ["<racm3:break>"])
82
+
83
+ def __post_init__(self):
84
+ if self.txt == True:
85
+ self.txt = Options.Text()
86
+ if self.img == True:
87
+ self.img = Options.Image()
88
+
89
+
90
+ class TokenManager:
91
+ def __init__(
92
+ self,
93
+ tokenizer_path: str,
94
+ vqgan_cfg_path: str,
95
+ vqgan_ckpt_path: str,
96
+ device: str | None = None,
97
+ ):
98
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
99
+ self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
100
+ self.translation = VocabTranslation(self.vocab, device=device)
101
+ self.image_tokenizer = ImageTokenizer(
102
+ cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device=device
103
+ )
104
+
105
+ def pil_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image:
106
+ image_tensor = self.translation.convert_bpe2img(bpe_tokens)
107
+ if image_tensor.shape[0] < 1024:
108
+ padding = (
109
+ torch.ones(
110
+ [1024 - image_tensor.shape[0]],
111
+ dtype=int,
112
+ device=image_tensor.device,
113
+ )
114
+ * image_tensor[0]
115
+ )
116
+ image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
117
+
118
+ return self.image_tokenizer.pil_from_img_toks(image_tensor)
119
+
120
+ def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
121
+ pil = self.pil_from_bpe_tokens(bpe_tokens)
122
+ img_io = io.BytesIO()
123
+ pil.save(img_io, format="PNG")
124
+ return img_io.getvalue()
125
+
126
+ def tokenize_text(self, text: str) -> list[int]:
127
+ return self.tokenizer.encode(text).ids
128
+
129
+ def tokenize_image(self, img: Image) -> list[int]:
130
+ return (
131
+ [self.vocab.begin_image]
132
+ + self.translation.convert_img2bp2(
133
+ self.image_tokenizer.img_tokens_from_pil(img)
134
+ ).tolist()
135
+ + [self.vocab.end_image]
136
+ )
137
+
138
+ def tokenize_b64img(self, b64img: str) -> list[int]:
139
+ image_data = base64.b64decode(b64img)
140
+ image_file = io.BytesIO(image_data)
141
+ return self.tokenize_image(PIL.Image.open(image_file))
142
+
143
+ def tokens_from_ui(self, inputs: list[dict]) -> list[int]:
144
+ tokens = [self.vocab.bos_id]
145
+ for input_ in inputs:
146
+ if input_["type"] == "text":
147
+ tokens += self.tokenize_text(input_["value"])
148
+ elif input_["type"] == "image":
149
+ if type(input_["value"]) == str:
150
+ if input_["value"].startswith("data:"):
151
+ # Value Format: 'data:image/[^;]+;base64,[A-Za-z0-9+/]+={0,2}'
152
+ tokens += self.tokenize_b64img(input_["value"].split(",", 1)[1])
153
+ elif input_["value"].startswith("file:"):
154
+ tokens += self.tokenize_image(
155
+ PIL.Image.open(input_["value"].split(":", 1)[1])
156
+ )
157
+ else:
158
+ raise ValueError("Unknown image format.")
159
+ elif type(input_["value"]) == Image:
160
+ tokens += self.tokenize_image(input_["value"])
161
+ else:
162
+ raise ValueError("Unknown image type.")
163
+ elif input_["type"] == "sentinel":
164
+ tokens += [
165
+ {
166
+ "<START-OF-IMAGE>": self.vocab.begin_image,
167
+ "<END-OF-TURN>": self.vocab.eot_id,
168
+ }[input_["value"]]
169
+ ]
170
+ elif input_["type"] == "ids":
171
+ tokens += input_["value"]
172
+ else:
173
+ raise ValueError("Unknown input type.")
174
+ return tokens
175
+
176
+ def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
177
+ if isinstance(ids, torch.Tensor):
178
+ ids = ids.tolist()
179
+
180
+ for row, values in enumerate(ids):
181
+ try:
182
+ ids[row] = values[: values.index(self.vocab.eos_id)]
183
+ except ValueError:
184
+ pass
185
+
186
+ return self.tokenizer.decode_batch(ids)
187
+
188
+ def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
189
+ return [self.pil_from_bpe_tokens(sample) for sample in ids]
190
+
191
+
192
+ @dataclass
193
+ class DecodePiece:
194
+ token: ChameleonGenerator.Token
195
+ next_decoder: type["Decoder"] | None
196
+
197
+
198
+ class Decoder:
199
+ def __init__(
200
+ self,
201
+ model: Transformer,
202
+ vocab: VocabInfo,
203
+ options: Options,
204
+ input_ids: list[int],
205
+ ): ...
206
+
207
+ def __next__(self) -> DecodePiece: ...
208
+
209
+
210
+ class TextDecoder(Decoder):
211
+ def __init__(
212
+ self,
213
+ model: Transformer,
214
+ vocab: VocabInfo,
215
+ options: Options,
216
+ input_ids: list[list[int]],
217
+ ):
218
+ self.vocab = vocab
219
+ self.options = options
220
+ assert vocab.eos_id is not None
221
+
222
+ prompt_lens = [len(inp) for inp in input_ids]
223
+ max_prompt_len = max(prompt_lens)
224
+ max_seq_len = min(options.max_seq_len, max_prompt_len + options.max_gen_len)
225
+
226
+ self.eos_ids = [vocab.eos_id]
227
+ for extra_eos_token in options.extra_eos_tokens:
228
+ if isinstance(extra_eos_token, str):
229
+ extra_eos_token = vocab.name2val[extra_eos_token]
230
+ assert isinstance(extra_eos_token, int)
231
+ self.eos_ids.append(extra_eos_token)
232
+
233
+ stopping_criteria = [
234
+ MaxLengthCriteria(max_seq_len),
235
+ ] + [StopOnEOSAfterBatchIndex(eos_id, [max_prompt_len] * len(prompt_lens)) for eos_id in self.eos_ids]
236
+
237
+ self.gen = ChameleonGenerator(
238
+ model=ChameleonModelAdapter(model, max_seq_len=max_seq_len),
239
+ input_ids=input_ids,
240
+ stopping_criteria=stopping_criteria,
241
+ logits_processors=self._logits_processors(),
242
+ alignment=AlignPromptRight(vocab.pad_id),
243
+ token_selector=(
244
+ ArgmaxTokenSelector()
245
+ if options.txt.greedy
246
+ else MultinomialTokenSelector()
247
+ ),
248
+ )
249
+ advance(self.gen, max_prompt_len)
250
+
251
+ def _allowed_tokens(self) -> list[int]:
252
+ allowed_tokens = [self.vocab.eos_id]
253
+ if self.options.txt:
254
+ allowed_tokens += self.vocab.text_tokens
255
+ if self.options.img:
256
+ allowed_tokens += [self.vocab.begin_image]
257
+ return allowed_tokens
258
+
259
+ def _logits_processors(self) -> list[LogitsProcessor]:
260
+ logits_processors = [
261
+ AllowOnlyTokensLogitsProcessor(self._allowed_tokens()),
262
+ ]
263
+ if isinstance(self.options.img, Options.Image):
264
+ logits_processors += [
265
+ DisallowTokensAtOrAfterIndexLogitsProcessor(
266
+ [self.vocab.begin_image],
267
+ self.options.max_seq_len - 1026,
268
+ ),
269
+ ]
270
+ if isinstance(self.options.txt, Options.Text):
271
+ logits_processors += [
272
+ RepetitionPenaltyLogitsProcessor(self.options.txt.repetition_penalty),
273
+ TemperatureLogitsWarper(self.options.txt.temp),
274
+ TopPLogitsWarper(self.options.txt.top_p),
275
+ ]
276
+ return logits_processors
277
+
278
+ def __next__(self) -> DecodePiece:
279
+ tok = next(self.gen)
280
+ next_decoder = None
281
+ if (
282
+ self.vocab.begin_image not in self.eos_ids
283
+ and (tok.id == self.vocab.begin_image).all()
284
+ ):
285
+ next_decoder = ImageDecoder
286
+ return DecodePiece(tok, next_decoder)
287
+
288
+
289
+ class ImageDecoder(Decoder):
290
+ def __init__(
291
+ self,
292
+ model: Transformer,
293
+ vocab: VocabInfo,
294
+ options: Options,
295
+ input_ids: list[list[int]],
296
+ ):
297
+ assert isinstance(options.img, Options.Image)
298
+ self.vocab = vocab
299
+ self.options = options
300
+ self.batch_size = len(input_ids)
301
+ logits_processors = [
302
+ InBatchInstructCFGLogitsProcessor(
303
+ options.img.cfg.guidance_scale_text,
304
+ options.img.cfg.guidance_scale_image,
305
+ ),
306
+ AllowOnlyTokensLogitsProcessor(vocab.image_tokens),
307
+ TemperatureLogitsWarper(options.img.temp),
308
+ TopPLogitsWarper(options.img.top_p),
309
+ ]
310
+
311
+ for inp in input_ids:
312
+ if inp[-1] != self.vocab.begin_image:
313
+ inp.append(self.vocab.begin_image)
314
+
315
+ max_prompt_len = max(len(inp) for inp in input_ids)
316
+ self.gen = ChameleonGenerator(
317
+ model=ChameleonModelAdapter(model, max_seq_len=max_prompt_len + 1024),
318
+ input_ids=self._split_inputs_for_cfg(input_ids),
319
+ logits_processors=logits_processors,
320
+ alignment=AlignPromptRight(vocab.pad_id),
321
+ token_selector=ReplicatedInputTokenSelector(
322
+ (
323
+ ArgmaxTokenSelector()
324
+ if options.img.greedy
325
+ else MultinomialTokenSelector()
326
+ ),
327
+ n=3,
328
+ ),
329
+ )
330
+ advance(self.gen, max_prompt_len)
331
+ self.gen_count = 0
332
+
333
+ def _split_inputs_for_cfg(self, input_ids: list[list[int]]) -> list[list[int]]:
334
+ image_conditioned_allowed = set(self.vocab.image_tokens) | {
335
+ self.vocab.bos_id,
336
+ self.vocab.begin_image,
337
+ self.vocab.end_image,
338
+ }
339
+
340
+ full_conditioned = input_ids
341
+
342
+ image_conditioned = [
343
+ [id for id in sample if id in image_conditioned_allowed]
344
+ for sample in input_ids
345
+ ]
346
+
347
+ unconditioned = [
348
+ [
349
+ self.vocab.bos_id,
350
+ self.vocab.begin_image,
351
+ ]
352
+ ] * self.batch_size
353
+
354
+ return full_conditioned + image_conditioned + unconditioned
355
+
356
+ def __next__(self) -> DecodePiece:
357
+ if self.gen_count == 1024:
358
+ id = torch.tensor([self.vocab.end_image] * self.batch_size)
359
+ logits = torch.full(
360
+ (self.batch_size, len(self.vocab.all_tokens)), -math.inf
361
+ )
362
+ logits[:, self.vocab.end_image] = 0
363
+ return DecodePiece(
364
+ ChameleonGenerator.Token(id=id, logits=logits),
365
+ TextDecoder,
366
+ )
367
+
368
+ tok = next(self.gen)
369
+ tok.id = tok.id.chunk(3)[0]
370
+ self.gen_count += 1
371
+ return DecodePiece(tok, None)
372
+
373
+
374
+ class Generator(Decoder):
375
+ def __init__(
376
+ self,
377
+ model: Transformer,
378
+ vocab: VocabInfo,
379
+ options: Options,
380
+ input_ids: list[list[int]],
381
+ ):
382
+ if options.seed is not None:
383
+ enable_full_determinism(options.seed, warn_only=True)
384
+
385
+ self.model = model
386
+ self.vocab = vocab
387
+ self.input_ids = input_ids[:]
388
+ self.generated_token_ids: list[torch.LongTensor] = []
389
+ self.options = options
390
+ if not self.options.txt:
391
+ self.dyngen = DynamicGenerator(
392
+ ImageDecoder(model, vocab, options, input_ids)
393
+ )
394
+ else:
395
+ self.dyngen = DynamicGenerator(
396
+ TextDecoder(model, vocab, options, input_ids)
397
+ )
398
+
399
+ def __iter__(self):
400
+ return self
401
+
402
+ def __next__(self) -> ChameleonGenerator.Token:
403
+ piece = next(self.dyngen)
404
+ self.generated_token_ids.append(piece.token.id)
405
+ if piece.next_decoder is not None:
406
+ if not self.options.txt:
407
+ raise StopIteration
408
+
409
+ self.input_ids = [
410
+ old_list + generated
411
+ for old_list, generated in zip(
412
+ self.input_ids, torch.stack(self.generated_token_ids).T.tolist()
413
+ )
414
+ ]
415
+ self.generated_token_ids = []
416
+ self.dyngen.gen = piece.next_decoder(
417
+ self.model,
418
+ self.vocab,
419
+ self.options,
420
+ self.input_ids,
421
+ )
422
+ return piece.token
423
+
424
+
425
+ class DistributedMode(Enum):
426
+ AUTO = 0
427
+ THREAD = 1
428
+ PROCESS = 2
429
+
430
+
431
+ @dataclass
432
+ class _DistributedContext:
433
+ req_q: Union[queue.Queue, queues.Queue]
434
+ res_q: Union[queue.Queue, queues.Queue]
435
+ active_key: Union[dict[int, Literal[True]], managers.DictProxy]
436
+ active_key_lock: Union[threading.Lock, synchronize.Lock]
437
+ ready_barrier: Union[threading.Barrier, synchronize.Barrier]
438
+ worker_launcher: Union[type[threading.Thread], type[mp.Process]]
439
+
440
+ @staticmethod
441
+ def make_for_threading(world_size: int):
442
+ return _DistributedContext(
443
+ req_q=queue.Queue(),
444
+ res_q=queue.Queue(),
445
+ active_key={},
446
+ active_key_lock=threading.Lock(),
447
+ ready_barrier=threading.Barrier(world_size + 1),
448
+ worker_launcher=threading.Thread,
449
+ )
450
+
451
+ @staticmethod
452
+ def make_for_multiprocessing(world_size: int):
453
+ local_mp = mp.get_context("spawn")
454
+ return _DistributedContext(
455
+ req_q=local_mp.Queue(),
456
+ res_q=local_mp.Queue(),
457
+ active_key=local_mp.Manager().dict(),
458
+ active_key_lock=local_mp.Lock(),
459
+ ready_barrier=local_mp.Barrier(world_size + 1),
460
+ worker_launcher=local_mp.Process,
461
+ )
462
+
463
+ @staticmethod
464
+ def make(mode: DistributedMode, world_size: int):
465
+ if mode == DistributedMode.AUTO:
466
+ mode = DistributedMode.PROCESS
467
+
468
+ if mode == DistributedMode.THREAD:
469
+ return _DistributedContext.make_for_threading(world_size)
470
+ elif mode == DistributedMode.PROCESS:
471
+ return _DistributedContext.make_for_multiprocessing(world_size)
472
+ else:
473
+ raise ValueError("Unknown DistributedMode")
474
+
475
+
476
+ def _worker_impl(
477
+ init_method: str,
478
+ model: Transformer | str,
479
+ world_size: int,
480
+ rank: int,
481
+ vocab: VocabInfo,
482
+ dctx: _DistributedContext,
483
+ ):
484
+ dist.init_process_group(
485
+ "nccl",
486
+ init_method=init_method,
487
+ world_size=world_size,
488
+ rank=rank,
489
+ )
490
+
491
+ torch.set_default_device(f"cuda:{rank}")
492
+ torch.cuda.set_device(rank)
493
+ if isinstance(model, str):
494
+ model = loader.load_model(model, rank=rank)
495
+ dctx.ready_barrier.wait()
496
+
497
+ is_coord = rank == 0
498
+
499
+ while True:
500
+ req = [Options(), [], 0, False]
501
+ if is_coord:
502
+ req = dctx.req_q.get()
503
+
504
+ dist.broadcast_object_list(req, src=0)
505
+ options, input_ids, key, shutdown = req
506
+ if shutdown:
507
+ break
508
+
509
+ for token in Generator(
510
+ model=model,
511
+ vocab=vocab,
512
+ options=options,
513
+ input_ids=input_ids,
514
+ ):
515
+ if is_coord:
516
+ dctx.res_q.put((key, token))
517
+
518
+ to_continue = [True]
519
+ if is_coord:
520
+ with dctx.active_key_lock:
521
+ to_continue = [key in dctx.active_key]
522
+ dist.broadcast_object_list(to_continue, src=0)
523
+ if not to_continue[0]:
524
+ break
525
+
526
+ if is_coord:
527
+ dctx.res_q.put((key, None))
528
+
529
+
530
+ class ChameleonInferenceModel:
531
+ def __init__(
532
+ self,
533
+ model: Transformer | str,
534
+ tokenizer_path: str,
535
+ vqgan_cfg_path: str,
536
+ vqgan_ckpt_path: str,
537
+ *,
538
+ options: Options | None = None,
539
+ distributed_mode: DistributedMode = DistributedMode.AUTO,
540
+ ):
541
+ self.options = options or Options()
542
+ self.next_key = 0
543
+
544
+ self.token_manager = TokenManager(
545
+ tokenizer_path=tokenizer_path,
546
+ vqgan_cfg_path=vqgan_cfg_path,
547
+ vqgan_ckpt_path=vqgan_ckpt_path,
548
+ device="cuda",
549
+ )
550
+ self.vocab = self.token_manager.vocab
551
+
552
+ world_size = 1
553
+ if isinstance(model, str):
554
+ world_size = loader.detect_shard_count(model)
555
+ self.dctx = _DistributedContext.make(distributed_mode, world_size)
556
+
557
+ init_method = f"tcp://0.0.0.0:{random_unused_port()}"
558
+ self.workers = [
559
+ self.dctx.worker_launcher(
560
+ target=_worker_impl,
561
+ args=(init_method, model, world_size, i, self.vocab, self.dctx),
562
+ daemon=True,
563
+ )
564
+ for i in range(world_size)
565
+ ]
566
+ for w in self.workers:
567
+ w.start()
568
+ self.dctx.ready_barrier.wait()
569
+
570
+ def __del__(self):
571
+ try:
572
+ with self.dctx.active_key_lock:
573
+ self.dctx.active_key.clear()
574
+ self.dctx.req_q.put([None, None, None, True])
575
+ for w in self.workers:
576
+ w.join()
577
+ except FileNotFoundError:
578
+ pass
579
+
580
+ def stream(
581
+ self,
582
+ *,
583
+ input_ids: list[int] | None = None,
584
+ prompt_text: str | None = None,
585
+ prompt_ui: list[dict] | None = None,
586
+ batch_input_ids: list[list[int]] | None = None,
587
+ batch_prompt_text: list[str] | None = None,
588
+ batch_prompt_ui: list[list[dict]] | None = None,
589
+ options: Options | None = None,
590
+ ):
591
+ # NOTE: Not thread-safe! Only one instance of generate may be run at a time.
592
+
593
+ if (
594
+ sum(
595
+ x is not None
596
+ for x in [
597
+ input_ids,
598
+ prompt_text,
599
+ prompt_ui,
600
+ batch_input_ids,
601
+ batch_prompt_text,
602
+ batch_prompt_ui,
603
+ ]
604
+ )
605
+ != 1
606
+ ):
607
+ raise ValueError(
608
+ "Must specify exactly one of: input_ids, prompt_text, prompt_ui, batch_input_ids, batch_prompt_text, batch_prompt_ui"
609
+ )
610
+
611
+ options = options or self.options
612
+
613
+ if prompt_text is not None:
614
+ batch_prompt_text = [prompt_text]
615
+ if prompt_ui is not None:
616
+ batch_prompt_ui = [prompt_ui]
617
+ if input_ids is not None:
618
+ batch_input_ids = [input_ids]
619
+ if batch_prompt_text is not None:
620
+ batch_prompt_ui = [
621
+ [{"type": "text", "value": prompt_text}]
622
+ for prompt_text in batch_prompt_text
623
+ ]
624
+ if batch_prompt_ui is not None:
625
+ batch_input_ids = [
626
+ self.token_manager.tokens_from_ui(prompt_ui)
627
+ for prompt_ui in batch_prompt_ui
628
+ ]
629
+
630
+ assert batch_input_ids
631
+
632
+ if not options.txt and not options.img:
633
+ raise ValueError("Must specify at least one modality.")
634
+ if options.txt and options.img and len(batch_input_ids) > 1:
635
+ raise ValueError(
636
+ "Batch generation only supported for one modality at a time."
637
+ )
638
+
639
+ req_key = self.next_key
640
+ self.next_key += 1
641
+
642
+ with self.dctx.active_key_lock:
643
+ self.dctx.active_key[req_key] = True
644
+
645
+ self.dctx.req_q.put([options, batch_input_ids, req_key, False])
646
+
647
+ try:
648
+ while key_token := self.dctx.res_q.get():
649
+ key, token = key_token
650
+ if key != req_key:
651
+ # Residual from prior calls to generation. Skip.
652
+ continue
653
+ if token is None:
654
+ break
655
+ yield token
656
+ finally:
657
+ with self.dctx.active_key_lock:
658
+ del self.dctx.active_key[req_key]
659
+
660
+ def step(self, *args, **kwargs) -> ChameleonGenerator.Token:
661
+ return next(self.stream(*args, **kwargs))
662
+
663
+ def generate(self, *args, **kwargs) -> torch.LongTensor:
664
+ tokens = [t.id for t in self.stream(*args, **kwargs)]
665
+ if not tokens:
666
+ return torch.LongTensor()
667
+ return torch.stack(tokens).T
668
+
669
+ def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
670
+ return self.token_manager.decode_text(ids)
671
+
672
+ def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
673
+ return self.token_manager.decode_image(ids)
chameleon/inference/cudagraph.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import functools
7
+ from typing import Any, Callable, TypeVar
8
+
9
+ import torch
10
+
11
+ T = TypeVar("T")
12
+ FN = Callable[..., T] # type: ignore
13
+
14
+
15
+ class CUDAGraphWrapper:
16
+ def __init__(
17
+ self,
18
+ fn: FN[T],
19
+ warmup_iter: int = 1,
20
+ debug_dump_path: str | None = None,
21
+ ):
22
+ self.fn = fn
23
+ self.warmup_iter = warmup_iter
24
+ self.debug_dump_path = debug_dump_path
25
+ self.graph: torch.cuda.CUDAGraph | None = None
26
+ self.result: T | None = None
27
+
28
+ def __call__(self, *args, **kwargs) -> Any: # type: ignore
29
+ if self.warmup_iter > 0:
30
+ self.warmup_iter -= 1
31
+ return self.fn(*args, **kwargs)
32
+
33
+ if self.graph is None:
34
+ self.graph = torch.cuda.CUDAGraph()
35
+ if self.debug_dump_path is not None:
36
+ self.graph.enable_debug_mode()
37
+ recording_kwargs = {}
38
+ if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
39
+ # In PyTorch 2.1+ and nightlies from late Aug 2023,
40
+ # we can do this to maybe avoid watchdog-related crashes
41
+ recording_kwargs["capture_error_mode"] = "thread_local"
42
+ with torch.cuda.graph(self.graph, **recording_kwargs):
43
+ self.result = self.fn(*args, **kwargs)
44
+ torch.cuda.synchronize()
45
+ if self.debug_dump_path is not None:
46
+ self.graph.debug_dump(self.debug_dump_path)
47
+
48
+ assert self.graph is not None
49
+ self.graph.replay()
50
+ return self.result
51
+
52
+
53
+ def cudagraph_wrap(
54
+ *args,
55
+ warmup_iter: int = 1,
56
+ debug_dump_path: str | None = None,
57
+ ) -> Callable[[FN[T]], FN[T]]:
58
+ def wrapper(fn: FN[T]) -> FN[T]:
59
+ graph_wrapper = CUDAGraphWrapper(
60
+ fn, warmup_iter=warmup_iter, debug_dump_path=debug_dump_path
61
+ )
62
+
63
+ @functools.wraps(fn)
64
+ def call_wrapper(*inner_args, **inner_kwargs):
65
+ return graph_wrapper(*inner_args, **inner_kwargs)
66
+
67
+ return call_wrapper
68
+
69
+ # @cudagraph_wrap
70
+ # def fn(...):
71
+ # ...
72
+ #
73
+ # - or -
74
+ #
75
+ # fast_fn = cudagraph_wrap(slow_fn, warmup_iter=2)
76
+ if len(args) == 1 and callable(args[0]):
77
+ return wrapper(args[0])
78
+
79
+ # @cudagraph_wrap(warmup_iter=3)
80
+ # def fn(...):
81
+ # ...
82
+ def decorator(fn: FN[T]) -> FN[T]:
83
+ return wrapper(fn)
84
+
85
+ return decorator
chameleon/inference/examples/batch.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from chameleon.inference.chameleon import ChameleonInferenceModel
7
+
8
+
9
+ def main():
10
+ model = ChameleonInferenceModel(
11
+ "./data/models/7b/",
12
+ "./data/tokenizer/text_tokenizer.json",
13
+ "./data/tokenizer/vqgan.yaml",
14
+ "./data/tokenizer/vqgan.ckpt",
15
+ )
16
+
17
+ batch_tokens = model.generate(batch_prompt_text=["All your base", "import asyncio"])
18
+ for text in model.decode_text(batch_tokens):
19
+ print(text)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ main()
chameleon/inference/examples/multimodal_input.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from chameleon.inference.chameleon import ChameleonInferenceModel
7
+
8
+
9
+ def main():
10
+ model = ChameleonInferenceModel(
11
+ "./data/models/7b/",
12
+ "./data/tokenizer/text_tokenizer.json",
13
+ "./data/tokenizer/vqgan.yaml",
14
+ "./data/tokenizer/vqgan.ckpt",
15
+ )
16
+
17
+ tokens = model.generate(
18
+ prompt_ui=[
19
+ {"type": "image", "value": "file:/path/to/image.jpeg"},
20
+ {"type": "text", "value": "What do you see?"},
21
+ {"type": "sentinel", "value": "<END-OF-TURN>"},
22
+ ]
23
+ )
24
+ print(model.decode_text(tokens)[0])
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
chameleon/inference/examples/simple.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from chameleon.inference.chameleon import ChameleonInferenceModel
7
+
8
+
9
+ def main():
10
+ model = ChameleonInferenceModel(
11
+ "./data/models/7b/",
12
+ "./data/tokenizer/text_tokenizer.json",
13
+ "./data/tokenizer/vqgan.yaml",
14
+ "./data/tokenizer/vqgan.ckpt",
15
+ )
16
+
17
+ tokens = model.generate(prompt_text="All your base")
18
+ print(model.decode_text(tokens)[0])
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main()
chameleon/inference/examples/streaming.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from chameleon.inference.chameleon import ChameleonInferenceModel
7
+
8
+
9
+ def main():
10
+ model = ChameleonInferenceModel(
11
+ "./data/models/7b/",
12
+ "./data/tokenizer/text_tokenizer.json",
13
+ "./data/tokenizer/vqgan.yaml",
14
+ "./data/tokenizer/vqgan.ckpt",
15
+ )
16
+
17
+ for tokens in model.stream(prompt_text="All your base"):
18
+ print(model.decode_text(tokens.id.view(-1, 1))[0], end="")
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main()
chameleon/inference/examples/streaming_batch.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from chameleon.inference.chameleon import ChameleonInferenceModel
7
+
8
+
9
+ def main():
10
+ model = ChameleonInferenceModel(
11
+ "./data/models/7b/",
12
+ "./data/tokenizer/text_tokenizer.json",
13
+ "./data/tokenizer/vqgan.yaml",
14
+ "./data/tokenizer/vqgan.ckpt",
15
+ )
16
+
17
+ for i, batch_tokens in enumerate(
18
+ model.stream(batch_prompt_text=["All your base", "import asyncio"])
19
+ ):
20
+ print(model.decode_text(batch_tokens.id.view(-1, 1)))
21
+
22
+
23
+ if __name__ == "__main__":
24
+ main()
chameleon/inference/generation.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ from transformers import (
10
+ LogitsProcessor,
11
+ LogitsProcessorList,
12
+ )
13
+ from transformers.generation.streamers import BaseStreamer
14
+
15
+ from chameleon.inference.alignment import AlignPromptLeft, PromptAlignment
16
+ from chameleon.inference.model_adapter import ModelAdapter
17
+ from chameleon.inference.stopping_criteria import StoppingCriteria, StoppingCriteriaList
18
+ from chameleon.inference.token_selector import MultinomialTokenSelector, TokenSelector
19
+
20
+
21
+ class ChameleonGenerator:
22
+ @dataclass
23
+ class Token:
24
+ id: torch.LongTensor
25
+ logits: torch.Tensor | None
26
+
27
+ def __init__(
28
+ self,
29
+ model: ModelAdapter,
30
+ input_ids: list[list[int]],
31
+ stopping_criteria: StoppingCriteriaList | list[StoppingCriteria] | None = None,
32
+ logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
33
+ probability_processors: LogitsProcessorList
34
+ | list[LogitsProcessor]
35
+ | None = None,
36
+ token_selector: TokenSelector | None = None,
37
+ alignment: PromptAlignment = AlignPromptLeft(),
38
+ ):
39
+ assert model.supports_alignment(alignment)
40
+
41
+ self.model = model
42
+
43
+ self.stopping_criteria = stopping_criteria
44
+ self.logits_processors = logits_processors
45
+ self.probability_processors = probability_processors
46
+ self.token_selector: TokenSelector = (
47
+ token_selector or MultinomialTokenSelector()
48
+ )
49
+
50
+ self.alignment = alignment
51
+
52
+ self.model.initialize(input_ids)
53
+
54
+ self._inputs = self.alignment.prepare_inputs(
55
+ input_ids
56
+ ) # inputs.shape = [batch, seq-len]
57
+
58
+ self._idx = 0
59
+ self._start_idx = self.alignment.start_index(input_ids)
60
+
61
+ self._original_inputs = self._inputs.clone()
62
+ self._inputs = self._inputs[:, : self._start_idx]
63
+
64
+ def __iter__(self):
65
+ return self
66
+
67
+ @torch.inference_mode()
68
+ def __next__(self) -> Token:
69
+ # Are we done?
70
+ if self.stopping_criteria(self._inputs, None):
71
+ raise StopIteration
72
+
73
+ # Emit initial tokens.
74
+ # Model is not run for these.
75
+ # If you want the logits, you can do a separate forward pass outside generation.
76
+ if self._idx < self._start_idx:
77
+ idx, self._idx = self._idx, self._idx + 1
78
+ return ChameleonGenerator.Token(id=self._inputs[:, idx], logits=None)
79
+
80
+ # Run the model for the next token.
81
+ self._inputs = self._inputs.contiguous()
82
+ outputs = self.model(self._inputs) # outputs.shape = [batch, seq-len, vocab]
83
+
84
+ # Pull out and process the logits.
85
+ logits = outputs[:, -1, :] # logits.shape = [batch, vocab]
86
+ logits = self.logits_processors(self._inputs, logits)
87
+ probs = logits.softmax(dim=1) # probs.shape = [batch, vocab]
88
+ probs = self.probability_processors(self._inputs, probs)
89
+
90
+ # Select a token and add it to the inputs.
91
+ next_tokens = self.token_selector(
92
+ self._inputs, probs
93
+ ) # next_tokens.shape = [batch]
94
+ self._inputs = torch.cat([self._inputs, next_tokens[:, None]], dim=1)
95
+
96
+ # Run alignment specific postprocessing.
97
+ self._inputs = self.alignment.postprocess_inputs(
98
+ self._inputs, self._original_inputs
99
+ )
100
+
101
+ # Return the next step result.
102
+ return ChameleonGenerator.Token(id=self._inputs[:, -1], logits=logits)
103
+
104
+ @property
105
+ def stopping_criteria(self) -> StoppingCriteriaList:
106
+ return self._stopping_criteria
107
+
108
+ @stopping_criteria.setter
109
+ def stopping_criteria(
110
+ self, value: StoppingCriteriaList | list[StoppingCriteria] | None
111
+ ):
112
+ self._stopping_criteria = StoppingCriteriaList(value or [])
113
+
114
+ @property
115
+ def logits_processors(self) -> LogitsProcessorList:
116
+ return self._logits_processors
117
+
118
+ @logits_processors.setter
119
+ def logits_processors(
120
+ self, value: LogitsProcessorList | list[LogitsProcessor] | None
121
+ ):
122
+ self._logits_processors = LogitsProcessorList(value or [])
123
+
124
+ @property
125
+ def probability_processors(self) -> LogitsProcessorList:
126
+ return self._probability_processors
127
+
128
+ @probability_processors.setter
129
+ def probability_processors(
130
+ self, value: LogitsProcessorList | list[LogitsProcessor] | None
131
+ ):
132
+ self._probability_processors = LogitsProcessorList(value or [])
133
+
134
+
135
+ def run_generation(
136
+ model: torch.nn.Module,
137
+ input_ids: list[list[int]],
138
+ stopping_criteria: StoppingCriteriaList | list[StoppingCriteria],
139
+ logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
140
+ probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
141
+ token_selector: TokenSelector | None = None,
142
+ alignment: PromptAlignment = AlignPromptLeft(),
143
+ streamer: BaseStreamer | None = None,
144
+ ) -> torch.LongTensor:
145
+ result = torch.empty((len(input_ids), 0), dtype=int)
146
+ for tok in ChameleonGenerator(
147
+ model=model,
148
+ input_ids=input_ids,
149
+ stopping_criteria=stopping_criteria,
150
+ logits_processors=logits_processors,
151
+ probability_processors=probability_processors,
152
+ token_selector=token_selector,
153
+ alignment=alignment,
154
+ ):
155
+ if streamer is not None:
156
+ streamer.put(tok.id)
157
+ result = torch.cat([result, tok.id.view(-1, 1)], dim=1)
158
+
159
+ if streamer is not None:
160
+ streamer.end()
161
+
162
+ return result
chameleon/inference/image_tokenizer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import yaml
10
+ from PIL import Image
11
+
12
+ from chameleon.inference.vqgan import VQModel
13
+
14
+
15
+ class ImageTokenizer:
16
+ def __init__(
17
+ self,
18
+ cfg_path: str,
19
+ ckpt_path: str,
20
+ device: str,
21
+ ):
22
+ with open(cfg_path) as f:
23
+ config = yaml.safe_load(f)
24
+
25
+ params = config["model"]["params"]
26
+ if "lossconfig" in params:
27
+ del params["lossconfig"]
28
+ params["ckpt_path"] = ckpt_path
29
+
30
+ self._vq_model = VQModel(**params)
31
+ self._vq_model.eval()
32
+
33
+ if device is None:
34
+ devices = {p.device for p in self._vq_model.parameters()}
35
+ assert len(devices) == 1
36
+ device = devices.pop()
37
+ else:
38
+ self._vq_model.to(device)
39
+ self._device = device
40
+
41
+ dtypes = {p.dtype for p in self._vq_model.parameters()}
42
+ assert len(dtypes) == 1
43
+ self._dtype = dtypes.pop()
44
+
45
+ def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
46
+ # Check if it's already in RGB format.
47
+ if img.mode == "RGB":
48
+ return img
49
+
50
+ vals_rgba = np.array(img.convert("RGBA"))
51
+
52
+ # If there is no transparency layer, simple convert and return.
53
+ if not (vals_rgba[:, :, 3] < 255).any():
54
+ return img.convert("RGB")
55
+
56
+ # There is a transparency layer, blend it with a white background.
57
+
58
+ # Calculate the alpha proportion for blending.
59
+ alpha = vals_rgba[:, :, 3] / 255.0
60
+ # Blend with white background.
61
+ vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[
62
+ :, :, np.newaxis
63
+ ] * vals_rgba[:, :, :3]
64
+ return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
65
+
66
+ def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
67
+ # Resize with aspect ratio preservation.
68
+ s = min(img.size)
69
+ scale = target_image_size / s
70
+ new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
71
+ img = img.resize(new_size, PIL.Image.LANCZOS)
72
+
73
+ # Center crop.
74
+ x0 = (img.width - target_image_size) // 2
75
+ y0 = (img.height - target_image_size) // 2
76
+ img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
77
+
78
+ # Convert to tensor.
79
+ np_img = np.array(img) / 255.0 # Normalize to [0, 1]
80
+ np_img = np_img * 2 - 1 # Scale to [-1, 1]
81
+ tensor_img = (
82
+ torch.from_numpy(np_img).permute(2, 0, 1).float()
83
+ ) # (Channels, Height, Width) format.
84
+
85
+ # Add batch dimension.
86
+ return tensor_img.unsqueeze(0)
87
+
88
+ def img_tokens_from_pil(self, image: PIL.Image) -> list[int]:
89
+ image = self._whiten_transparency(image)
90
+ vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype)
91
+ _, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input)
92
+ return img_toks
93
+
94
+ def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
95
+ # Ensure detachment and move tensor to CPU.
96
+ detached_chw_tensor = chw_tensor.detach().cpu()
97
+
98
+ # Normalize tensor to [0, 1] range from [-1, 1] range.
99
+ normalized_chw_tensor = (
100
+ torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
101
+ ) / 2.0
102
+
103
+ # Permute CHW tensor to HWC format and convert to NumPy array.
104
+ hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
105
+
106
+ # Convert to an 8-bit unsigned integer format.
107
+ image_array_uint8 = (hwc_array * 255).astype(np.uint8)
108
+
109
+ # Convert NumPy array to PIL Image.
110
+ pil_image = Image.fromarray(image_array_uint8)
111
+
112
+ # Convert image to RGB if it is not already.
113
+ if pil_image.mode != "RGB":
114
+ pil_image = pil_image.convert("RGB")
115
+
116
+ return pil_image
117
+
118
+ def pil_from_img_toks(self, img_tensor: torch.Tensor, height=32,width=32) -> PIL.Image:
119
+ emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
120
+ # import pdb;pdb.set_trace()
121
+ codebook_entry = self._vq_model.quantize.get_codebook_entry(
122
+ img_tensor, (1, height, width, emb_dim)
123
+ )
124
+ # import pdb;pdb.set_trace()
125
+ pixels = self._vq_model.decode(codebook_entry)
126
+ # import pdb;pdb.set_trace()
127
+ return self._pil_from_chw_tensor(pixels[0])
chameleon/inference/loader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import glob
7
+ import inspect
8
+ import json
9
+ from pathlib import Path
10
+
11
+ import torch
12
+
13
+ from chameleon.inference.transformer import ModelArgs, Transformer
14
+
15
+
16
+ def _convert(model_args: ModelArgs, consolidated_path: Path) -> Transformer:
17
+ old_default_dtype = torch.get_default_dtype()
18
+ torch.set_default_dtype(torch.bfloat16)
19
+
20
+ model = Transformer(model_args)
21
+
22
+ transfer_results = model.load_state_dict(
23
+ torch.load(str(consolidated_path)),
24
+ strict=False,
25
+ )
26
+
27
+ # TODO: More generally, assert missing or unexpected keys are buffers.
28
+ assert transfer_results.missing_keys == []
29
+ assert transfer_results.unexpected_keys == ["rope.freqs"]
30
+
31
+ model.eval()
32
+
33
+ torch.set_default_dtype(old_default_dtype)
34
+ return model
35
+
36
+
37
+ def _get_checkpoint_path(src_dir: Path, rank: int | None) -> Path:
38
+ base_path = src_dir / "consolidated.pth"
39
+ if not rank and base_path.exists():
40
+ return base_path
41
+
42
+ alt_path = src_dir / f"consolidated.{rank:02}.pth"
43
+ if alt_path.exists():
44
+ return alt_path
45
+
46
+ raise ValueError("Consolidated checkpoint not found.")
47
+
48
+
49
+ def load_model(path: str, rank: int | None = None) -> Transformer:
50
+ src_dir = Path(path)
51
+
52
+ with open(src_dir / "params.json", "r") as f:
53
+ params = json.loads(f.read())
54
+ with open(src_dir / "consolidate_params.json", "r") as f:
55
+ consolidate_params = json.loads(f.read())
56
+ params = {**params, **params["model"], **consolidate_params}
57
+
58
+ known_param = inspect.signature(ModelArgs.__init__).parameters
59
+ filtered_params = {k: v for k, v in params.items() if k in known_param}
60
+
61
+ return _convert(
62
+ ModelArgs(**filtered_params),
63
+ _get_checkpoint_path(src_dir, rank),
64
+ )
65
+
66
+
67
+ def detect_shard_count(path: str) -> int:
68
+ src_dir = Path(path)
69
+ if (src_dir / "consolidated.pth").exists():
70
+ return 1
71
+ return len(glob.glob(str(src_dir / "consolidated.*.pth")))
chameleon/inference/logits_processor.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ from transformers import LogitsProcessor
10
+
11
+
12
+ class TopPProbabilityProcessor(LogitsProcessor):
13
+ # Modified version of TopPLogitsWarper to act on probabilities.
14
+ # Changes:
15
+ # * filter_value changed from -inf to 0
16
+ # * removed softmax
17
+ # * renormalize L1
18
+
19
+ def __init__(
20
+ self,
21
+ top_p: float,
22
+ min_tokens_to_keep: int = 1,
23
+ ):
24
+ top_p = float(top_p)
25
+ if top_p < 0 or top_p > 1.0:
26
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
27
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
28
+ raise ValueError(
29
+ f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
30
+ )
31
+
32
+ self.top_p = top_p
33
+ self.min_tokens_to_keep = min_tokens_to_keep
34
+
35
+ def __call__(
36
+ self, input_ids: torch.LongTensor, probs: torch.FloatTensor
37
+ ) -> torch.FloatTensor:
38
+ # input_ids.shape=[batch, seq-len]
39
+ # probs.shape=[batch, vocab]
40
+ sorted_probs, sorted_indices = torch.sort(probs, descending=False)
41
+ cumulative_probs = sorted_probs.cumsum(dim=-1)
42
+
43
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
44
+ sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
45
+ # Keep at least min_tokens_to_keep
46
+ sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
47
+
48
+ # scatter sorted tensors to original indexing
49
+ indices_to_remove = sorted_indices_to_remove.scatter(
50
+ 1, sorted_indices, sorted_indices_to_remove
51
+ )
52
+ probs = probs.masked_fill(indices_to_remove, 0.0)
53
+ probs = probs / probs.sum(dim=-1, keepdim=True)
54
+ return probs
55
+
56
+
57
+ class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor):
58
+ def __init__(
59
+ self, token_ids: list[int], start_index: int, end_index: int | None = None
60
+ ):
61
+ self.token_ids = torch.tensor(token_ids)
62
+ self.start_index = start_index
63
+ self.end_index = end_index if end_index is not None else math.inf
64
+
65
+ def __call__(
66
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
67
+ ) -> torch.FloatTensor:
68
+ current_index = input_ids.shape[1]
69
+ if self.start_index <= current_index < self.end_index:
70
+ logits[:, self.token_ids] = -math.inf
71
+ return logits
72
+
73
+
74
+ class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
75
+ def __init__(self, token_ids: list[int]):
76
+ super().__init__(token_ids, 0)
77
+
78
+
79
+ class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
80
+ def __init__(self, token_ids: list[int], index: int):
81
+ super().__init__(token_ids, index, index + 1)
82
+
83
+
84
+ class DisallowTokensAfterIndexLogitsProcessor(
85
+ DisallowTokensInIndexRangeLogitsProcessor
86
+ ):
87
+ def __init__(self, token_ids: list[int], index: int):
88
+ super().__init__(token_ids, index + 1)
89
+
90
+
91
+ class DisallowTokensAtOrAfterIndexLogitsProcessor(
92
+ DisallowTokensInIndexRangeLogitsProcessor
93
+ ):
94
+ def __init__(self, token_ids: list[int], index: int):
95
+ super().__init__(token_ids, index)
96
+
97
+
98
+ class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
99
+ def __init__(
100
+ self,
101
+ token_ids: list[int],
102
+ start_indices: list[int],
103
+ end_indices: list[int] | None = None,
104
+ ):
105
+ self.token_ids = torch.tensor(token_ids)
106
+ self.start_indices = torch.tensor(start_indices)
107
+ self.end_indices = (
108
+ torch.tensor(end_indices)
109
+ if end_indices is not None
110
+ else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
111
+ )
112
+
113
+ def __call__(
114
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
115
+ ) -> torch.FloatTensor:
116
+ # input_ids.shape = [batch, seq_len]
117
+ # logits.shape = [batch, vocab]
118
+ current_index = input_ids.shape[1]
119
+ mask = (self.start_indices <= current_index) & (
120
+ current_index < self.end_indices
121
+ )
122
+ # The following will fail if the mask is all False.
123
+ # logits[mask, self.token_ids] = -math.inf
124
+ logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf
125
+ return logits
126
+
127
+
128
+ class DisallowTokensAtBatchIndexLogitsProcessor(
129
+ DisallowTokensInBatchIndexRangeLogitsProcessor
130
+ ):
131
+ def __init__(self, token_ids: list[int], batch_index: list[int]):
132
+ super().__init__(token_ids, batch_index, [i + 1 for i in batch_index])
133
+
134
+
135
+ class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor):
136
+ def __init__(
137
+ self, token_ids: list[int], start_index: int, end_index: int | None = None
138
+ ):
139
+ self.token_ids = torch.tensor(token_ids)
140
+ self.start_index = start_index
141
+ self.end_index = end_index if end_index is not None else math.inf
142
+
143
+ def __call__(
144
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
145
+ ) -> torch.FloatTensor:
146
+ current_index = input_ids.shape[1]
147
+ if self.start_index <= current_index < self.end_index:
148
+ replacement = torch.full_like(logits, -math.inf)
149
+ replacement[:, self.token_ids] = logits[:, self.token_ids]
150
+ logits[:] = replacement
151
+ return logits
152
+
153
+
154
+ class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
155
+ def __init__(self, token_ids: list[int]):
156
+ super().__init__(token_ids, 0)
157
+
158
+
159
+ class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
160
+ def __init__(self, token_ids: list[int], index: int):
161
+ super().__init__(token_ids, index, index + 1)
162
+
163
+
164
+ class AllowOnlyTokensAfterIndexLogitsProcessor(
165
+ AllowOnlyTokensInIndexRangeLogitsProcessor
166
+ ):
167
+ def __init__(self, token_ids: list[int], index: int):
168
+ super().__init__(token_ids, index + 1)
169
+
170
+
171
+ class AllowOnlyTokensAtOrAfterIndexLogitsProcessor(
172
+ AllowOnlyTokensInIndexRangeLogitsProcessor
173
+ ):
174
+ def __init__(self, token_ids: list[int], index: int):
175
+ super().__init__(token_ids, index)
176
+
177
+
178
+ class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
179
+ def __init__(
180
+ self,
181
+ token_ids: list[int],
182
+ start_indices: list[int],
183
+ end_indices: list[int] | None = None,
184
+ ):
185
+ self.token_ids = torch.tensor(token_ids)
186
+ self.start_indices = torch.tensor(start_indices)
187
+ self.end_indices = (
188
+ torch.tensor(end_indices)
189
+ if end_indices is not None
190
+ else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
191
+ )
192
+
193
+ def __call__(
194
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
195
+ ) -> torch.FloatTensor:
196
+ # input_ids.shape = [batch, seq_len]
197
+ # logits.shape = [batch, vocab]
198
+ current_index = input_ids.shape[1]
199
+ mask = (self.start_indices <= current_index) & (
200
+ current_index < self.end_indices
201
+ )
202
+
203
+ valid_batch_indices = torch.where(mask)[0].unsqueeze(1)
204
+ full_mask = torch.full_like(logits, -math.inf)
205
+ full_mask[valid_batch_indices, self.token_ids] = logits[
206
+ valid_batch_indices, self.token_ids
207
+ ]
208
+
209
+ logits[:] = torch.where(full_mask != -math.inf, full_mask, logits)
210
+ return logits
211
+
212
+
213
+ class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor):
214
+ def __init__(
215
+ self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int
216
+ ):
217
+ self.trigger_token_id = trigger_token_id
218
+ self.subsequent_token_ids = torch.tensor(subsequent_token_ids)
219
+ self.offset = offset
220
+
221
+ def __call__(
222
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
223
+ ) -> torch.FloatTensor:
224
+ # input_ids.shape=[batch, seq_len]
225
+ # logits.shape=[batch, vocab]
226
+ if input_ids.shape[1] < self.offset:
227
+ return logits
228
+
229
+ trigger_positions = (
230
+ input_ids[:, -self.offset] == self.trigger_token_id
231
+ ).unsqueeze(-1)
232
+
233
+ disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
234
+ disallowed_tokens_mask[:, self.subsequent_token_ids] = False
235
+
236
+ return logits.masked_fill_(
237
+ disallowed_tokens_mask & trigger_positions,
238
+ -math.inf,
239
+ )
240
+
241
+
242
+ class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
243
+ def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int):
244
+ self.trigger_token_id = trigger_token_id
245
+ self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze(
246
+ 0
247
+ ) # shape: [1, num_allowed_tokens]
248
+ self.width = width
249
+
250
+ def __call__(
251
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
252
+ ) -> torch.FloatTensor:
253
+ # input_ids.shape=[batch, seq_len]
254
+ # logits.shape=[batch, vocab]
255
+ width = min(self.width, input_ids.shape[1])
256
+ trigger_positions = (
257
+ (input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)
258
+ )
259
+
260
+ disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
261
+ disallowed_tokens_mask[:, self.allowed_token_ids] = False
262
+
263
+ return logits.masked_fill_(
264
+ disallowed_tokens_mask & trigger_positions,
265
+ -math.inf,
266
+ )
267
+
268
+
269
+ class CFGLogitsProcessor(LogitsProcessor):
270
+ def __init__(
271
+ self,
272
+ guidance_scale: float,
273
+ unconditional_ids: torch.LongTensor,
274
+ model,
275
+ ):
276
+ self.guidance_scale = guidance_scale
277
+ self.unconditional_ids = unconditional_ids
278
+ self.model = model
279
+
280
+ def __call__(
281
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
282
+ ) -> torch.FloatTensor:
283
+ conditioned_logits = logits
284
+
285
+ self.unconditional_ids = torch.cat(
286
+ [self.unconditional_ids, input_ids[:, -1:]], dim=1
287
+ )
288
+ unconditioned_outputs = self.model(self.unconditional_ids)
289
+ unconditioned_logits = unconditioned_outputs[:, -1, :]
290
+ return (
291
+ self.guidance_scale * (conditioned_logits - unconditioned_logits)
292
+ + unconditioned_logits
293
+ )
294
+
295
+
296
+ class InBatchCFGLogitsProcessor(LogitsProcessor):
297
+ def __init__(self, guidance_scale: float):
298
+ self.guidance_scale = guidance_scale
299
+
300
+ def __call__(
301
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
302
+ ) -> torch.FloatTensor:
303
+ # input_ids.shape=[2*batch, seq-len]
304
+ # logits.shape=[2*batch, vocab]
305
+ conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0)
306
+ mixed_logits = unconditioned_logits + self.guidance_scale * (
307
+ conditioned_logits - unconditioned_logits
308
+ )
309
+ return mixed_logits.repeat(2, 1)
310
+
311
+
312
+ class InBatchInstructCFGLogitsProcessor(LogitsProcessor):
313
+ # See https://arxiv.org/abs/2211.09800
314
+
315
+ def __init__(self, guidance_scale_text: float, guidance_scale_image: float):
316
+ self.guidance_scale_text = guidance_scale_text
317
+ self.guidance_scale_image = guidance_scale_image
318
+
319
+ def __call__(
320
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
321
+ ) -> torch.FloatTensor:
322
+ # input_ids.shape=[3*batch, seq-len]
323
+ # logits.shape=[3*batch, vocab]
324
+ (
325
+ full_conditioned_logits,
326
+ image_conditioned_logits,
327
+ unconditioned_logits,
328
+ ) = logits.chunk(3)
329
+ mixed_logits = (
330
+ unconditioned_logits
331
+ + self.guidance_scale_image
332
+ * (image_conditioned_logits - unconditioned_logits)
333
+ + self.guidance_scale_text
334
+ * (full_conditioned_logits - image_conditioned_logits)
335
+ )
336
+ return mixed_logits.repeat(3, 1)
chameleon/inference/model_adapter.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from abc import ABC, abstractmethod
8
+
9
+ import torch
10
+
11
+ from chameleon.inference import transformer
12
+ from chameleon.inference.alignment import (
13
+ AlignPromptLeft,
14
+ AlignPromptRight,
15
+ PromptAlignment,
16
+ )
17
+ from chameleon.inference.cudagraph import cudagraph_wrap
18
+
19
+
20
+ class ModelAdapter(ABC):
21
+ @abstractmethod
22
+ def initialize(self, prompt_tokens: list[list[int]]):
23
+ ...
24
+
25
+ @abstractmethod
26
+ def supports_alignment(self, alignment: PromptAlignment) -> bool:
27
+ ...
28
+
29
+ @abstractmethod
30
+ @torch.inference_mode()
31
+ def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
32
+ ...
33
+
34
+
35
+ class ChameleonModelAdapter(ModelAdapter):
36
+ """Adapter for Chameleon-style model that handles state, such as cache."""
37
+
38
+ def __init__(
39
+ self,
40
+ model: transformer.Transformer,
41
+ max_seq_len: int,
42
+ dtype: torch.dtype | None = None,
43
+ ):
44
+ super().__init__()
45
+ self._args = model.args
46
+ self._model = model
47
+ self._max_seq_len = max_seq_len
48
+ self._dtype = dtype or next(model.parameters()).data.dtype
49
+
50
+ def initialize(self, prompt_tokens: list[list[int]]):
51
+ self._prompt_lengths = [len(toks) for toks in prompt_tokens]
52
+ batch_size = len(prompt_tokens)
53
+
54
+ self._cache = transformer.make_cache(
55
+ args=self._args,
56
+ length=batch_size * self._max_seq_len,
57
+ dtype=self._dtype,
58
+ )
59
+
60
+ self._local_inputs = torch.zeros([batch_size], dtype=int, device="cuda")
61
+
62
+ self._forward = cudagraph_wrap(self._model.forward_with_attn_bias)
63
+
64
+ self._first_pass = True
65
+
66
+ def supports_alignment(self, alignment: PromptAlignment) -> bool:
67
+ return isinstance(alignment, AlignPromptLeft) or isinstance(
68
+ alignment, AlignPromptRight
69
+ )
70
+
71
+ def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
72
+ # inputs.shape=[batch, seq-len]
73
+ batch_size, seq_len = inputs.shape
74
+
75
+ if self._first_pass:
76
+ attn_seqlen = [min(pl, seq_len) for pl in self._prompt_lengths]
77
+ self._bias = transformer.AttnBias.from_seqlens(
78
+ q_seqlen=attn_seqlen,
79
+ kv_seqlen=attn_seqlen,
80
+ kv_padding=self._max_seq_len,
81
+ )
82
+
83
+ mask = torch.zeros_like(inputs, dtype=torch.bool)
84
+ for i, k in enumerate(self._prompt_lengths):
85
+ mask[i, -k:] = True
86
+
87
+ flat_outputs: torch.Tensor = self._forward( # type: ignore
88
+ token_values=inputs[mask],
89
+ attn_bias=self._bias,
90
+ cache=self._cache,
91
+ )
92
+ self._local_outputs = torch.full(
93
+ (inputs.shape[0], inputs.shape[1], flat_outputs.shape[-1]),
94
+ -math.inf,
95
+ )
96
+ self._local_outputs[mask] = flat_outputs
97
+
98
+ self._vocab_size = self._local_outputs.shape[-1]
99
+
100
+ self._bias.q_seqinfo.seqstart.copy_(
101
+ torch.arange(batch_size + 1, dtype=torch.int)
102
+ )
103
+ self._bias.q_seqinfo.max_seqlen = 1
104
+ self._bias.q_seqinfo.seqstart_py = self._bias.q_seqinfo.seqstart.tolist()
105
+
106
+ self._first_pass = False
107
+
108
+ else:
109
+ self._local_inputs.copy_(inputs[:, -1]) # type: ignore
110
+
111
+ self._local_outputs = self._forward( # type: ignore
112
+ token_values=self._local_inputs,
113
+ attn_bias=self._bias,
114
+ cache=self._cache,
115
+ )
116
+
117
+ self._bias.k_seqinfo.seqlen.add_(1)
118
+ return self._local_outputs.view(batch_size, -1, self._vocab_size)
chameleon/inference/stopping_criteria.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ class StoppingCriteria:
10
+ def __call__(
11
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
12
+ ) -> bool:
13
+ raise NotImplementedError("StoppingCriteria needs to be subclassed")
14
+
15
+
16
+ class StoppingCriteriaList(list):
17
+ def __call__(
18
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
19
+ ) -> bool:
20
+ return any(criteria(input_ids, scores, **kwargs) for criteria in self)
21
+
22
+
23
+ class MaxLengthCriteria(StoppingCriteria):
24
+ def __init__(self, max_length: int):
25
+ self.max_length = max_length
26
+
27
+ def __call__(
28
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
29
+ ) -> bool:
30
+ cur_len = input_ids.shape[-1]
31
+ return cur_len >= self.max_length
32
+
33
+
34
+ class StopOnEOS(StoppingCriteria):
35
+ def __init__(self, eos_id: int):
36
+ self._eos_id = eos_id
37
+
38
+ def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
39
+ # input_ids.shape=[batch, seq_len]
40
+ return (input_ids == self._eos_id).sum(dim=1).all()
41
+
42
+
43
+ class StopOnEOSAfterBatchIndex(StoppingCriteria):
44
+ def __init__(self, eos_id: int, batch_index: list[int]):
45
+ self._eos_id = eos_id
46
+ self.batch_index = torch.tensor(batch_index, dtype=torch.long).unsqueeze(1)
47
+
48
+ def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
49
+ # input_ids.shape=[batch, seq_len]
50
+ eos_mask = input_ids == self._eos_id
51
+ consider_eos_mask = (
52
+ torch.arange(input_ids.shape[1]).unsqueeze(0) >= self.batch_index
53
+ )
54
+ valid_eos = eos_mask & consider_eos_mask
55
+ return valid_eos.sum(dim=1).all()
chameleon/inference/token_selector.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ class TokenSelector:
10
+ def __call__(
11
+ self, input_ids: torch.LongTensor, probs: torch.FloatTensor
12
+ ) -> torch.FloatTensor:
13
+ # input_ids.shape=[batch, seq_len]
14
+ # probs.shape=[batch, vocab]
15
+ ...
16
+
17
+
18
+ class ArgmaxTokenSelector(TokenSelector):
19
+ def __call__(
20
+ self, _: torch.LongTensor, probs: torch.FloatTensor
21
+ ) -> torch.LongTensor:
22
+ # probs.shape=[batch, vocab]
23
+ return probs.argmax(dim=1)
24
+
25
+
26
+ class MultinomialTokenSelector(TokenSelector):
27
+ def __call__(
28
+ self, _: torch.LongTensor, probs: torch.FloatTensor
29
+ ) -> torch.LongTensor:
30
+ # probs.shape=[batch, vocab]
31
+ return probs.multinomial(num_samples=1).squeeze(1)
32
+
33
+
34
+ class ReplicatedInputTokenSelector(TokenSelector):
35
+ def __init__(self, token_selector: TokenSelector, n: int):
36
+ self.token_selector = token_selector
37
+ self.n = n
38
+
39
+ def __call__(
40
+ self, input_ids: torch.LongTensor, probs: torch.FloatTensor
41
+ ) -> torch.LongTensor:
42
+ # input_ids.shape=[n*batch, seq_len]
43
+ # probs.shape=[n*batch, vocab]
44
+ primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0]
45
+ primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0]
46
+ tokens = self.token_selector(primary_input_ids, primary_probs)
47
+ return tokens.repeat(self.n)
chameleon/inference/transformer.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ from torch import distributed as dist
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from xformers.ops import RMSNorm, fmha, rope_padded
13
+ from xformers.ops.fmha.attn_bias import (
14
+ BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
15
+ )
16
+
17
+
18
+ @dataclass
19
+ class ModelArgs:
20
+ model_parallel_size: int = 1
21
+ dim: int = 512
22
+ n_layers: int = 8
23
+ n_heads: int = 8
24
+ n_kv_heads: int | None = None
25
+ vocab_size: int = -1
26
+ ffn_dim_multiplier: float | None = None
27
+ multiple_of: int = 256
28
+ norm_eps: float = 1e-5
29
+ rope_theta: float = 10000.0
30
+ qk_normalization: bool = False
31
+ swin_norm: bool = False
32
+
33
+
34
+ LayerCache = tuple[torch.Tensor, torch.Tensor]
35
+
36
+
37
+ class Attention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ model_parallel_size: int,
41
+ dim: int,
42
+ head_dim: int,
43
+ n_heads: int,
44
+ n_kv_heads: int,
45
+ rope_theta: float,
46
+ qk_normalization: bool = False,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.model_parallel_size = model_parallel_size
51
+
52
+ self.head_dim = head_dim
53
+ self.rope_theta = rope_theta
54
+
55
+ self.n_local_heads = n_heads // model_parallel_size
56
+ self.n_local_kv_heads = n_kv_heads // model_parallel_size
57
+
58
+ self.wqkv = nn.Linear(
59
+ dim,
60
+ (self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
61
+ bias=False,
62
+ dtype=torch.bfloat16,
63
+ )
64
+ self.wo = nn.Linear(
65
+ self.n_local_heads * head_dim,
66
+ dim,
67
+ bias=False,
68
+ dtype=torch.bfloat16,
69
+ )
70
+
71
+ self.qk_normalization = qk_normalization
72
+ if qk_normalization:
73
+ self.q_normalization = torch.nn.LayerNorm(head_dim)
74
+ self.k_normalization = torch.nn.LayerNorm(head_dim)
75
+
76
+ self._register_load_state_dict_pre_hook(self.load_hook)
77
+
78
+ # This adapter makes sure we can load vanilla
79
+ # Llama checkpoints where wq, wk, and wv are
80
+ # not fused in a single parameter
81
+ def load_hook(
82
+ self,
83
+ state_dict,
84
+ prefix,
85
+ local_metadata,
86
+ strict,
87
+ missing_keys,
88
+ unexpected_keys,
89
+ error_msgs,
90
+ ):
91
+ if prefix + "wq.weight" in state_dict:
92
+ wq = state_dict.pop(prefix + "wq.weight")
93
+ wk = state_dict.pop(prefix + "wk.weight")
94
+ wv = state_dict.pop(prefix + "wv.weight")
95
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
96
+
97
+ def forward(
98
+ self,
99
+ x: torch.Tensor,
100
+ cache: LayerCache,
101
+ attn_bias: AttnBias,
102
+ group: dist.ProcessGroup | None = None,
103
+ ) -> torch.Tensor:
104
+ # x.shape is (sum(seq_lens), dim)
105
+ #
106
+ # Since we support heterogenous sequence
107
+ # lengths, the hidden states are all
108
+ # concatenated together along the usual
109
+ # sequence dimension. The attention below
110
+ # finds out where sequences start & end
111
+ # using the provided attention bias.
112
+ xqkv = self.wqkv(x)
113
+ xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
114
+ xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
115
+ xk, xv = xkv.chunk(2, 1)
116
+
117
+ if self.qk_normalization:
118
+ xq = xq.view(-1, self.n_local_heads, self.head_dim)
119
+ xq = self.q_normalization(xq)
120
+ xq = xq.view(-1, self.n_local_heads * self.head_dim)
121
+
122
+ xk = xk.view(-1, self.n_local_kv_heads, self.head_dim)
123
+ xk = self.k_normalization(xk)
124
+ xk = xk.view(-1, self.n_local_kv_heads * self.head_dim)
125
+
126
+ output_shape = xq.shape
127
+ xq = xq.view(1, xq.shape[0], self.n_local_heads, self.head_dim)
128
+ xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, self.head_dim)
129
+ xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, self.head_dim)
130
+ cache_k, cache_v = cache
131
+
132
+ xq = rope_padded(
133
+ xq=xq,
134
+ xk=xk,
135
+ xv=xv,
136
+ cache_k=cache_k,
137
+ cache_v=cache_v,
138
+ attn_bias=attn_bias,
139
+ theta=self.rope_theta,
140
+ )
141
+
142
+ # Handle GQA
143
+ # Q shape: [B, M, Hkv, Hq // Hkv, K]
144
+ heads_per_group = self.n_local_heads // self.n_local_kv_heads
145
+ cache_k = cache_k.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
146
+ cache_v = cache_v.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
147
+ xq = xq.reshape(
148
+ [*xq.shape[:2], self.n_local_kv_heads, heads_per_group, xq.shape[-1]]
149
+ )
150
+
151
+ # rope_padded() updated the caches, so we
152
+ # call attention directly
153
+ output = fmha.memory_efficient_attention_forward(
154
+ xq, cache_k, cache_v, attn_bias
155
+ )
156
+
157
+ output = self.wo(output.reshape(output_shape))
158
+ if self.model_parallel_size > 1:
159
+ dist.all_reduce(output, group=group)
160
+
161
+ return output
162
+
163
+
164
+ class FeedForward(nn.Module):
165
+ def __init__(
166
+ self,
167
+ model_parallel_size: int,
168
+ dim: int,
169
+ hidden_dim: int,
170
+ multiple_of: int,
171
+ ffn_dim_multiplier: float | None,
172
+ ):
173
+ super().__init__()
174
+
175
+ self.model_parallel_size = model_parallel_size
176
+
177
+ hidden_dim = int(2 * hidden_dim / 3)
178
+ if ffn_dim_multiplier is not None:
179
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
180
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
181
+ assert hidden_dim % model_parallel_size == 0
182
+
183
+ self.w13 = nn.Linear(
184
+ dim,
185
+ 2 * hidden_dim // model_parallel_size,
186
+ bias=False,
187
+ )
188
+ self.w2 = nn.Linear(
189
+ hidden_dim // model_parallel_size,
190
+ dim,
191
+ bias=False,
192
+ )
193
+ self._register_load_state_dict_pre_hook(self.load_hook)
194
+
195
+ # This adapter makes sure we can load vanilla
196
+ # Llama checkpoints where w1 and w3 are not
197
+ # fused in a single parameter
198
+ def load_hook(
199
+ self,
200
+ state_dict,
201
+ prefix,
202
+ local_metadata,
203
+ strict,
204
+ missing_keys,
205
+ unexpected_keys,
206
+ error_msgs,
207
+ ):
208
+ if prefix + "w1.weight" in state_dict:
209
+ w1 = state_dict.pop(prefix + "w1.weight")
210
+ w3 = state_dict.pop(prefix + "w3.weight")
211
+ state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
212
+
213
+ def forward(
214
+ self, x: torch.Tensor, group: dist.ProcessGroup | None = None
215
+ ) -> torch.Tensor:
216
+ x13 = self.w13(x)
217
+ x1, x3 = x13.chunk(2, -1)
218
+ output = self.w2(F.silu(x1) * x3)
219
+ if self.model_parallel_size > 1:
220
+ dist.all_reduce(output, group=group)
221
+ return output
222
+
223
+
224
+ class TransformerBlock(nn.Module):
225
+ def __init__(self, args: ModelArgs):
226
+ super().__init__()
227
+
228
+ assert args.dim % args.n_heads == 0
229
+ head_dim = args.dim // args.n_heads
230
+ if args.n_kv_heads is not None:
231
+ n_kv_heads = args.n_kv_heads
232
+ else:
233
+ n_kv_heads = args.n_heads
234
+
235
+ model_parallel_size = args.model_parallel_size
236
+ assert args.n_heads % n_kv_heads == 0
237
+ assert args.n_heads % model_parallel_size == 0
238
+ assert n_kv_heads % model_parallel_size == 0
239
+
240
+ self.attention = Attention(
241
+ model_parallel_size=model_parallel_size,
242
+ dim=args.dim,
243
+ head_dim=head_dim,
244
+ n_heads=args.n_heads,
245
+ n_kv_heads=n_kv_heads,
246
+ rope_theta=args.rope_theta,
247
+ qk_normalization=args.qk_normalization,
248
+ )
249
+ self.feed_forward = FeedForward(
250
+ model_parallel_size=model_parallel_size,
251
+ dim=args.dim,
252
+ hidden_dim=4 * args.dim,
253
+ multiple_of=args.multiple_of,
254
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
255
+ )
256
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
257
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
258
+ self.swin_norm = args.swin_norm
259
+
260
+ def forward(
261
+ self,
262
+ x: torch.Tensor,
263
+ cache: LayerCache,
264
+ attn_bias: AttnBias,
265
+ group: dist.ProcessGroup | None = None,
266
+ ) -> torch.Tensor:
267
+ if self.swin_norm:
268
+ h = x + self.attention_norm(
269
+ self.attention.forward(
270
+ x,
271
+ cache,
272
+ attn_bias,
273
+ group=group,
274
+ )
275
+ )
276
+ out = h + self.ffn_norm(self.feed_forward(h, group=group))
277
+ else:
278
+ h = x + self.attention.forward(
279
+ self.attention_norm(x),
280
+ cache,
281
+ attn_bias,
282
+ group=group,
283
+ )
284
+ out = h + self.feed_forward(self.ffn_norm(h), group=group)
285
+ return out
286
+
287
+
288
+ class Transformer(nn.Module):
289
+ def __init__(self, args: ModelArgs):
290
+ super().__init__()
291
+ self.args = args
292
+
293
+ self.model_parallel_size = args.model_parallel_size
294
+ assert args.dim % self.model_parallel_size == 0
295
+ assert args.vocab_size > 0
296
+ assert args.vocab_size % self.model_parallel_size == 0
297
+
298
+ self.tok_embeddings = nn.Embedding(
299
+ num_embeddings=args.vocab_size,
300
+ embedding_dim=args.dim // self.model_parallel_size,
301
+ )
302
+
303
+ self.layers = nn.ModuleList()
304
+ for _ in range(args.n_layers):
305
+ self.layers.append(TransformerBlock(args))
306
+
307
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
308
+
309
+ self.output = nn.Linear(
310
+ args.dim,
311
+ args.vocab_size // self.model_parallel_size,
312
+ bias=False,
313
+ )
314
+
315
+ @torch.no_grad()
316
+ def forward_with_attn_bias(
317
+ self,
318
+ token_values: torch.Tensor,
319
+ attn_bias: AttnBias,
320
+ cache: list[LayerCache],
321
+ group: dist.ProcessGroup | None = None,
322
+ ) -> torch.Tensor:
323
+ h = self.tok_embeddings(token_values)
324
+ if self.model_parallel_size > 1:
325
+ gather = [torch.empty_like(h) for _ in range(self.model_parallel_size)]
326
+ dist.all_gather(gather, h, group=group)
327
+ h = torch.cat(gather, dim=-1)
328
+
329
+ for i, layer in enumerate(self.layers):
330
+ h = layer(h, cache[i], attn_bias, group=group)
331
+
332
+ logits = self.output(self.norm(h))
333
+ if self.model_parallel_size > 1:
334
+ gather = [torch.empty_like(logits) for _ in range(self.model_parallel_size)]
335
+ dist.all_gather(gather, logits, group=group)
336
+ logits = torch.cat(gather, dim=-1)
337
+ return logits.float()
338
+
339
+ def forward(
340
+ self,
341
+ token_values: torch.Tensor,
342
+ token_lengths: torch.Tensor,
343
+ start_pos: torch.Tensor,
344
+ cache: list[LayerCache],
345
+ kv_padding: int,
346
+ group: dist.ProcessGroup | None = None,
347
+ ) -> torch.Tensor:
348
+ attn_bias = AttnBias.from_seqlens(
349
+ q_seqlen=token_lengths.tolist(),
350
+ kv_seqlen=(start_pos + token_lengths).tolist(),
351
+ kv_padding=kv_padding,
352
+ )
353
+ return self.forward_with_attn_bias(token_values, attn_bias, cache, group=group)
354
+
355
+
356
+ def make_cache(
357
+ args: ModelArgs,
358
+ length: int,
359
+ device: str | torch.device | None = None,
360
+ n_layers: int | None = None,
361
+ dtype: torch.dtype | None = None,
362
+ ) -> list[LayerCache]:
363
+ """
364
+ Allocate a cache to be used with the Transformer module.
365
+
366
+ Args:
367
+ args (ModelArgs): the model configuration.
368
+ length (int): per layer cache size.
369
+ It is usually budgeted as ``max_batch * max_seq``
370
+ device (torch.device, optional): the device on which
371
+ the cache should be allocated.
372
+ n_layers (int, optional): the number of layers to
373
+ allocate a cache for (defaults to the model
374
+ settings).
375
+ dtype (torch.dtype, optional): the dtype to use for
376
+ cache entries (defaults to the default dtype).
377
+
378
+ Returns:
379
+ The cache object to pass to ``Tranformer.forward``.
380
+ """
381
+
382
+ head_dim = args.dim // args.n_heads
383
+ n_kv_heads = args.n_kv_heads
384
+ if n_kv_heads is None:
385
+ n_kv_heads = args.n_heads
386
+ n_local_kv_heads = n_kv_heads // args.model_parallel_size
387
+
388
+ if n_layers is None:
389
+ n_layers = args.n_layers
390
+
391
+ shape = (1, length, n_local_kv_heads, head_dim)
392
+ return [
393
+ (
394
+ torch.zeros(shape, device=device, dtype=dtype),
395
+ torch.zeros(shape, device=device, dtype=dtype),
396
+ )
397
+ for _ in range(n_layers)
398
+ ]
399
+
400
+
401
+ def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
402
+ """
403
+ Take a prefix view of a larger cache.
404
+
405
+ The original cache object remains of identical size and valid
406
+ after the shrinked alias has been used. This function is useful
407
+ when a cache was allocated for a larger batch size than what is
408
+ necessary.
409
+
410
+ Args:
411
+ cache: the cache to take a view in.
412
+ length (int): the desired length
413
+
414
+ Returns:
415
+ A view in the input cache object.
416
+ """
417
+
418
+ if len(cache) > 0:
419
+ assert cache[0][0].shape[1] >= length
420
+
421
+ return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
chameleon/inference/utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import socket
7
+ from typing import Generator, Generic, Iterator, TypeVar
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ class DynamicGenerator(Generic[T]):
13
+ def __init__(self, gen: Generator[T, None, None]):
14
+ self.gen = gen
15
+
16
+ def __iter__(self) -> Iterator[T]:
17
+ return self
18
+
19
+ def __next__(self) -> T:
20
+ return next(self.gen)
21
+
22
+
23
+ def advance(iterator: Iterator[T], steps: int):
24
+ try:
25
+ for _ in range(steps):
26
+ next(iterator)
27
+ except StopIteration:
28
+ pass
29
+
30
+
31
+ def random_unused_port():
32
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
33
+ s.bind(("", 0))
34
+ return s.getsockname()[1]
chameleon/inference/vocab.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from functools import cached_property
7
+
8
+ import torch
9
+
10
+
11
+ class VocabInfo:
12
+ def __init__(self, vocab_map: dict[str, int]):
13
+ self.name2val = vocab_map
14
+
15
+ self.bos_id = vocab_map.get("<s>")
16
+ self.eos_id = vocab_map.get("</s>")
17
+ self.boi_id = vocab_map.get("<racm3:break>")
18
+ self.eoi_id = vocab_map.get("<eoss>")
19
+ self.pad_id = vocab_map.get("<pad>")
20
+ self.eot_id = vocab_map.get("<reserved08706>")
21
+
22
+ @property
23
+ def begin_sequence(self) -> int:
24
+ return self.bos_id
25
+
26
+ @property
27
+ def end_sequence(self) -> int:
28
+ return self.eos_id
29
+
30
+ @property
31
+ def begin_image(self) -> int:
32
+ return self.boi_id
33
+
34
+ @property
35
+ def end_image(self) -> int:
36
+ return self.eoi_id
37
+
38
+ @property
39
+ def padding(self) -> int:
40
+ return self.pad_id
41
+
42
+ @property
43
+ def end_turn(self) -> int:
44
+ return self.eot_id
45
+
46
+ @cached_property
47
+ def val2name(self) -> dict[int, str]:
48
+ return {v: k for k, v in self.name2val.items()}
49
+
50
+ @cached_property
51
+ def all_tokens(self) -> list[int]:
52
+ return sorted(self.name2val.values())
53
+
54
+ @cached_property
55
+ def image_tokens(self) -> list[int]:
56
+ return sorted(
57
+ [val for name, val in self.name2val.items() if name.startswith("IMGIMG")]
58
+ )
59
+
60
+ @cached_property
61
+ def special_tokens(self) -> list[int]:
62
+ return sorted(
63
+ [
64
+ val
65
+ for name, val in self.name2val.items()
66
+ if name.startswith("<") and name != "<"
67
+ ]
68
+ )
69
+
70
+ @cached_property
71
+ def text_tokens(self) -> list[int]:
72
+ return sorted(
73
+ set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens)
74
+ )
75
+
76
+
77
+ class VocabTranslation:
78
+ def __init__(self, vocab_info: VocabInfo, device: str | None = None):
79
+ self._vocab = vocab_info
80
+ self._device = device
81
+
82
+ @cached_property
83
+ def bpe2img(self) -> dict[int, int]:
84
+ img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
85
+
86
+ def remap(old_name: str) -> str:
87
+ return "".join(
88
+ img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
89
+ )
90
+
91
+ return {
92
+ tok: int(remap(self._vocab.val2name[tok]))
93
+ for tok in self._vocab.image_tokens
94
+ }
95
+
96
+ @cached_property
97
+ def img2bpe(self) -> dict[int, int]:
98
+ return {v: k for k, v in self.bpe2img.items()}
99
+
100
+ @cached_property
101
+ def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]:
102
+ sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device)
103
+ sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device)
104
+ return sorted_bpe, sorted_img
105
+
106
+ @cached_property
107
+ def img2bpe_mapping_tensor(self) -> torch.LongTensor:
108
+ mapping = torch.zeros(
109
+ max(self.img2bpe.keys()) + 1,
110
+ dtype=torch.int,
111
+ device=self._device,
112
+ )
113
+ for k, v in self.img2bpe.items():
114
+ mapping[k] = v
115
+ return mapping
116
+
117
+ def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor:
118
+ bpe_tok, img_tok = self.bpe2img_search_tensors
119
+ return img_tok[torch.searchsorted(bpe_tok, bpe_batch)]
120
+
121
+ def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor:
122
+ return self.img2bpe_mapping_tensor[img_batch]
chameleon/inference/vqgan.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
8
+ [with minimal dependencies]
9
+
10
+ This implementation is inference-only -- training steps and optimizer components
11
+ introduce significant additional dependencies
12
+ """
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class VectorQuantizer2(nn.Module):
21
+ """
22
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
23
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
24
+ """
25
+
26
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
27
+ # backwards compatibility we use the buggy version by default, but you can
28
+ # specify legacy=False to fix it.
29
+ def __init__(
30
+ self,
31
+ n_e,
32
+ e_dim,
33
+ beta,
34
+ remap=None,
35
+ unknown_index="random",
36
+ sane_index_shape=False,
37
+ legacy=True,
38
+ ):
39
+ super().__init__()
40
+ self.n_e = n_e
41
+ self.e_dim = e_dim
42
+ self.beta = beta
43
+ self.legacy = legacy
44
+
45
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
46
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
47
+
48
+ self.remap = remap
49
+ if self.remap is not None:
50
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
51
+ self.re_embed = self.used.shape[0]
52
+ self.unknown_index = unknown_index # "random" or "extra" or integer
53
+ if self.unknown_index == "extra":
54
+ self.unknown_index = self.re_embed
55
+ self.re_embed = self.re_embed + 1
56
+ print(
57
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
58
+ f"Using {self.unknown_index} for unknown indices."
59
+ )
60
+ else:
61
+ self.re_embed = n_e
62
+
63
+ self.sane_index_shape = sane_index_shape
64
+
65
+ def remap_to_used(self, inds):
66
+ ishape = inds.shape
67
+ assert len(ishape) > 1
68
+ inds = inds.reshape(ishape[0], -1)
69
+ used = self.used.to(inds)
70
+ match = (inds[:, :, None] == used[None, None, ...]).long()
71
+ new = match.argmax(-1)
72
+ unknown = match.sum(2) < 1
73
+ if self.unknown_index == "random":
74
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
75
+ device=new.device
76
+ )
77
+ else:
78
+ new[unknown] = self.unknown_index
79
+ return new.reshape(ishape)
80
+
81
+ def unmap_to_all(self, inds):
82
+ ishape = inds.shape
83
+ assert len(ishape) > 1
84
+ inds = inds.reshape(ishape[0], -1)
85
+ used = self.used.to(inds)
86
+ if self.re_embed > self.used.shape[0]: # extra token
87
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
88
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
89
+ return back.reshape(ishape)
90
+
91
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
92
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
93
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
94
+ assert return_logits is False, "Only for interface compatible with Gumbel"
95
+ # reshape z -> (batch, height, width, channel) and flatten
96
+ z = z.permute(0, 2, 3, 1).contiguous()
97
+ z_flattened = z.view(-1, self.e_dim)
98
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
99
+
100
+ d = (
101
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
102
+ + torch.sum(self.embedding.weight**2, dim=1)
103
+ - 2
104
+ * torch.einsum(
105
+ "bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
106
+ )
107
+ )
108
+
109
+ min_encoding_indices = torch.argmin(d, dim=1)
110
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
111
+ perplexity = None
112
+ min_encodings = None
113
+
114
+ # compute loss for embedding
115
+ if not self.legacy:
116
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
117
+ (z_q - z.detach()) ** 2
118
+ )
119
+ else:
120
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
121
+ (z_q - z.detach()) ** 2
122
+ )
123
+
124
+ # preserve gradients
125
+ z_q = z + (z_q - z).detach()
126
+
127
+ # reshape back to match original input shape
128
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
129
+
130
+ if self.remap is not None:
131
+ min_encoding_indices = min_encoding_indices.reshape(
132
+ z.shape[0], -1
133
+ ) # add batch axis
134
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
135
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
136
+
137
+ if self.sane_index_shape:
138
+ min_encoding_indices = min_encoding_indices.reshape(
139
+ z_q.shape[0], z_q.shape[2], z_q.shape[3]
140
+ )
141
+
142
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
143
+
144
+ def get_codebook_entry(self, indices, shape):
145
+ # shape specifying (batch, height, width, channel)
146
+ if self.remap is not None:
147
+ indices = indices.reshape(shape[0], -1) # add batch axis
148
+ indices = self.unmap_to_all(indices)
149
+ indices = indices.reshape(-1) # flatten again
150
+
151
+ # get quantized latent vectors
152
+ z_q = self.embedding(indices)
153
+
154
+ if shape is not None:
155
+ z_q = z_q.view(shape)
156
+ # reshape back to match original input shape
157
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
158
+
159
+ return z_q
160
+
161
+
162
+ # Alias
163
+ VectorQuantizer = VectorQuantizer2
164
+
165
+
166
+ def nonlinearity(x):
167
+ # swish
168
+ return x * torch.sigmoid(x)
169
+
170
+
171
+ def Normalize(in_channels, num_groups=32):
172
+ return torch.nn.GroupNorm(
173
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
174
+ )
175
+
176
+
177
+ class Upsample(nn.Module):
178
+ def __init__(self, in_channels, with_conv):
179
+ super().__init__()
180
+ self.with_conv = with_conv
181
+ if self.with_conv:
182
+ self.conv = torch.nn.Conv2d(
183
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
184
+ )
185
+
186
+ def forward(self, x):
187
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
188
+ if self.with_conv:
189
+ x = self.conv(x)
190
+ return x
191
+
192
+
193
+ class Downsample(nn.Module):
194
+ def __init__(self, in_channels, with_conv):
195
+ super().__init__()
196
+ self.with_conv = with_conv
197
+ if self.with_conv:
198
+ # no asymmetric padding in torch conv, must do it ourselves
199
+ self.conv = torch.nn.Conv2d(
200
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
201
+ )
202
+
203
+ def forward(self, x):
204
+ if self.with_conv:
205
+ pad = (0, 1, 0, 1)
206
+ x = F.pad(x, pad, mode="constant", value=0)
207
+ x = self.conv(x)
208
+ else:
209
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
210
+ return x
211
+
212
+
213
+ class ResnetBlock(nn.Module):
214
+ def __init__(
215
+ self,
216
+ *,
217
+ in_channels,
218
+ out_channels=None,
219
+ conv_shortcut=False,
220
+ dropout,
221
+ temb_channels=512,
222
+ ):
223
+ super().__init__()
224
+ self.in_channels = in_channels
225
+ out_channels = in_channels if out_channels is None else out_channels
226
+ self.out_channels = out_channels
227
+ self.use_conv_shortcut = conv_shortcut
228
+
229
+ self.norm1 = Normalize(in_channels)
230
+ self.conv1 = torch.nn.Conv2d(
231
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
232
+ )
233
+ if temb_channels > 0:
234
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
235
+ self.norm2 = Normalize(out_channels)
236
+ self.dropout = torch.nn.Dropout(dropout)
237
+ self.conv2 = torch.nn.Conv2d(
238
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
239
+ )
240
+ if self.in_channels != self.out_channels:
241
+ if self.use_conv_shortcut:
242
+ self.conv_shortcut = torch.nn.Conv2d(
243
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
244
+ )
245
+ else:
246
+ self.nin_shortcut = torch.nn.Conv2d(
247
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
248
+ )
249
+
250
+ def forward(self, x, temb):
251
+ h = x
252
+ h = self.norm1(h)
253
+ h = nonlinearity(h)
254
+ h = self.conv1(h)
255
+
256
+ if temb is not None:
257
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
258
+
259
+ h = self.norm2(h)
260
+ h = nonlinearity(h)
261
+ h = self.dropout(h)
262
+ h = self.conv2(h)
263
+
264
+ if self.in_channels != self.out_channels:
265
+ if self.use_conv_shortcut:
266
+ x = self.conv_shortcut(x)
267
+ else:
268
+ x = self.nin_shortcut(x)
269
+
270
+ return x + h
271
+
272
+
273
+ class AttnBlock(nn.Module):
274
+ def __init__(self, in_channels):
275
+ super().__init__()
276
+ self.in_channels = in_channels
277
+
278
+ self.norm = Normalize(in_channels)
279
+ self.q = torch.nn.Conv2d(
280
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
281
+ )
282
+ self.k = torch.nn.Conv2d(
283
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
284
+ )
285
+ self.v = torch.nn.Conv2d(
286
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
287
+ )
288
+ self.proj_out = torch.nn.Conv2d(
289
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
290
+ )
291
+
292
+ def forward(self, x):
293
+ h_ = x
294
+ h_ = self.norm(h_)
295
+ q = self.q(h_)
296
+ k = self.k(h_)
297
+ v = self.v(h_)
298
+
299
+ # compute attention
300
+ b, c, h, w = q.shape
301
+ q = q.reshape(b, c, h * w)
302
+ q = q.permute(0, 2, 1) # b,hw,c
303
+ k = k.reshape(b, c, h * w) # b,c,hw
304
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
305
+ w_ = w_ * (int(c) ** (-0.5))
306
+ w_ = F.softmax(w_, dim=2)
307
+
308
+ # attend to values
309
+ v = v.reshape(b, c, h * w)
310
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
311
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
312
+ h_ = h_.reshape(b, c, h, w)
313
+
314
+ h_ = self.proj_out(h_)
315
+
316
+ return x + h_
317
+
318
+
319
+ def make_attn(in_channels, attn_type="vanilla"):
320
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
321
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
322
+ if attn_type == "vanilla":
323
+ return AttnBlock(in_channels)
324
+ elif attn_type == "none":
325
+ return nn.Identity(in_channels)
326
+ else:
327
+ raise ValueError("Unexpected attention type")
328
+
329
+
330
+ class Encoder(nn.Module):
331
+ def __init__(
332
+ self,
333
+ *,
334
+ ch,
335
+ out_ch,
336
+ ch_mult=(1, 2, 4, 8),
337
+ num_res_blocks,
338
+ attn_resolutions,
339
+ dropout=0.0,
340
+ resamp_with_conv=True,
341
+ in_channels,
342
+ resolution,
343
+ z_channels,
344
+ double_z=True,
345
+ use_linear_attn=False,
346
+ attn_type="vanilla",
347
+ **ignore_kwargs,
348
+ ):
349
+ super().__init__()
350
+ if use_linear_attn:
351
+ attn_type = "linear"
352
+ self.ch = ch
353
+ self.temb_ch = 0
354
+ self.num_resolutions = len(ch_mult)
355
+ self.num_res_blocks = num_res_blocks
356
+ self.resolution = resolution
357
+ self.in_channels = in_channels
358
+
359
+ # downsampling
360
+ self.conv_in = torch.nn.Conv2d(
361
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
362
+ )
363
+
364
+ curr_res = resolution
365
+ in_ch_mult = (1,) + tuple(ch_mult)
366
+ self.in_ch_mult = in_ch_mult
367
+ self.down = nn.ModuleList()
368
+ for i_level in range(self.num_resolutions):
369
+ block = nn.ModuleList()
370
+ attn = nn.ModuleList()
371
+ block_in = ch * in_ch_mult[i_level]
372
+ block_out = ch * ch_mult[i_level]
373
+ for i_block in range(self.num_res_blocks):
374
+ block.append(
375
+ ResnetBlock(
376
+ in_channels=block_in,
377
+ out_channels=block_out,
378
+ temb_channels=self.temb_ch,
379
+ dropout=dropout,
380
+ )
381
+ )
382
+ block_in = block_out
383
+ if curr_res in attn_resolutions:
384
+ attn.append(make_attn(block_in, attn_type=attn_type))
385
+ down = nn.Module()
386
+ down.block = block
387
+ down.attn = attn
388
+ if i_level != self.num_resolutions - 1:
389
+ down.downsample = Downsample(block_in, resamp_with_conv)
390
+ curr_res = curr_res // 2
391
+ self.down.append(down)
392
+
393
+ # middle
394
+ self.mid = nn.Module()
395
+ self.mid.block_1 = ResnetBlock(
396
+ in_channels=block_in,
397
+ out_channels=block_in,
398
+ temb_channels=self.temb_ch,
399
+ dropout=dropout,
400
+ )
401
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
402
+ self.mid.block_2 = ResnetBlock(
403
+ in_channels=block_in,
404
+ out_channels=block_in,
405
+ temb_channels=self.temb_ch,
406
+ dropout=dropout,
407
+ )
408
+
409
+ # end
410
+ self.norm_out = Normalize(block_in)
411
+ self.conv_out = torch.nn.Conv2d(
412
+ block_in,
413
+ 2 * z_channels if double_z else z_channels,
414
+ kernel_size=3,
415
+ stride=1,
416
+ padding=1,
417
+ )
418
+
419
+ def forward(self, x):
420
+ # timestep embedding
421
+ temb = None
422
+
423
+ # downsampling
424
+ hs = [self.conv_in(x)]
425
+ for i_level in range(self.num_resolutions):
426
+ for i_block in range(self.num_res_blocks):
427
+ h = self.down[i_level].block[i_block](hs[-1], temb)
428
+ if len(self.down[i_level].attn) > 0:
429
+ h = self.down[i_level].attn[i_block](h)
430
+ hs.append(h)
431
+ if i_level != self.num_resolutions - 1:
432
+ hs.append(self.down[i_level].downsample(hs[-1]))
433
+
434
+ # middle
435
+ h = hs[-1]
436
+ h = self.mid.block_1(h, temb)
437
+ h = self.mid.attn_1(h)
438
+ h = self.mid.block_2(h, temb)
439
+
440
+ # end
441
+ h = self.norm_out(h)
442
+ h = nonlinearity(h)
443
+ h = self.conv_out(h)
444
+ return h
445
+
446
+
447
+ class Decoder(nn.Module):
448
+ def __init__(
449
+ self,
450
+ *,
451
+ ch,
452
+ out_ch,
453
+ ch_mult=(1, 2, 4, 8),
454
+ num_res_blocks,
455
+ attn_resolutions,
456
+ dropout=0.0,
457
+ resamp_with_conv=True,
458
+ in_channels,
459
+ resolution,
460
+ z_channels,
461
+ give_pre_end=False,
462
+ tanh_out=False,
463
+ use_linear_attn=False,
464
+ attn_type="vanilla",
465
+ **ignorekwargs,
466
+ ):
467
+ super().__init__()
468
+ if use_linear_attn:
469
+ attn_type = "linear"
470
+ self.ch = ch
471
+ self.temb_ch = 0
472
+ self.num_resolutions = len(ch_mult)
473
+ self.num_res_blocks = num_res_blocks
474
+ self.resolution = resolution
475
+ self.in_channels = in_channels
476
+ self.give_pre_end = give_pre_end
477
+ self.tanh_out = tanh_out
478
+
479
+ # compute in_ch_mult, block_in and curr_res at lowest res
480
+ block_in = ch * ch_mult[self.num_resolutions - 1]
481
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
482
+ self.z_shape = (1, z_channels, curr_res, curr_res)
483
+
484
+ # z to block_in
485
+ self.conv_in = torch.nn.Conv2d(
486
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
487
+ )
488
+
489
+ # middle
490
+ self.mid = nn.Module()
491
+ self.mid.block_1 = ResnetBlock(
492
+ in_channels=block_in,
493
+ out_channels=block_in,
494
+ temb_channels=self.temb_ch,
495
+ dropout=dropout,
496
+ )
497
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
498
+ self.mid.block_2 = ResnetBlock(
499
+ in_channels=block_in,
500
+ out_channels=block_in,
501
+ temb_channels=self.temb_ch,
502
+ dropout=dropout,
503
+ )
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch * ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks + 1):
512
+ block.append(
513
+ ResnetBlock(
514
+ in_channels=block_in,
515
+ out_channels=block_out,
516
+ temb_channels=self.temb_ch,
517
+ dropout=dropout,
518
+ )
519
+ )
520
+ block_in = block_out
521
+ if curr_res in attn_resolutions:
522
+ attn.append(make_attn(block_in, attn_type=attn_type))
523
+ up = nn.Module()
524
+ up.block = block
525
+ up.attn = attn
526
+ if i_level != 0:
527
+ up.upsample = Upsample(block_in, resamp_with_conv)
528
+ curr_res = curr_res * 2
529
+ self.up.insert(0, up) # prepend to get consistent order
530
+
531
+ # end
532
+ self.norm_out = Normalize(block_in)
533
+ self.conv_out = torch.nn.Conv2d(
534
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
535
+ )
536
+
537
+ def forward(self, z):
538
+ # assert z.shape[1:] == self.z_shape[1:]
539
+ self.last_z_shape = z.shape
540
+
541
+ # timestep embedding
542
+ temb = None
543
+
544
+ # z to block_in
545
+ h = self.conv_in(z)
546
+
547
+ # middle
548
+ h = self.mid.block_1(h, temb)
549
+ h = self.mid.attn_1(h)
550
+ h = self.mid.block_2(h, temb)
551
+
552
+ # upsampling
553
+ for i_level in reversed(range(self.num_resolutions)):
554
+ for i_block in range(self.num_res_blocks + 1):
555
+ h = self.up[i_level].block[i_block](h, temb)
556
+ if len(self.up[i_level].attn) > 0:
557
+ h = self.up[i_level].attn[i_block](h)
558
+ if i_level != 0:
559
+ h = self.up[i_level].upsample(h)
560
+
561
+ # end
562
+ if self.give_pre_end:
563
+ return h
564
+
565
+ h = self.norm_out(h)
566
+ h = nonlinearity(h)
567
+ h = self.conv_out(h)
568
+ if self.tanh_out:
569
+ h = torch.tanh(h)
570
+ return h
571
+
572
+
573
+ class VQModel(nn.Module):
574
+ def __init__(
575
+ self,
576
+ ddconfig,
577
+ n_embed,
578
+ embed_dim,
579
+ ckpt_path=None,
580
+ ignore_keys=[],
581
+ image_key="image",
582
+ colorize_nlabels=None,
583
+ monitor=None,
584
+ scheduler_config=None,
585
+ lr_g_factor=1.0,
586
+ remap=None,
587
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
588
+ ):
589
+ super().__init__()
590
+ self.image_key = image_key
591
+ self.encoder = Encoder(**ddconfig)
592
+ self.decoder = Decoder(**ddconfig)
593
+ self.quantize = VectorQuantizer(
594
+ n_embed,
595
+ embed_dim,
596
+ beta=0.25,
597
+ remap=remap,
598
+ sane_index_shape=sane_index_shape,
599
+ )
600
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
601
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
602
+ if ckpt_path is not None:
603
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
604
+ self.image_key = image_key
605
+ if colorize_nlabels is not None:
606
+ assert isinstance(colorize_nlabels, int)
607
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
608
+ if monitor is not None:
609
+ self.monitor = monitor
610
+ self.scheduler_config = scheduler_config
611
+ self.lr_g_factor = lr_g_factor
612
+
613
+ def init_from_ckpt(self, path, ignore_keys=list()):
614
+ sd = torch.load(path, map_location="cpu")["state_dict"]
615
+ keys = list(sd.keys())
616
+ for k in keys:
617
+ for ik in ignore_keys:
618
+ if k.startswith(ik):
619
+ print("Deleting key {} from state_dict.".format(k))
620
+ del sd[k]
621
+ self.load_state_dict(sd, strict=False)
622
+ print(f"VQModel loaded from {path}")
623
+
624
+ def encode(self, x):
625
+ h = self.encoder(x)
626
+ h = self.quant_conv(h)
627
+ quant, emb_loss, info = self.quantize(h)
628
+ return quant, emb_loss, info
629
+
630
+ def decode(self, quant):
631
+ quant = self.post_quant_conv(quant)
632
+ dec = self.decoder(quant)
633
+ return dec
634
+
635
+ def decode_code(self, code_b):
636
+ quant_b = self.quantize.embed_code(code_b)
637
+ dec = self.decode(quant_b)
638
+ return dec
639
+
640
+ def forward(self, input):
641
+ quant, diff, _ = self.encode(input)
642
+ dec = self.decode(quant)
643
+ return dec, diff
644
+
645
+ def get_input(self, batch, k):
646
+ x = batch[k]
647
+ if len(x.shape) == 3:
648
+ x = x[..., None]
649
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
650
+ return x.float()
651
+
652
+ def get_last_layer(self):
653
+ return self.decoder.conv_out.weight
654
+
655
+ def log_images(self, batch, **kwargs):
656
+ log = dict()
657
+ x = self.get_input(batch, self.image_key)
658
+ x = x.to(self.device)
659
+ xrec, _ = self(x)
660
+ if x.shape[1] > 3:
661
+ # colorize with random projection
662
+ assert xrec.shape[1] > 3
663
+ x = self.to_rgb(x)
664
+ xrec = self.to_rgb(xrec)
665
+ log["inputs"] = x
666
+ log["reconstructions"] = xrec
667
+ return log
668
+
669
+ def to_rgb(self, x):
670
+ assert self.image_key == "segmentation"
671
+ if not hasattr(self, "colorize"):
672
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
673
+ x = F.conv2d(x, weight=self.colorize)
674
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
675
+ return x
chameleon/vqgan.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ede986bf6b171db3081ce171ad88e4ac970793cea14c180b3e5ac5105f4cb43
3
+ size 281270377
chameleon/vqgan.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: taming.models.vqgan.VQModel
4
+ params:
5
+ embed_dim: 256
6
+ n_embed: 8192
7
+ ddconfig:
8
+ double_z: false
9
+ z_channels: 256
10
+ resolution: 512
11
+ in_channels: 3
12
+ out_ch: 3
13
+ ch: 128
14
+ ch_mult:
15
+ - 1
16
+ - 1
17
+ - 2
18
+ - 2
19
+ - 4
20
+ num_res_blocks: 2
21
+ attn_resolutions: []
22
+ dropout: 0.0
23
+ lossconfig:
24
+ target: taming.modules.losses.vqperceptual_vit_vqgan.VQLPIPSWithDiscriminator
25
+ params:
26
+ disc_start: 100001
27
+ perceptual_weight: 1.0
28
+ adversarial_weight: 0.5
29
+ disc_params:
30
+ size: 512
31
+ ckpt_path: manifold://fair_onellm_checkpoints/tree/v2/tokenizer/vqgan_wm_0209.ckpt
32
+ data:
33
+ target: main.DataModuleFromConfig
34
+ params:
35
+ batch_size: 4
36
+ num_workers: 10
37
+ image_size: 512
38
+ filter_image_size: 512
39
+ dataset: coco
40
+ aesthetics_th: 0
41
+ clipsim_th: 0
42
+ --distributed-world-size: null
43
+ '32': null
44
+ --distributed-port: null
45
+ '17338': null
46
+ --save-dir: null
47
+ /checkpoint/shellysheynin/shutterstock/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN:
48
+ log_every-500:
49
+ ngpu32: null
50
+ --tensorboard-logdir: null
51
+ /checkpoint/shellysheynin/tensorboard_logs/2023-03-30/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN:
52
+ log_every-500:
53
+ ngpu32: null
54
+ '14561': null
55
+ /checkpoint/shellysheynin/tensorboard_logs/2023-04-02/512x512_1024tokens_4node_shutterstock_laion_no_attn_styleGAN:
56
+ log_every-500:
57
+ ngpu32: null
conversation.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+ GEMMA = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
+ sep: str = "###"
28
+ sep2: str = None
29
+ version: str = "Unknown"
30
+
31
+ skip_next: bool = False
32
+
33
+ def get_prompt(self):
34
+ messages = self.messages
35
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
36
+ messages = self.messages.copy()
37
+ init_role, init_msg = messages[0].copy()
38
+ init_msg = init_msg[0].replace("<image>", "").strip()
39
+ if 'mmtag' in self.version:
40
+ messages[0] = (init_role, init_msg)
41
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
+ messages.insert(1, (self.roles[1], "Received."))
43
+ else:
44
+ messages[0] = (init_role, "<image>\n" + init_msg)
45
+
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in messages:
49
+ if message:
50
+ if type(message) is tuple:
51
+ message = message[0]
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message = message[0]
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.MPT:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message = message[0]
71
+ ret += role + message + self.sep
72
+ else:
73
+ ret += role
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ if i == 0: message = wrap_sys(self.system) + message
87
+ if i % 2 == 0:
88
+ message = wrap_inst(message)
89
+ ret += self.sep + message
90
+ else:
91
+ ret += " " + message + " " + self.sep2
92
+ else:
93
+ ret += ""
94
+ ret = ret.lstrip(self.sep)
95
+ elif self.sep_style == SeparatorStyle.GEMMA:
96
+ seps = [self.sep, self.sep2]
97
+ ret = self.system + seps[0]
98
+ for i, (role, message) in enumerate(messages):
99
+ if message:
100
+ if type(message) is tuple:
101
+ message, _, _ = message
102
+ ret += "<start_of_turn>" + role + "\n" + message + "<end_of_turn>\n" + seps[i % 2]
103
+ else:
104
+ ret += "<start_of_turn>" + role + "\n"
105
+ elif self.sep_style == SeparatorStyle.PLAIN:
106
+ seps = [self.sep, self.sep2]
107
+ ret = self.system
108
+ for i, (role, message) in enumerate(messages):
109
+ if message:
110
+ if type(message) is tuple:
111
+ message, _, _ = message
112
+ ret += message + seps[i % 2]
113
+ else:
114
+ ret += ""
115
+ else:
116
+ raise ValueError(f"Invalid style: {self.sep_style}")
117
+
118
+ return ret
119
+
120
+ def append_message(self, role, message):
121
+ self.messages.append([role, message])
122
+
123
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
124
+ if image_process_mode == "Pad":
125
+ def expand2square(pil_img, background_color=(122, 116, 104)):
126
+ width, height = pil_img.size
127
+ if width == height:
128
+ return pil_img
129
+ elif width > height:
130
+ result = Image.new(pil_img.mode, (width, width), background_color)
131
+ result.paste(pil_img, (0, (width - height) // 2))
132
+ return result
133
+ else:
134
+ result = Image.new(pil_img.mode, (height, height), background_color)
135
+ result.paste(pil_img, ((height - width) // 2, 0))
136
+ return result
137
+ image = expand2square(image)
138
+ elif image_process_mode in ["Default", "Crop"]:
139
+ pass
140
+ elif image_process_mode == "Resize":
141
+ image = image.resize((336, 336))
142
+ else:
143
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
144
+ if max(image.size) > max_len:
145
+ max_hw, min_hw = max(image.size), min(image.size)
146
+ aspect_ratio = max_hw / min_hw
147
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
148
+ longest_edge = int(shortest_edge * aspect_ratio)
149
+ W, H = image.size
150
+ if H > W:
151
+ H, W = longest_edge, shortest_edge
152
+ else:
153
+ H, W = shortest_edge, longest_edge
154
+ image = image.resize((W, H))
155
+ if return_pil:
156
+ return image
157
+ else:
158
+ buffered = BytesIO()
159
+ image.save(buffered, format=image_format)
160
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
161
+ return img_b64_str
162
+
163
+ def get_images(self, return_pil=False):
164
+ images = []
165
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
166
+ if i % 2 == 0:
167
+ if type(msg) is tuple:
168
+ msg, image, image_process_mode = msg
169
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
170
+ images.append(image)
171
+ return images
172
+
173
+ def to_gradio_chatbot(self):
174
+ ret = []
175
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
176
+ if i % 2 == 0:
177
+ if type(msg) is tuple:
178
+ msg, image, image_process_mode = msg
179
+ img_b64_str = self.process_image(
180
+ image, "Default", return_pil=False,
181
+ image_format='JPEG')
182
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
183
+ msg = img_str + msg.replace('<image>', '').strip()
184
+ ret.append([msg, None])
185
+ else:
186
+ ret.append([msg, None])
187
+ else:
188
+ if type(msg) is tuple and len(msg) == 2:
189
+ msg, img_b64_str = msg
190
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
191
+ msg = msg.strip() + img_str
192
+ ret[-1][-1] = msg
193
+ return ret
194
+
195
+ def copy(self):
196
+ return Conversation(
197
+ system=self.system,
198
+ roles=self.roles,
199
+ messages=[[x, y] for x, y in self.messages],
200
+ offset=self.offset,
201
+ sep_style=self.sep_style,
202
+ sep=self.sep,
203
+ sep2=self.sep2,
204
+ version=self.version)
205
+
206
+ def dict(self):
207
+ if len(self.get_images()) > 0:
208
+ return {
209
+ "system": self.system,
210
+ "roles": self.roles,
211
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
212
+ "offset": self.offset,
213
+ "sep": self.sep,
214
+ "sep2": self.sep2,
215
+ }
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": self.messages,
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+
225
+
226
+ conv_vicuna_v0 = Conversation(
227
+ system="A chat between a curious human and an artificial intelligence assistant. "
228
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
229
+ roles=("Human", "Assistant"),
230
+ messages=(
231
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
232
+ ("Assistant",
233
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
234
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
235
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
236
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
237
+ "renewable and non-renewable energy sources:\n"
238
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
239
+ "energy sources are finite and will eventually run out.\n"
240
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
241
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
242
+ "and other negative effects.\n"
243
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
244
+ "have lower operational costs than non-renewable sources.\n"
245
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
246
+ "locations than non-renewable sources.\n"
247
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
248
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
249
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
250
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
251
+ ),
252
+ offset=2,
253
+ sep_style=SeparatorStyle.SINGLE,
254
+ sep="###",
255
+ )
256
+
257
+ conv_vicuna_v1 = Conversation(
258
+ system="A chat between a curious user and an artificial intelligence assistant. "
259
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
260
+ roles=("USER", "ASSISTANT"),
261
+ version="v1",
262
+ messages=(),
263
+ offset=0,
264
+ sep_style=SeparatorStyle.TWO,
265
+ sep=" ",
266
+ sep2="</s>",
267
+ )
268
+
269
+ conv_llama_2 = Conversation(
270
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
271
+
272
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
273
+ roles=("USER", "ASSISTANT"),
274
+ version="llama_v2",
275
+ messages=(),
276
+ offset=0,
277
+ sep_style=SeparatorStyle.LLAMA_2,
278
+ sep="<s>",
279
+ sep2="</s>",
280
+ )
281
+
282
+ conv_llava_llama_2 = Conversation(
283
+ system="You are a helpful language and vision assistant. "
284
+ "You are able to understand the visual content that the user provides, "
285
+ "and assist the user with a variety of tasks using natural language.",
286
+ roles=("USER", "ASSISTANT"),
287
+ version="llama_v2",
288
+ messages=(),
289
+ offset=0,
290
+ sep_style=SeparatorStyle.LLAMA_2,
291
+ sep="<s>",
292
+ sep2="</s>",
293
+ )
294
+
295
+ conv_mpt = Conversation(
296
+ system="""<|im_start|>system
297
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
298
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
299
+ version="mpt",
300
+ messages=(),
301
+ offset=0,
302
+ sep_style=SeparatorStyle.MPT,
303
+ sep="<|im_end|>",
304
+ )
305
+
306
+ conv_llava_plain = Conversation(
307
+ system="",
308
+ roles=("", ""),
309
+ messages=(
310
+ ),
311
+ offset=0,
312
+ sep_style=SeparatorStyle.PLAIN,
313
+ sep="\n",
314
+ )
315
+
316
+ conv_llava_v0 = Conversation(
317
+ system="A chat between a curious human and an artificial intelligence assistant. "
318
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
319
+ roles=("Human", "Assistant"),
320
+ messages=(
321
+ ),
322
+ offset=0,
323
+ sep_style=SeparatorStyle.SINGLE,
324
+ sep="###",
325
+ )
326
+
327
+ conv_llava_v0_mmtag = Conversation(
328
+ system="A chat between a curious user and an artificial intelligence assistant. "
329
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
330
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
331
+ roles=("Human", "Assistant"),
332
+ messages=(
333
+ ),
334
+ offset=0,
335
+ sep_style=SeparatorStyle.SINGLE,
336
+ sep="###",
337
+ version="v0_mmtag",
338
+ )
339
+
340
+ conv_llava_v1 = Conversation(
341
+ system="A chat between a curious human and an artificial intelligence assistant. "
342
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
343
+ roles=("USER", "ASSISTANT"),
344
+ version="v1",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.TWO,
348
+ sep=" ",
349
+ sep2="</s>",
350
+ )
351
+
352
+ conv_vicuna_imgsp_v1 = Conversation(
353
+ system="A chat between a curious user and an artificial intelligence assistant. "
354
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
355
+ roles=("USER", "ASSISTANT"),
356
+ version="imgsp_v1",
357
+ messages=(),
358
+ offset=0,
359
+ sep_style=SeparatorStyle.TWO,
360
+ sep=" ",
361
+ sep2="</s>",
362
+ )
363
+
364
+ conv_llava_plain_guided = Conversation(
365
+ system="",
366
+ roles=("", ""),
367
+ version="plain_guided",
368
+ messages=(
369
+ ),
370
+ offset=0,
371
+ sep_style=SeparatorStyle.PLAIN,
372
+ sep="\n",
373
+ )
374
+
375
+ conv_llava_v1_mmtag = Conversation(
376
+ system="A chat between a curious user and an artificial intelligence assistant. "
377
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
378
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
379
+ roles=("USER", "ASSISTANT"),
380
+ messages=(),
381
+ offset=0,
382
+ sep_style=SeparatorStyle.TWO,
383
+ sep=" ",
384
+ sep2="</s>",
385
+ version="v1_mmtag",
386
+ )
387
+
388
+ conv_phi_2 = Conversation(
389
+ system="A chat between a curious user and an artificial intelligence assistant. "
390
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
391
+ roles=("USER", "ASSISTANT"),
392
+ version="phi2",
393
+ messages=(),
394
+ offset=0,
395
+ sep_style=SeparatorStyle.TWO,
396
+ sep=" ",
397
+ sep2="<|endoftext|>",
398
+ )
399
+
400
+ conv_mistral_instruct = Conversation(
401
+ system="",
402
+ roles=("USER", "ASSISTANT"),
403
+ version="llama_v2",
404
+ messages=(),
405
+ offset=0,
406
+ sep_style=SeparatorStyle.LLAMA_2,
407
+ sep="<s>",
408
+ sep2="</s>",
409
+ )
410
+
411
+ conv_gemma = Conversation(
412
+ system="",
413
+ roles=("user", "model"),
414
+ version="gemma",
415
+ messages=(),
416
+ offset=0,
417
+ sep_style=SeparatorStyle.GEMMA,
418
+ sep="",
419
+ sep2="<eos>",
420
+ )
421
+
422
+ conv_chatml_direct = Conversation(
423
+ system="""<|im_start|>system
424
+ Answer the questions.""",
425
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
426
+ version="mpt",
427
+ messages=(),
428
+ offset=0,
429
+ sep_style=SeparatorStyle.MPT,
430
+ sep="<|im_end|>",
431
+ )
432
+
433
+ default_conversation = conv_vicuna_v1
434
+ conv_templates = {
435
+ "default": conv_vicuna_v0,
436
+ "v0": conv_vicuna_v0,
437
+ "v1": conv_vicuna_v1,
438
+ "vicuna_v1": conv_vicuna_v1,
439
+ "phi_2": conv_phi_2,
440
+ "gemma": conv_gemma,
441
+ "llama_2": conv_llama_2,
442
+ "imgsp_v1": conv_vicuna_imgsp_v1,
443
+ "plain_guided": conv_llava_plain_guided,
444
+ "mistral_instruct": conv_mistral_instruct,
445
+ "chatml_direct": conv_chatml_direct,
446
+ "mistral_direct": conv_chatml_direct,
447
+ "plain": conv_llava_plain,
448
+ "v0_plain": conv_llava_plain,
449
+ "llava_v0": conv_llava_v0,
450
+ "v0_mmtag": conv_llava_v0_mmtag,
451
+ "llava_v1": conv_llava_v1,
452
+ "v1_mmtag": conv_llava_v1_mmtag,
453
+ "llava_llama_2": conv_llava_llama_2,
454
+
455
+ "mpt": conv_mpt,
456
+ }
457
+
458
+
459
+ if __name__ == "__main__":
460
+ print(default_conversation.get_prompt())
helpers.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from PIL import Image
4
+
5
+
6
+ ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
7
+ def top_k_top_p_filtering(
8
+ logits,
9
+ top_k: int = 0,
10
+ top_p: float = 1.0,
11
+ filter_value: float = -float("Inf"),
12
+ min_tokens_to_keep: int = 1,
13
+ ):
14
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
15
+ Args:
16
+ logits: logits distribution shape (batch size, vocabulary size)
17
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
18
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
19
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
20
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
21
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
22
+ """
23
+
24
+ logits[:,:256000]=filter_value
25
+ if top_k > 0:
26
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
27
+ # Remove all tokens with a probability less than the last token of the top-k
28
+
29
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
30
+ logits[indices_to_remove] = filter_value
31
+
32
+ if top_p < 1.0:
33
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
34
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
35
+
36
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
37
+ sorted_indices_to_remove = cumulative_probs > top_p
38
+ if min_tokens_to_keep > 1:
39
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
40
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
41
+ # Shift the indices to the right to keep also the first token above the threshold
42
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
+ sorted_indices_to_remove[..., 0] = 0
44
+
45
+ # scatter sorted tensors to original indexing
46
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
47
+ logits[indices_to_remove] = filter_value
48
+ # import pdb;pdb.set_trace()
49
+ return logits
50
+
51
+
52
+ def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
53
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
54
+ if top_k > 0 or top_p < 1.0:
55
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
56
+ probs = F.softmax(logits, dim=-1)
57
+ if sample_logits:
58
+ idx = torch.multinomial(probs, num_samples=1)
59
+ else:
60
+ _, idx = torch.topk(probs, k=1, dim=-1)
61
+ return idx, probs
62
+
63
+
64
+ def expand2square(pil_img, background_color):
65
+ width, height = pil_img.size
66
+ if width == height:
67
+ return pil_img
68
+ elif width > height:
69
+ result = Image.new(pil_img.mode, (width, width), background_color)
70
+ result.paste(pil_img, (0, (width - height) // 2))
71
+ return result
72
+ else:
73
+ result = Image.new(pil_img.mode, (height, height), background_color)
74
+ result.paste(pil_img, ((height - width) // 2, 0))
75
+ return result
76
+
77
+
78
+
79
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=-200, return_tensors=None):
80
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
81
+
82
+ def insert_separator(X, sep):
83
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
84
+
85
+ input_ids = []
86
+ offset = 0
87
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
88
+ offset = 1
89
+ input_ids.append(prompt_chunks[0][0])
90
+
91
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
92
+ input_ids.extend(x[offset:])
93
+
94
+ if return_tensors is not None:
95
+ if return_tensors == 'pt':
96
+ return torch.tensor(input_ids, dtype=torch.long)
97
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
98
+ return input_ids
99
+
requirements.txt CHANGED
@@ -1 +1,6 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
1
+ torch
2
+ transformers==4.39.2
3
+ spaces
4
+ pillow
5
+ accelerate
6
+ tqdm