Spaces:
Runtime error
Runtime error
initial commit
Browse files- .gitattributes +11 -0
- License.txt +14 -0
- app.py +176 -4
- assets/example_outputs/case_1.png +3 -0
- assets/example_outputs/case_2.png +3 -0
- assets/example_outputs/case_3.png +3 -0
- assets/example_outputs/case_4.png +3 -0
- assets/example_outputs/case_5.png +3 -0
- assets/example_outputs/case_6.png +3 -0
- assets/example_outputs/case_7.png +3 -0
- assets/framework.png +3 -0
- assets/grpo_curve.png +3 -0
- assets/inference.png +3 -0
- assets/reasoning_case_com.png +3 -0
- assets/tapdole.jpeg +0 -0
- requirements.txt +21 -0
- src/__init__.py +5 -0
- src/image_decoder/__init__.py +6 -0
- src/image_decoder/image_pipeline.py +273 -0
- src/image_decoder/model.py +395 -0
- src/image_decoder/modeling_phi3.py +1611 -0
- src/image_decoder/processor.py +221 -0
- src/image_decoder/scheduler.py +194 -0
- src/image_decoder/transformer.py +179 -0
- src/mindomni.py +219 -0
- src/mllm.py +245 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/example_outputs/case_1.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/example_outputs/case_2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/example_outputs/case_3.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/example_outputs/case_4.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/example_outputs/case_5.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/example_outputs/case_6.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/example_outputs/case_7.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/framework.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/grpo_curve.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/inference.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
assets/reasoning_case_com.png filter=lfs diff=lfs merge=lfs -text
|
License.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Tencent is pleased to support the open source community by making MindOmni available.
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2025 Tencent. All rights reserved.
|
| 4 |
+
|
| 5 |
+
MindOmni is licensed under the MIT License.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
Terms of the MIT License:
|
| 9 |
+
--------------------------------------------------------------------
|
| 10 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 13 |
+
|
| 14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
app.py
CHANGED
|
@@ -1,7 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
def greet(name):
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import random
|
| 7 |
+
import spaces
|
| 8 |
import gradio as gr
|
| 9 |
+
from src import MindOmni
|
| 10 |
+
|
| 11 |
+
NEGATIVE_PROMPT = '''
|
| 12 |
+
low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_args():
|
| 17 |
+
args = argparse.ArgumentParser(description='MindOmni')
|
| 18 |
+
args.add_argument('--device', type=str, default='cuda')
|
| 19 |
+
args.add_argument('--dtype', type=str, default='bf16')
|
| 20 |
+
args.add_argument('--server_name', type=str, default='127.0.0.1')
|
| 21 |
+
args.add_argument('--port', type=int, default=8080)
|
| 22 |
+
args.add_argument('--model_path', type=str,
|
| 23 |
+
default='your_path/MindOmni')
|
| 24 |
+
args = args.parse_args()
|
| 25 |
+
return args
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def build_model(args):
|
| 29 |
+
device = args.device
|
| 30 |
+
MindOmni_model = MindOmni.from_pretrained(args.model_path)
|
| 31 |
+
if args.dtype == "bf16":
|
| 32 |
+
dtype = torch.bfloat16
|
| 33 |
+
MindOmni_model.to(device=device, dtype=dtype)
|
| 34 |
+
MindOmni_model.eval()
|
| 35 |
+
return MindOmni_model
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@spaces.GPU(duration=180)
|
| 39 |
+
def understand_func(
|
| 40 |
+
MindOmni_model, text, do_sample, temperature,
|
| 41 |
+
max_new_tokens, input_llm_images):
|
| 42 |
+
if input_llm_images is not None and not isinstance(input_llm_images, list):
|
| 43 |
+
input_llm_images = [input_llm_images]
|
| 44 |
+
answer = MindOmni_model.generate_text(
|
| 45 |
+
text, input_llm_images, do_sample, temperature,
|
| 46 |
+
max_new_tokens, only_understand=True)
|
| 47 |
+
return answer
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@spaces.GPU(duration=180)
|
| 51 |
+
def generate_func(
|
| 52 |
+
MindOmni_model, text, use_cot, height, width, guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model, max_input_image_size, randomize_seed, save_images, do_sample, temperature, max_new_tokens, input_llm_images, only_understand):
|
| 53 |
+
if input_llm_images is not None and not isinstance(input_llm_images, list):
|
| 54 |
+
input_llm_images = [input_llm_images]
|
| 55 |
+
|
| 56 |
+
if randomize_seed:
|
| 57 |
+
seed = random.randint(0, 10000000)
|
| 58 |
+
|
| 59 |
+
os.makedirs(os.path.dirname('/tmp/.unhold'), exist_ok=True)
|
| 60 |
+
with open('/tmp/.unhold', 'w') as f:
|
| 61 |
+
f.write('')
|
| 62 |
+
output, prompt_ = MindOmni_model.generate_image(
|
| 63 |
+
height, width, guidance_scale, inference_steps, separate_cfg_infer, offload_model, seed, max_input_image_size,
|
| 64 |
+
text, NEGATIVE_PROMPT, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=use_cot)
|
| 65 |
+
os.remove('/tmp/.unhold')
|
| 66 |
+
|
| 67 |
+
img = output[0]
|
| 68 |
+
|
| 69 |
+
if save_images:
|
| 70 |
+
# Save All Generated Images
|
| 71 |
+
from datetime import datetime
|
| 72 |
+
# Create outputs directory if it doesn't exist
|
| 73 |
+
os.makedirs('assets/outputs', exist_ok=True)
|
| 74 |
+
# Generate unique filename with timestamp
|
| 75 |
+
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
|
| 76 |
+
output_path = os.path.join('assets/outputs', f'{timestamp}.png')
|
| 77 |
+
# Save the image
|
| 78 |
+
img.save(output_path)
|
| 79 |
+
|
| 80 |
+
return img, prompt_, seed
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def build_gradio(args, MindOmni_model):
|
| 84 |
+
with gr.Blocks() as demo:
|
| 85 |
+
gr.Markdown("## 🪄 MindOmni Demo")
|
| 86 |
+
|
| 87 |
+
with gr.Tabs():
|
| 88 |
+
# ---------- GENERATE ----------
|
| 89 |
+
with gr.TabItem("🎨 Generate"):
|
| 90 |
+
with gr.Row():
|
| 91 |
+
with gr.Column(scale=1):
|
| 92 |
+
g_prompt = gr.Textbox(label="Text prompt")
|
| 93 |
+
g_image = gr.Image(label="Condition image (optional)", type="filepath")
|
| 94 |
+
g_btn = gr.Button("🚀 Generate Image")
|
| 95 |
+
|
| 96 |
+
with gr.Accordion("📚 Image Generation Args"):
|
| 97 |
+
g_use_cot = gr.Checkbox(label="With thinking", value=False)
|
| 98 |
+
g_do_sample = gr.Checkbox(label="Do sample", value=False)
|
| 99 |
+
g_temperature = gr.Slider(0, 10, value=1, label="Temperature")
|
| 100 |
+
g_max_new_tok = gr.Slider(32, 8192, value=512, label="Max new tokens")
|
| 101 |
+
|
| 102 |
+
g_height = gr.Slider(128, 2048, value=1024, step=16, label="Height")
|
| 103 |
+
g_width = gr.Slider(128, 2048, value=1024, step=16, label="Width")
|
| 104 |
+
g_scale = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Guidance Scale")
|
| 105 |
+
g_steps = gr.Slider(1, 100, value=50, label="Inference Steps")
|
| 106 |
+
g_seed = gr.Slider(0, 2**31 - 1, value=42, label="Seed")
|
| 107 |
+
g_rand = gr.Checkbox(label="Randomize seed", value=False)
|
| 108 |
+
g_max_img = gr.Slider(128, 2048, value=1024, step=16,
|
| 109 |
+
label="Max input image size")
|
| 110 |
+
g_sep_cfg = gr.Checkbox(label="Separate-CFG infer", value=True)
|
| 111 |
+
g_offload = gr.Checkbox(label="Offload model to CPU", value=False)
|
| 112 |
+
g_save = gr.Checkbox(label="Save generated images", value=False)
|
| 113 |
+
|
| 114 |
+
with gr.Column(scale=1):
|
| 115 |
+
g_out_img = gr.Image(label="Generated Image")
|
| 116 |
+
g_prompt_out = gr.Textbox(label="MindOmni CoT Content")
|
| 117 |
+
g_seed_out = gr.Textbox(label="Used seed")
|
| 118 |
+
|
| 119 |
+
with gr.Accordion("🖼️ Prompt Examples: Text-only"):
|
| 120 |
+
gr.Examples(
|
| 121 |
+
examples=[
|
| 122 |
+
["Futuristic city skyline at sunset, digital art", 42, False, False, False, 1024, 1024, "assets/example_outputs/case_1.png"],
|
| 123 |
+
["An image of multiple apples, the quantity of apples is the solution of '2x + 6 = 16'.", 1723284, False, True, False, 512, 1024, "assets/example_outputs/case_2.png"],
|
| 124 |
+
["A park with benches equal to the solution of 'x^2 -2x = 8'.", 4318852, False, True, False, 512, 512, "assets/example_outputs/case_3.png"],
|
| 125 |
+
["An image of China's national treasure animal.", 42, False, True, False, 1024, 1024, "assets/example_outputs/case_4.png"],
|
| 126 |
+
["Scene in the Sydney Opera House when New York is at noon.", 42, False, True, False, 1024, 1024, "assets/example_outputs/case_5.png"],
|
| 127 |
+
["Generate an image of an animal with (3 + 6) lives", 7393438, False, True, False, 1024, 1024, "assets/example_outputs/case_6.png"],
|
| 128 |
+
],
|
| 129 |
+
inputs=[g_prompt, g_seed, g_rand, g_use_cot, g_do_sample, g_height, g_width, g_out_img],
|
| 130 |
+
)
|
| 131 |
+
with gr.Accordion("🖼️ Prompt Examples: With reference image"):
|
| 132 |
+
gr.Examples(
|
| 133 |
+
examples=[
|
| 134 |
+
["An image of the animal growing up", "assets/tapdole.jpeg", 42, False, True, True, 1024, 1024, "assets/example_outputs/case_7.png"]
|
| 135 |
+
],
|
| 136 |
+
inputs=[g_prompt, g_image, g_seed, g_rand, g_use_cot, g_do_sample, g_height, g_width, g_out_img],
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
g_btn.click(
|
| 140 |
+
partial(generate_func, MindOmni_model),
|
| 141 |
+
inputs=[g_prompt, g_use_cot, g_height, g_width, g_scale, g_steps,
|
| 142 |
+
g_seed, g_sep_cfg, g_offload, g_max_img, g_rand, g_save,
|
| 143 |
+
g_do_sample, g_temperature, g_max_new_tok,
|
| 144 |
+
g_image, gr.State(False)], # only_understand=False
|
| 145 |
+
outputs=[g_out_img, g_prompt_out, g_seed_out])
|
| 146 |
+
|
| 147 |
+
# ---------- UNDERSTAND ----------
|
| 148 |
+
with gr.TabItem("🧠 Understand"):
|
| 149 |
+
with gr.Row():
|
| 150 |
+
with gr.Column(scale=1):
|
| 151 |
+
u_prompt = gr.Textbox(label="Text prompt")
|
| 152 |
+
u_image = gr.Image(label="Image (optional)", type="filepath")
|
| 153 |
+
u_btn = gr.Button("🔍 Understand")
|
| 154 |
+
with gr.Accordion("📚 Text Generation Args"):
|
| 155 |
+
u_do_sample = gr.Checkbox(label="Do sample", value=False)
|
| 156 |
+
u_temperature = gr.Slider(0, 10, value=1, label="Temperature")
|
| 157 |
+
u_max_new_tok = gr.Slider(32, 8192, value=512, label="Max new tokens")
|
| 158 |
+
|
| 159 |
+
with gr.Column(scale=1):
|
| 160 |
+
u_answer = gr.Textbox(label="Answer", lines=8)
|
| 161 |
+
|
| 162 |
+
u_btn.click(
|
| 163 |
+
partial(understand_func, MindOmni_model),
|
| 164 |
+
inputs=[u_prompt, u_do_sample,
|
| 165 |
+
u_temperature, u_max_new_tok, u_image],
|
| 166 |
+
outputs=u_answer)
|
| 167 |
+
|
| 168 |
+
demo.launch(server_name=args.server_name, server_port=args.port)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def main():
|
| 172 |
+
args = parse_args()
|
| 173 |
+
print(f'running args: {args}')
|
| 174 |
+
MindOmni_model = build_model(args)
|
| 175 |
+
build_gradio(args, MindOmni_model)
|
| 176 |
|
|
|
|
|
|
|
| 177 |
|
| 178 |
+
if __name__ == '__main__':
|
| 179 |
+
main()
|
assets/example_outputs/case_1.png
ADDED
|
Git LFS Details
|
assets/example_outputs/case_2.png
ADDED
|
Git LFS Details
|
assets/example_outputs/case_3.png
ADDED
|
Git LFS Details
|
assets/example_outputs/case_4.png
ADDED
|
Git LFS Details
|
assets/example_outputs/case_5.png
ADDED
|
Git LFS Details
|
assets/example_outputs/case_6.png
ADDED
|
Git LFS Details
|
assets/example_outputs/case_7.png
ADDED
|
Git LFS Details
|
assets/framework.png
ADDED
|
Git LFS Details
|
assets/grpo_curve.png
ADDED
|
Git LFS Details
|
assets/inference.png
ADDED
|
Git LFS Details
|
assets/reasoning_case_com.png
ADDED
|
Git LFS Details
|
assets/tapdole.jpeg
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.7.0
|
| 2 |
+
datasets==2.20.0
|
| 3 |
+
decord==0.6.0
|
| 4 |
+
deepspeed==0.16.5
|
| 5 |
+
diffusers==0.30.3
|
| 6 |
+
gradio==4.44.1
|
| 7 |
+
gradio_client==1.3.0
|
| 8 |
+
huggingface-hub==0.32.0
|
| 9 |
+
numpy==1.26.3
|
| 10 |
+
omegaconf==2.3.0
|
| 11 |
+
pandas==2.2.3
|
| 12 |
+
pathvalidate==3.2.1
|
| 13 |
+
peft==0.13.2
|
| 14 |
+
qwen-vl-utils==0.0.8
|
| 15 |
+
safetensors==0.4.5
|
| 16 |
+
scipy==1.13.1
|
| 17 |
+
sympy==1.13.3
|
| 18 |
+
timm==0.9.16
|
| 19 |
+
tokenizers==0.21.1
|
| 20 |
+
torch==2.4.0
|
| 21 |
+
transformers==4.51.1
|
src/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .image_decoder import * # noqa
|
| 2 |
+
from .mllm import MindOmniMLLM, MindOmniMLLM_Model
|
| 3 |
+
from .mindomni import MindOmni
|
| 4 |
+
|
| 5 |
+
__all__ = ["MindOmniMLLM", "MindOmniMLLM_Model", "MindOmni"]
|
src/image_decoder/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .image_pipeline import ImageDecoderPipeline
|
| 2 |
+
from .model import OmniGen
|
| 3 |
+
from .modeling_phi3 import Phi3DecoderLayer
|
| 4 |
+
from .processor import OmniGenProcessor
|
| 5 |
+
|
| 6 |
+
__all__ = ["ImageDecoderPipeline", "OmniGen", "Phi3DecoderLayer", "OmniGenProcessor"]
|
src/image_decoder/image_pipeline.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is based on OmniGen
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
import gc
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
try:
|
| 8 |
+
import torch_npu
|
| 9 |
+
except Exception as e:
|
| 10 |
+
print(e)
|
| 11 |
+
from diffusers.models import AutoencoderKL
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from .processor import OmniGenProcessor
|
| 15 |
+
from .model import OmniGen
|
| 16 |
+
from .scheduler import OmniGenScheduler
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ImageDecoderPipeline:
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
vae: AutoencoderKL,
|
| 26 |
+
model: OmniGen,
|
| 27 |
+
connector: nn.Module,
|
| 28 |
+
processor: OmniGenProcessor,
|
| 29 |
+
device: Union[str, torch.device] = None,
|
| 30 |
+
):
|
| 31 |
+
self.vae = vae
|
| 32 |
+
self.model = model
|
| 33 |
+
self.connector = connector
|
| 34 |
+
self.processor = processor
|
| 35 |
+
self.device = device
|
| 36 |
+
|
| 37 |
+
if device is None:
|
| 38 |
+
if torch.cuda.is_available():
|
| 39 |
+
self.device = torch.device("cuda")
|
| 40 |
+
elif torch_npu.npu.is_available():
|
| 41 |
+
self.device = torch.device("npu")
|
| 42 |
+
elif torch.backends.mps.is_available():
|
| 43 |
+
self.device = torch.device("mps")
|
| 44 |
+
else:
|
| 45 |
+
logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
|
| 46 |
+
self.device = torch.device("cpu")
|
| 47 |
+
|
| 48 |
+
# self.model.to(torch.bfloat16)
|
| 49 |
+
self.model.eval()
|
| 50 |
+
self.vae.eval()
|
| 51 |
+
|
| 52 |
+
self.model_cpu_offload = False
|
| 53 |
+
|
| 54 |
+
def to(self, device: Union[str, torch.device]):
|
| 55 |
+
if isinstance(device, str):
|
| 56 |
+
device = torch.device(device)
|
| 57 |
+
self.model.to(device)
|
| 58 |
+
self.vae.to(device)
|
| 59 |
+
self.device = device
|
| 60 |
+
|
| 61 |
+
def vae_encode(self, x, dtype):
|
| 62 |
+
if self.vae.config.shift_factor is not None:
|
| 63 |
+
x = self.vae.encode(x).latent_dist.sample()
|
| 64 |
+
x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 65 |
+
else:
|
| 66 |
+
x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
|
| 67 |
+
x = x.to(dtype)
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
def move_to_device(self, data):
|
| 71 |
+
if isinstance(data, list):
|
| 72 |
+
return [x.to(self.device) for x in data]
|
| 73 |
+
return data.to(self.device)
|
| 74 |
+
|
| 75 |
+
def enable_model_cpu_offload(self):
|
| 76 |
+
self.model_cpu_offload = True
|
| 77 |
+
self.model.to("cpu")
|
| 78 |
+
self.vae.to("cpu")
|
| 79 |
+
if torch.cuda.is_available():
|
| 80 |
+
torch.cuda.empty_cache() # Clear VRAM
|
| 81 |
+
elif torch_npu.npu.is_available():
|
| 82 |
+
torch_npu.npu.empty_cache() # Clear VRAM
|
| 83 |
+
gc.collect() # Run garbage collection to free system RAM
|
| 84 |
+
|
| 85 |
+
def disable_model_cpu_offload(self):
|
| 86 |
+
self.model_cpu_offload = False
|
| 87 |
+
self.model.to(self.device)
|
| 88 |
+
self.vae.to(self.device)
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
def __call__(
|
| 92 |
+
self,
|
| 93 |
+
context_hidden_state: Union[str, List[str]] = None,
|
| 94 |
+
neg_context_hidden_state: Union[str, List[str]] = None,
|
| 95 |
+
height: int = 1024,
|
| 96 |
+
width: int = 1024,
|
| 97 |
+
num_inference_steps: int = 50,
|
| 98 |
+
guidance_scale: float = 3,
|
| 99 |
+
max_input_image_size: int = 1024,
|
| 100 |
+
separate_cfg_infer: bool = True,
|
| 101 |
+
offload_model: bool = False,
|
| 102 |
+
use_kv_cache: bool = True,
|
| 103 |
+
offload_kv_cache: bool = True,
|
| 104 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 105 |
+
seed: int = None,
|
| 106 |
+
output_type: str = "pil",
|
| 107 |
+
tqdm_disable: bool = False,
|
| 108 |
+
):
|
| 109 |
+
r"""
|
| 110 |
+
Function invoked when calling the pipeline for generation.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
prompt (`str` or `List[str]`):
|
| 114 |
+
The prompt or prompts to guide the image generation.
|
| 115 |
+
input_images (`List[str]` or `List[List[str]]`, *optional*):
|
| 116 |
+
The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
|
| 117 |
+
height (`int`, *optional*, defaults to 1024):
|
| 118 |
+
The height in pixels of the generated image. The number must be a multiple of 16.
|
| 119 |
+
width (`int`, *optional*, defaults to 1024):
|
| 120 |
+
The width in pixels of the generated image. The number must be a multiple of 16.
|
| 121 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 122 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
| 123 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 124 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 125 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 126 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 127 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 128 |
+
usually at the expense of lower image quality.
|
| 129 |
+
use_img_guidance (`bool`, *optional*, defaults to True):
|
| 130 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
| 131 |
+
img_guidance_scale (`float`, *optional*, defaults to 1.6):
|
| 132 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
| 133 |
+
max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
|
| 134 |
+
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
| 135 |
+
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
| 136 |
+
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
| 137 |
+
offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
|
| 138 |
+
offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
|
| 139 |
+
use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
|
| 140 |
+
seed (`int`, *optional*):
|
| 141 |
+
A random seed for generating output.
|
| 142 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
| 143 |
+
data type for the model
|
| 144 |
+
output_type (`str`, *optional*, defaults to "pil"):
|
| 145 |
+
The type of the output image, which can be "pt" or "pil"
|
| 146 |
+
Examples:
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
A list with the generated images.
|
| 150 |
+
"""
|
| 151 |
+
# check inputs:
|
| 152 |
+
assert height % 16 == 0 and width % 16 == 0, "The height and width must be a multiple of 16."
|
| 153 |
+
if context_hidden_state is not None and not isinstance(context_hidden_state, list):
|
| 154 |
+
context_hidden_state = [context_hidden_state]
|
| 155 |
+
neg_context_hidden_state = [neg_context_hidden_state]
|
| 156 |
+
|
| 157 |
+
# set model and processor
|
| 158 |
+
if max_input_image_size != self.processor.max_image_size:
|
| 159 |
+
self.processor = OmniGenProcessor(max_image_size=max_input_image_size)
|
| 160 |
+
self.model.to(dtype)
|
| 161 |
+
if offload_model:
|
| 162 |
+
self.enable_model_cpu_offload()
|
| 163 |
+
else:
|
| 164 |
+
self.disable_model_cpu_offload()
|
| 165 |
+
|
| 166 |
+
input_data = self.processor(context_hidden_state, neg_context_hidden_state, height=height, width=width, separate_cfg_input=separate_cfg_infer)
|
| 167 |
+
|
| 168 |
+
num_prompt = len(context_hidden_state)
|
| 169 |
+
num_cfg = 1
|
| 170 |
+
latent_size_h, latent_size_w = height // 8, width // 8
|
| 171 |
+
|
| 172 |
+
if seed is not None:
|
| 173 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 174 |
+
else:
|
| 175 |
+
generator = None
|
| 176 |
+
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
|
| 177 |
+
latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype)
|
| 178 |
+
|
| 179 |
+
model_kwargs = dict(cfg_scale=guidance_scale,
|
| 180 |
+
use_kv_cache=use_kv_cache,
|
| 181 |
+
offload_model=offload_model,
|
| 182 |
+
)
|
| 183 |
+
# obtain the qwen feature
|
| 184 |
+
# if self.llm_processor is not None:
|
| 185 |
+
llm_input_embeds = []
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
# for seperate cfg infer mode
|
| 188 |
+
for i in range(len(input_data['context_hidden_state'])):
|
| 189 |
+
|
| 190 |
+
context_hidden_state = input_data['context_hidden_state'][i]
|
| 191 |
+
hidden_states = self.connector[0](context_hidden_state)
|
| 192 |
+
cache_position = torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
|
| 193 |
+
|
| 194 |
+
mask_func = self.model.llm._update_causal_mask
|
| 195 |
+
cond_causal_mask = mask_func(
|
| 196 |
+
input_data['connector_attention_mask'][i].to(self.device), hidden_states, cache_position, None, None)
|
| 197 |
+
for decoder_layer in self.connector[1:]:
|
| 198 |
+
layer_out = decoder_layer(
|
| 199 |
+
hidden_states,
|
| 200 |
+
attention_mask=cond_causal_mask,
|
| 201 |
+
position_ids=input_data['connector_position_ids'][i].to(self.device),
|
| 202 |
+
)
|
| 203 |
+
hidden_states = layer_out[0]
|
| 204 |
+
|
| 205 |
+
llm_input_embeds.append(hidden_states)
|
| 206 |
+
|
| 207 |
+
# import ipdb; ipdb.set_trace()
|
| 208 |
+
model_kwargs['llm_input_embeds'] = llm_input_embeds
|
| 209 |
+
model_kwargs['llm_attention_mask'] = self.move_to_device(input_data['llm_attention_mask'])
|
| 210 |
+
model_kwargs['llm_position_ids'] = self.move_to_device(input_data['llm_position_ids'])
|
| 211 |
+
|
| 212 |
+
if separate_cfg_infer:
|
| 213 |
+
func = self.model.forward_with_separate_cfg
|
| 214 |
+
else:
|
| 215 |
+
func = self.model.forward_with_cfg
|
| 216 |
+
|
| 217 |
+
if self.model_cpu_offload:
|
| 218 |
+
for name, param in self.model.named_parameters():
|
| 219 |
+
if 'layers' in name and 'layers.0' not in name:
|
| 220 |
+
param.data = param.data.cpu()
|
| 221 |
+
else:
|
| 222 |
+
param.data = param.data.to(self.device)
|
| 223 |
+
for buffer_name, buffer in self.model.named_buffers():
|
| 224 |
+
setattr(self.model, buffer_name, buffer.to(self.device))
|
| 225 |
+
# else:
|
| 226 |
+
# self.model.to(self.device)
|
| 227 |
+
|
| 228 |
+
scheduler = OmniGenScheduler(num_steps=num_inference_steps)
|
| 229 |
+
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache, tqdm_disable=tqdm_disable)
|
| 230 |
+
samples = samples.chunk((1 + num_cfg), dim=0)[0]
|
| 231 |
+
|
| 232 |
+
if self.model_cpu_offload:
|
| 233 |
+
self.model.to('cpu')
|
| 234 |
+
if torch.cuda.is_available():
|
| 235 |
+
torch.cuda.empty_cache() # Clear VRAM
|
| 236 |
+
elif torch_npu.npu.is_available():
|
| 237 |
+
torch_npu.npu.empty_cache() # Clear VRAM
|
| 238 |
+
gc.collect()
|
| 239 |
+
|
| 240 |
+
self.vae.to(self.device)
|
| 241 |
+
samples = samples.to(torch.float32)
|
| 242 |
+
if self.vae.config.shift_factor is not None:
|
| 243 |
+
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
| 244 |
+
else:
|
| 245 |
+
samples = samples / self.vae.config.scaling_factor
|
| 246 |
+
samples = self.vae.decode(samples).sample
|
| 247 |
+
|
| 248 |
+
if self.model_cpu_offload:
|
| 249 |
+
self.vae.to('cpu')
|
| 250 |
+
if torch.cuda.is_available():
|
| 251 |
+
torch.cuda.empty_cache() # Clear VRAM
|
| 252 |
+
elif torch_npu.npu.is_available():
|
| 253 |
+
torch_npu.npu.empty_cache() # Clear VRAM
|
| 254 |
+
gc.collect()
|
| 255 |
+
|
| 256 |
+
samples = (samples * 0.5 + 0.5).clamp(0, 1)
|
| 257 |
+
|
| 258 |
+
if output_type == "pt":
|
| 259 |
+
output_images = samples
|
| 260 |
+
else:
|
| 261 |
+
output_samples = (samples * 255).to("cpu", dtype=torch.uint8)
|
| 262 |
+
output_samples = output_samples.permute(0, 2, 3, 1).numpy()
|
| 263 |
+
output_images = []
|
| 264 |
+
for i, sample in enumerate(output_samples):
|
| 265 |
+
output_images.append(Image.fromarray(sample))
|
| 266 |
+
|
| 267 |
+
if torch.cuda.is_available():
|
| 268 |
+
torch.cuda.empty_cache() # Clear VRAM
|
| 269 |
+
elif torch_npu.npu.is_available():
|
| 270 |
+
torch_npu.npu.empty_cache() # Clear VRAM
|
| 271 |
+
gc.collect() # Run garbage collection to free system RAM
|
| 272 |
+
|
| 273 |
+
return output_images
|
src/image_decoder/model.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The code is revised from DiT
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
|
| 11 |
+
from .transformer import Phi3Transformer
|
| 12 |
+
from transformers import Phi3Config
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def modulate(x, shift, scale):
|
| 16 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TimestepEmbedder(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Embeds scalar timesteps into vector representations.
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.mlp = nn.Sequential(
|
| 26 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 27 |
+
nn.SiLU(),
|
| 28 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 29 |
+
)
|
| 30 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 34 |
+
"""
|
| 35 |
+
Create sinusoidal timestep embeddings.
|
| 36 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 37 |
+
These may be fractional.
|
| 38 |
+
:param dim: the dimension of the output.
|
| 39 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 40 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 41 |
+
"""
|
| 42 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 43 |
+
half = dim // 2
|
| 44 |
+
freqs = torch.exp(
|
| 45 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 46 |
+
).to(device=t.device)
|
| 47 |
+
args = t[:, None].float() * freqs[None]
|
| 48 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 49 |
+
if dim % 2:
|
| 50 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 51 |
+
return embedding
|
| 52 |
+
|
| 53 |
+
def forward(self, t, dtype=torch.float32):
|
| 54 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
| 55 |
+
t_emb = self.mlp(t_freq)
|
| 56 |
+
return t_emb
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class FinalLayer(nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
The final layer of DiT.
|
| 62 |
+
"""
|
| 63 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 66 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 67 |
+
self.adaLN_modulation = nn.Sequential(
|
| 68 |
+
nn.SiLU(),
|
| 69 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def forward(self, x, c):
|
| 73 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 74 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 75 |
+
x = self.linear(x)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
| 80 |
+
"""
|
| 81 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
| 82 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 83 |
+
"""
|
| 84 |
+
if isinstance(grid_size, int):
|
| 85 |
+
grid_size = (grid_size, grid_size)
|
| 86 |
+
|
| 87 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
| 88 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
| 89 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 90 |
+
grid = np.stack(grid, axis=0)
|
| 91 |
+
|
| 92 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
| 93 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 94 |
+
if cls_token and extra_tokens > 0:
|
| 95 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 96 |
+
return pos_embed
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 100 |
+
assert embed_dim % 2 == 0
|
| 101 |
+
|
| 102 |
+
# use half of dimensions to encode grid_h
|
| 103 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 104 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 105 |
+
|
| 106 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 107 |
+
return emb
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 111 |
+
"""
|
| 112 |
+
embed_dim: output dimension for each position
|
| 113 |
+
pos: a list of positions to be encoded: size (M,)
|
| 114 |
+
out: (M, D)
|
| 115 |
+
"""
|
| 116 |
+
assert embed_dim % 2 == 0
|
| 117 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 118 |
+
omega /= embed_dim / 2.
|
| 119 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 120 |
+
|
| 121 |
+
pos = pos.reshape(-1) # (M,)
|
| 122 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 123 |
+
|
| 124 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 125 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 126 |
+
|
| 127 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 128 |
+
return emb
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class PatchEmbedMR(nn.Module):
|
| 132 |
+
""" 2D Image to Patch Embedding
|
| 133 |
+
"""
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
patch_size: int = 2,
|
| 137 |
+
in_chans: int = 4,
|
| 138 |
+
embed_dim: int = 768,
|
| 139 |
+
bias: bool = True,
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
x = self.proj(x)
|
| 146 |
+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class OmniGen(nn.Module, PeftAdapterMixin):
|
| 151 |
+
"""
|
| 152 |
+
Diffusion model with a Transformer backbone.
|
| 153 |
+
"""
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
transformer_config: Phi3Config,
|
| 157 |
+
patch_size=2,
|
| 158 |
+
in_channels=4,
|
| 159 |
+
pe_interpolation: float = 1.0,
|
| 160 |
+
pos_embed_max_size: int = 192,
|
| 161 |
+
):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.in_channels = in_channels
|
| 164 |
+
self.out_channels = in_channels
|
| 165 |
+
self.patch_size = patch_size
|
| 166 |
+
self.pos_embed_max_size = pos_embed_max_size
|
| 167 |
+
|
| 168 |
+
hidden_size = transformer_config.hidden_size
|
| 169 |
+
|
| 170 |
+
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
| 171 |
+
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
| 172 |
+
|
| 173 |
+
self.time_token = TimestepEmbedder(hidden_size)
|
| 174 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 175 |
+
|
| 176 |
+
self.pe_interpolation = pe_interpolation
|
| 177 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
| 178 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
| 179 |
+
|
| 180 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 181 |
+
|
| 182 |
+
self.initialize_weights()
|
| 183 |
+
|
| 184 |
+
self.llm = Phi3Transformer(config=transformer_config)
|
| 185 |
+
self.llm.config.use_cache = False
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def from_pretrained(cls, model_name):
|
| 189 |
+
if not os.path.exists(model_name):
|
| 190 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
| 191 |
+
model_name = snapshot_download(repo_id=model_name,
|
| 192 |
+
cache_dir=cache_folder,
|
| 193 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
| 194 |
+
config = Phi3Config.from_pretrained(model_name)
|
| 195 |
+
model = cls(config)
|
| 196 |
+
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
| 197 |
+
print("Loading safetensors")
|
| 198 |
+
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
| 199 |
+
else:
|
| 200 |
+
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
| 201 |
+
|
| 202 |
+
module_keys = list(model.state_dict().keys())
|
| 203 |
+
pretrained_keys = list(ckpt.keys())
|
| 204 |
+
all_keys = module_keys + pretrained_keys
|
| 205 |
+
missing_modules = []
|
| 206 |
+
unexpected_modules = []
|
| 207 |
+
for item in all_keys:
|
| 208 |
+
if item in module_keys and item not in ckpt.keys():
|
| 209 |
+
missing_modules.append(item)
|
| 210 |
+
if item not in module_keys and item in ckpt.keys():
|
| 211 |
+
unexpected_modules.append(item)
|
| 212 |
+
|
| 213 |
+
print(f"loading {model.__class__.__name__} but missing modules: {missing_modules}, unexpected modules: {unexpected_modules}")
|
| 214 |
+
model.load_state_dict(ckpt, strict=False)
|
| 215 |
+
return model
|
| 216 |
+
|
| 217 |
+
def initialize_weights(self):
|
| 218 |
+
assert not hasattr(self, "llama")
|
| 219 |
+
|
| 220 |
+
# Initialize transformer layers:
|
| 221 |
+
def _basic_init(module):
|
| 222 |
+
if isinstance(module, nn.Linear):
|
| 223 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 224 |
+
if module.bias is not None:
|
| 225 |
+
nn.init.constant_(module.bias, 0)
|
| 226 |
+
self.apply(_basic_init)
|
| 227 |
+
|
| 228 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 229 |
+
w = self.x_embedder.proj.weight.data
|
| 230 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 231 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 232 |
+
|
| 233 |
+
w = self.input_x_embedder.proj.weight.data
|
| 234 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 235 |
+
nn.init.constant_(self.input_x_embedder.proj.bias, 0)
|
| 236 |
+
|
| 237 |
+
# Initialize timestep embedding MLP:
|
| 238 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 239 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 240 |
+
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
| 241 |
+
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
| 242 |
+
|
| 243 |
+
# Zero-out output layers:
|
| 244 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 245 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 246 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 247 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 248 |
+
|
| 249 |
+
def unpatchify(self, x, h, w):
|
| 250 |
+
"""
|
| 251 |
+
x: (N, T, patch_size**2 * C)
|
| 252 |
+
imgs: (N, H, W, C)
|
| 253 |
+
"""
|
| 254 |
+
c = self.out_channels
|
| 255 |
+
|
| 256 |
+
x = x.reshape(shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c))
|
| 257 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 258 |
+
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
| 259 |
+
return imgs
|
| 260 |
+
|
| 261 |
+
def cropped_pos_embed(self, height, width):
|
| 262 |
+
"""Crops positional embeddings for SD3 compatibility."""
|
| 263 |
+
if self.pos_embed_max_size is None:
|
| 264 |
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
| 265 |
+
|
| 266 |
+
height = height // self.patch_size
|
| 267 |
+
width = width // self.patch_size
|
| 268 |
+
if height > self.pos_embed_max_size:
|
| 269 |
+
raise ValueError(
|
| 270 |
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
| 271 |
+
)
|
| 272 |
+
if width > self.pos_embed_max_size:
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
top = (self.pos_embed_max_size - height) // 2
|
| 278 |
+
left = (self.pos_embed_max_size - width) // 2
|
| 279 |
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
| 280 |
+
spatial_pos_embed = spatial_pos_embed[:, top: top + height, left: left + width, :]
|
| 281 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
| 282 |
+
return spatial_pos_embed
|
| 283 |
+
|
| 284 |
+
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images=False):
|
| 285 |
+
if isinstance(latents, list):
|
| 286 |
+
return_list = False
|
| 287 |
+
if padding_latent is None:
|
| 288 |
+
padding_latent = [None] * len(latents)
|
| 289 |
+
return_list = True
|
| 290 |
+
patched_latents, num_tokens, shapes = [], [], []
|
| 291 |
+
for latent, padding in zip(latents, padding_latent):
|
| 292 |
+
height, width = latent.shape[-2:]
|
| 293 |
+
if is_input_images:
|
| 294 |
+
latent = self.input_x_embedder(latent)
|
| 295 |
+
else:
|
| 296 |
+
latent = self.x_embedder(latent)
|
| 297 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
| 298 |
+
latent = latent + pos_embed
|
| 299 |
+
if padding is not None:
|
| 300 |
+
latent = torch.cat([latent, padding], dim=-2)
|
| 301 |
+
patched_latents.append(latent)
|
| 302 |
+
|
| 303 |
+
num_tokens.append(pos_embed.size(1))
|
| 304 |
+
shapes.append([height, width])
|
| 305 |
+
if not return_list:
|
| 306 |
+
latents = torch.cat(patched_latents, dim=0)
|
| 307 |
+
else:
|
| 308 |
+
latents = patched_latents
|
| 309 |
+
else:
|
| 310 |
+
height, width = latents.shape[-2:]
|
| 311 |
+
if is_input_images:
|
| 312 |
+
latents = self.input_x_embedder(latents)
|
| 313 |
+
else:
|
| 314 |
+
latents = self.x_embedder(latents)
|
| 315 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
| 316 |
+
latents = latents + pos_embed
|
| 317 |
+
num_tokens = latents.size(1)
|
| 318 |
+
shapes = [height, width]
|
| 319 |
+
return latents, num_tokens, shapes
|
| 320 |
+
|
| 321 |
+
def forward(self, x, timestep, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model: bool = False,
|
| 322 |
+
llm_input_embeds=None, llm_attention_mask=None, llm_position_ids=None, use_dist=False):
|
| 323 |
+
input_is_list = isinstance(x, list)
|
| 324 |
+
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
| 325 |
+
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
| 326 |
+
|
| 327 |
+
if llm_input_embeds is not None:
|
| 328 |
+
condition_embeds_llm = llm_input_embeds
|
| 329 |
+
input_emb = torch.cat([condition_embeds_llm, time_token, x], dim=1)
|
| 330 |
+
attention_mask = llm_attention_mask
|
| 331 |
+
position_ids = llm_position_ids
|
| 332 |
+
else:
|
| 333 |
+
input_emb = torch.cat([time_token, x], dim=1)
|
| 334 |
+
attention_mask = llm_attention_mask
|
| 335 |
+
position_ids = llm_position_ids
|
| 336 |
+
|
| 337 |
+
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model, output_hidden_states=True)
|
| 338 |
+
output, past_key_values, all_hidden_states = output.last_hidden_state, output.past_key_values, output.hidden_states
|
| 339 |
+
if not use_dist:
|
| 340 |
+
all_states_noise = None
|
| 341 |
+
if input_is_list:
|
| 342 |
+
image_embedding = output[:, -max(num_tokens):]
|
| 343 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
| 344 |
+
x = self.final_layer(image_embedding, time_emb)
|
| 345 |
+
latents = []
|
| 346 |
+
if use_dist:
|
| 347 |
+
all_states = torch.stack([hidden_states[:, -max(num_tokens):] for hidden_states in all_hidden_states], dim=1) # b l s d
|
| 348 |
+
all_states_noise = []
|
| 349 |
+
for i in range(x.size(0)):
|
| 350 |
+
latent = x[i: i + 1, :num_tokens[i]]
|
| 351 |
+
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
| 352 |
+
latents.append(latent)
|
| 353 |
+
if use_dist:
|
| 354 |
+
all_states_noise.append(all_states[i, :, :num_tokens[i]])
|
| 355 |
+
else:
|
| 356 |
+
image_embedding = output[:, -num_tokens:]
|
| 357 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
| 358 |
+
x = self.final_layer(image_embedding, time_emb)
|
| 359 |
+
latents = self.unpatchify(x, shapes[0], shapes[1])
|
| 360 |
+
if use_dist:
|
| 361 |
+
all_states_noise = torch.stack([hidden_states[:, -num_tokens:] for hidden_states in all_hidden_states], dim=1) # b l s d
|
| 362 |
+
|
| 363 |
+
if return_past_key_values:
|
| 364 |
+
return latents, past_key_values, all_states_noise
|
| 365 |
+
return latents, all_states_noise
|
| 366 |
+
|
| 367 |
+
@torch.no_grad()
|
| 368 |
+
def forward_with_separate_cfg(self, x, timestep, cfg_scale, past_key_values, use_kv_cache, offload_model,
|
| 369 |
+
llm_input_embeds=None, llm_attention_mask=None, llm_position_ids=None, llm_padded_input_ids=None, llm_image_sizes=None):
|
| 370 |
+
self.llm.config.use_cache = use_kv_cache
|
| 371 |
+
if past_key_values is None:
|
| 372 |
+
past_key_values = [None] * len(llm_attention_mask)
|
| 373 |
+
|
| 374 |
+
x = torch.split(x, len(x) // len(llm_attention_mask), dim=0)
|
| 375 |
+
timestep = timestep.to(x[0].dtype)
|
| 376 |
+
timestep = torch.split(timestep, len(timestep) // len(llm_input_embeds), dim=0)
|
| 377 |
+
|
| 378 |
+
model_out, pask_key_values = [], []
|
| 379 |
+
for i in range(len(llm_input_embeds)):
|
| 380 |
+
if llm_input_embeds is not None:
|
| 381 |
+
temp_out, temp_pask_key_values, _ = self.forward(x[i], timestep[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model,
|
| 382 |
+
llm_input_embeds=llm_input_embeds[i], llm_attention_mask=llm_attention_mask[i], llm_position_ids=llm_position_ids[i])
|
| 383 |
+
else:
|
| 384 |
+
temp_out, temp_pask_key_values, _ = self.forward(x[i], timestep[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
| 385 |
+
model_out.append(temp_out)
|
| 386 |
+
pask_key_values.append(temp_pask_key_values)
|
| 387 |
+
|
| 388 |
+
if len(model_out) == 2:
|
| 389 |
+
cond, uncond = model_out
|
| 390 |
+
cond = uncond + cfg_scale * (cond - uncond)
|
| 391 |
+
model_out = [cond, cond]
|
| 392 |
+
else:
|
| 393 |
+
return model_out[0]
|
| 394 |
+
|
| 395 |
+
return torch.cat(model_out, dim=0), pask_key_values
|
src/image_decoder/modeling_phi3.py
ADDED
|
@@ -0,0 +1,1611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""PyTorch Phi-3 model."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import warnings
|
| 20 |
+
from typing import List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
+
|
| 27 |
+
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 29 |
+
from transformers.generation import GenerationMixin
|
| 30 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 31 |
+
from transformers.modeling_outputs import (
|
| 32 |
+
BaseModelOutputWithPast,
|
| 33 |
+
CausalLMOutputWithPast,
|
| 34 |
+
SequenceClassifierOutputWithPast,
|
| 35 |
+
TokenClassifierOutput,
|
| 36 |
+
)
|
| 37 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 38 |
+
from transformers.utils import (
|
| 39 |
+
add_code_sample_docstrings,
|
| 40 |
+
add_start_docstrings,
|
| 41 |
+
add_start_docstrings_to_model_forward,
|
| 42 |
+
is_flash_attn_2_available,
|
| 43 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 44 |
+
is_torchdynamo_compiling,
|
| 45 |
+
logging,
|
| 46 |
+
replace_return_docstrings,
|
| 47 |
+
)
|
| 48 |
+
from transformers import Phi3Config
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if is_flash_attn_2_available():
|
| 52 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 53 |
+
|
| 54 |
+
logger = logging.get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
|
| 57 |
+
_CONFIG_FOR_DOC = "Phi3Config"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
| 61 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 62 |
+
attention_mask: torch.Tensor,
|
| 63 |
+
sequence_length: int,
|
| 64 |
+
target_length: int,
|
| 65 |
+
dtype: torch.dtype,
|
| 66 |
+
device: torch.device,
|
| 67 |
+
min_dtype: float,
|
| 68 |
+
cache_position: torch.Tensor,
|
| 69 |
+
batch_size: int,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 73 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
attention_mask (`torch.Tensor`):
|
| 77 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
| 78 |
+
sequence_length (`int`):
|
| 79 |
+
The sequence length being processed.
|
| 80 |
+
target_length (`int`):
|
| 81 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
| 82 |
+
dtype (`torch.dtype`):
|
| 83 |
+
The dtype to use for the 4D attention mask.
|
| 84 |
+
device (`torch.device`):
|
| 85 |
+
The device to plcae the 4D attention mask on.
|
| 86 |
+
min_dtype (`float`):
|
| 87 |
+
The minimum value representable with the dtype `dtype`.
|
| 88 |
+
cache_position (`torch.Tensor`):
|
| 89 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 90 |
+
batch_size (`torch.Tensor`):
|
| 91 |
+
Batch size.
|
| 92 |
+
"""
|
| 93 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 94 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 95 |
+
causal_mask = attention_mask
|
| 96 |
+
else:
|
| 97 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 98 |
+
if sequence_length != 1:
|
| 99 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 100 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 101 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 102 |
+
if attention_mask is not None:
|
| 103 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 104 |
+
mask_length = attention_mask.shape[-1]
|
| 105 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 106 |
+
padding_mask = padding_mask == 0
|
| 107 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 108 |
+
padding_mask, min_dtype
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return causal_mask
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
|
| 115 |
+
class Phi3RMSNorm(nn.Module):
|
| 116 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 117 |
+
"""
|
| 118 |
+
Phi3RMSNorm is equivalent to T5LayerNorm
|
| 119 |
+
"""
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 122 |
+
self.variance_epsilon = eps
|
| 123 |
+
|
| 124 |
+
def forward(self, hidden_states):
|
| 125 |
+
input_dtype = hidden_states.dtype
|
| 126 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 127 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 128 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 129 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 130 |
+
|
| 131 |
+
def extra_repr(self):
|
| 132 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
|
| 136 |
+
class Phi3RotaryEmbedding(nn.Module):
|
| 137 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 138 |
+
super().__init__()
|
| 139 |
+
|
| 140 |
+
self.dim = dim
|
| 141 |
+
self.max_position_embeddings = max_position_embeddings
|
| 142 |
+
self.base = base
|
| 143 |
+
|
| 144 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
| 145 |
+
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
| 146 |
+
|
| 147 |
+
@torch.no_grad()
|
| 148 |
+
def forward(self, x, position_ids, seq_len=None):
|
| 149 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 150 |
+
self.inv_freq.to(x.device)
|
| 151 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 152 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 153 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 154 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 155 |
+
device_type = x.device.type
|
| 156 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 157 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 158 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 159 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 160 |
+
cos = emb.cos()
|
| 161 |
+
sin = emb.sin()
|
| 162 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
| 166 |
+
def __init__(self, dim, config, device=None):
|
| 167 |
+
warnings.warn(
|
| 168 |
+
"The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
|
| 169 |
+
" use Phi3LongRoPEScaledRotaryEmbedding instead.",
|
| 170 |
+
FutureWarning,
|
| 171 |
+
)
|
| 172 |
+
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
| 173 |
+
|
| 174 |
+
self.short_factor = config.rope_scaling["short_factor"]
|
| 175 |
+
self.long_factor = config.rope_scaling["long_factor"]
|
| 176 |
+
self.original_max_position_embeddings = config.original_max_position_embeddings
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def forward(self, x, position_ids, seq_len=None):
|
| 180 |
+
seq_len = torch.max(position_ids) + 1
|
| 181 |
+
if seq_len > self.original_max_position_embeddings:
|
| 182 |
+
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
| 183 |
+
else:
|
| 184 |
+
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
| 185 |
+
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
| 186 |
+
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
| 187 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 188 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 189 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 190 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 191 |
+
device_type = x.device.type
|
| 192 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 193 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 194 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 195 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 196 |
+
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
| 197 |
+
if scale <= 1.0:
|
| 198 |
+
scaling_factor = 1.0
|
| 199 |
+
else:
|
| 200 |
+
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
| 201 |
+
cos = emb.cos() * scaling_factor
|
| 202 |
+
sin = emb.sin() * scaling_factor
|
| 203 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
| 207 |
+
def __init__(self, dim, config, device=None):
|
| 208 |
+
warnings.warn(
|
| 209 |
+
"The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
|
| 210 |
+
FutureWarning,
|
| 211 |
+
)
|
| 212 |
+
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
| 213 |
+
|
| 214 |
+
self.short_factor = config.rope_scaling["short_factor"]
|
| 215 |
+
self.long_factor = config.rope_scaling["long_factor"]
|
| 216 |
+
self.original_max_position_embeddings = config.original_max_position_embeddings
|
| 217 |
+
|
| 218 |
+
@torch.no_grad()
|
| 219 |
+
def forward(self, x, position_ids, seq_len=None):
|
| 220 |
+
seq_len = torch.max(position_ids) + 1
|
| 221 |
+
if seq_len > self.original_max_position_embeddings:
|
| 222 |
+
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
| 223 |
+
else:
|
| 224 |
+
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
| 225 |
+
|
| 226 |
+
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
| 227 |
+
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
| 228 |
+
|
| 229 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 230 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 231 |
+
|
| 232 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 233 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 234 |
+
device_type = x.device.type
|
| 235 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 236 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 237 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 238 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 239 |
+
|
| 240 |
+
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
| 241 |
+
if scale <= 1.0:
|
| 242 |
+
scaling_factor = 1.0
|
| 243 |
+
else:
|
| 244 |
+
scaling_factor = 0.1 * math.log(scale) + 1.0
|
| 245 |
+
|
| 246 |
+
cos = emb.cos() * scaling_factor
|
| 247 |
+
sin = emb.sin() * scaling_factor
|
| 248 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
| 252 |
+
def __init__(self, dim, config, device=None):
|
| 253 |
+
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
| 254 |
+
|
| 255 |
+
self.short_factor = config.rope_scaling["short_factor"]
|
| 256 |
+
self.long_factor = config.rope_scaling["long_factor"]
|
| 257 |
+
self.original_max_position_embeddings = config.original_max_position_embeddings
|
| 258 |
+
|
| 259 |
+
@torch.no_grad()
|
| 260 |
+
def forward(self, x, position_ids, seq_len=None):
|
| 261 |
+
seq_len = seq_len or torch.max(position_ids) + 1
|
| 262 |
+
if seq_len > self.original_max_position_embeddings:
|
| 263 |
+
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
| 264 |
+
else:
|
| 265 |
+
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
| 266 |
+
|
| 267 |
+
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
| 268 |
+
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
| 269 |
+
|
| 270 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 271 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 272 |
+
|
| 273 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 274 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 275 |
+
device_type = x.device.type
|
| 276 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 277 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 278 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 279 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 280 |
+
|
| 281 |
+
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
| 282 |
+
if scale <= 1.0:
|
| 283 |
+
scaling_factor = 1.0
|
| 284 |
+
else:
|
| 285 |
+
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
| 286 |
+
|
| 287 |
+
cos = emb.cos() * scaling_factor
|
| 288 |
+
sin = emb.sin() * scaling_factor
|
| 289 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 293 |
+
def rotate_half(x):
|
| 294 |
+
"""Rotates half the hidden dims of the input."""
|
| 295 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 296 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 297 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 301 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 302 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
q (`torch.Tensor`): The query tensor.
|
| 306 |
+
k (`torch.Tensor`): The key tensor.
|
| 307 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 308 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 309 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 310 |
+
Deprecated and unused.
|
| 311 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 312 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 313 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 314 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 315 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 316 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 317 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 318 |
+
Returns:
|
| 319 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 320 |
+
"""
|
| 321 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 322 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 323 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 324 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 325 |
+
return q_embed, k_embed
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class Phi3MLP(nn.Module):
|
| 329 |
+
def __init__(self, config):
|
| 330 |
+
super().__init__()
|
| 331 |
+
|
| 332 |
+
self.config = config
|
| 333 |
+
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
|
| 334 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 335 |
+
|
| 336 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 337 |
+
|
| 338 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 339 |
+
up_states = self.gate_up_proj(hidden_states)
|
| 340 |
+
|
| 341 |
+
gate, up_states = up_states.chunk(2, dim=-1)
|
| 342 |
+
up_states = up_states * self.activation_fn(gate)
|
| 343 |
+
|
| 344 |
+
return self.down_proj(up_states)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
| 348 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 349 |
+
"""
|
| 350 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 351 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 352 |
+
"""
|
| 353 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 354 |
+
if n_rep == 1:
|
| 355 |
+
return hidden_states
|
| 356 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 357 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class Phi3Attention(nn.Module):
|
| 361 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 362 |
+
|
| 363 |
+
def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.config = config
|
| 366 |
+
self.layer_idx = layer_idx
|
| 367 |
+
if layer_idx is None:
|
| 368 |
+
logger.warning_once(
|
| 369 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 370 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 371 |
+
"when creating this class."
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
self.attention_dropout = config.attention_dropout
|
| 375 |
+
self.hidden_size = config.hidden_size
|
| 376 |
+
self.num_heads = config.num_attention_heads
|
| 377 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 378 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 379 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 380 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 381 |
+
self.original_max_position_embeddings = config.original_max_position_embeddings
|
| 382 |
+
self.rope_theta = config.rope_theta
|
| 383 |
+
self.rope_scaling = config.rope_scaling
|
| 384 |
+
self.is_causal = True
|
| 385 |
+
|
| 386 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 389 |
+
f" and `num_heads`: {self.num_heads})."
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
|
| 393 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 394 |
+
self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
|
| 395 |
+
self._init_rope()
|
| 396 |
+
|
| 397 |
+
def _init_rope(self):
|
| 398 |
+
if self.rope_scaling is None:
|
| 399 |
+
self.rotary_emb = Phi3RotaryEmbedding(
|
| 400 |
+
self.head_dim,
|
| 401 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 402 |
+
base=self.rope_theta,
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 406 |
+
if scaling_type == "longrope":
|
| 407 |
+
self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
|
| 408 |
+
else:
|
| 409 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 410 |
+
|
| 411 |
+
def forward(
|
| 412 |
+
self,
|
| 413 |
+
hidden_states: torch.Tensor,
|
| 414 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 415 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 416 |
+
past_key_value: Optional[Cache] = None,
|
| 417 |
+
output_attentions: bool = False,
|
| 418 |
+
use_cache: bool = False,
|
| 419 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 420 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 421 |
+
logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
|
| 422 |
+
|
| 423 |
+
bsz, q_len, _ = hidden_states.size()
|
| 424 |
+
|
| 425 |
+
qkv = self.qkv_proj(hidden_states)
|
| 426 |
+
query_pos = self.num_heads * self.head_dim
|
| 427 |
+
query_states = qkv[..., :query_pos]
|
| 428 |
+
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
|
| 429 |
+
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
|
| 430 |
+
|
| 431 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 432 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 433 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 434 |
+
|
| 435 |
+
kv_seq_len = key_states.shape[-2]
|
| 436 |
+
if past_key_value is not None:
|
| 437 |
+
if self.layer_idx is None:
|
| 438 |
+
raise ValueError(
|
| 439 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 440 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 441 |
+
"with a layer index."
|
| 442 |
+
)
|
| 443 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 444 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 445 |
+
|
| 446 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 447 |
+
|
| 448 |
+
if past_key_value is not None:
|
| 449 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 450 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 451 |
+
|
| 452 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 453 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 454 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 455 |
+
|
| 456 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 457 |
+
|
| 458 |
+
if attention_mask is not None:
|
| 459 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 460 |
+
attn_weights += causal_mask
|
| 461 |
+
|
| 462 |
+
# upcast attention to fp32
|
| 463 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
| 464 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 465 |
+
|
| 466 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 467 |
+
|
| 468 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 469 |
+
raise ValueError(
|
| 470 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 471 |
+
f" {attn_output.size()}"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 475 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 476 |
+
|
| 477 |
+
attn_output = self.o_proj(attn_output)
|
| 478 |
+
|
| 479 |
+
if not output_attentions:
|
| 480 |
+
attn_weights = None
|
| 481 |
+
|
| 482 |
+
return attn_output, attn_weights, past_key_value
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class Phi3FlashAttention2(Phi3Attention):
|
| 486 |
+
"""
|
| 487 |
+
Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
|
| 488 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 489 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
| 493 |
+
def __init__(self, *args, **kwargs):
|
| 494 |
+
super().__init__(*args, **kwargs)
|
| 495 |
+
|
| 496 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 497 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 498 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 499 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 500 |
+
|
| 501 |
+
def forward(
|
| 502 |
+
self,
|
| 503 |
+
hidden_states: torch.Tensor,
|
| 504 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 505 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 506 |
+
past_key_value: Optional[Cache] = None,
|
| 507 |
+
output_attentions: bool = False,
|
| 508 |
+
use_cache: bool = False,
|
| 509 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 510 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 511 |
+
# Phi3FlashAttention2 attention does not support output_attentions
|
| 512 |
+
|
| 513 |
+
output_attentions = False
|
| 514 |
+
|
| 515 |
+
bsz, q_len, _ = hidden_states.size()
|
| 516 |
+
|
| 517 |
+
qkv = self.qkv_proj(hidden_states)
|
| 518 |
+
query_pos = self.num_heads * self.head_dim
|
| 519 |
+
query_states = qkv[..., :query_pos]
|
| 520 |
+
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
|
| 521 |
+
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
|
| 522 |
+
|
| 523 |
+
# Flash attention requires the input to have the shape
|
| 524 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 525 |
+
# therefore we just need to keep the original shape
|
| 526 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 527 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 528 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 529 |
+
|
| 530 |
+
kv_seq_len = key_states.shape[-2]
|
| 531 |
+
if past_key_value is not None:
|
| 532 |
+
if self.layer_idx is None:
|
| 533 |
+
raise ValueError(
|
| 534 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 535 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 536 |
+
"with a layer index."
|
| 537 |
+
)
|
| 538 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 539 |
+
|
| 540 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 541 |
+
rotary_seq_len = (
|
| 542 |
+
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids)
|
| 546 |
+
|
| 547 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 548 |
+
|
| 549 |
+
if past_key_value is not None:
|
| 550 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
| 551 |
+
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
| 552 |
+
if (
|
| 553 |
+
getattr(self.config, "sliding_window", None) is not None
|
| 554 |
+
and kv_seq_len > self.config.sliding_window
|
| 555 |
+
and cache_has_contents
|
| 556 |
+
):
|
| 557 |
+
slicing_tokens = 1 - self.config.sliding_window
|
| 558 |
+
|
| 559 |
+
past_key = past_key_value[self.layer_idx][0]
|
| 560 |
+
past_value = past_key_value[self.layer_idx][1]
|
| 561 |
+
|
| 562 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
| 563 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
| 564 |
+
|
| 565 |
+
if past_key.shape[-2] != self.config.sliding_window - 1:
|
| 566 |
+
raise ValueError(
|
| 567 |
+
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
| 568 |
+
f" {past_key.shape}"
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
if attention_mask is not None:
|
| 572 |
+
attention_mask = attention_mask[:, slicing_tokens:]
|
| 573 |
+
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
| 574 |
+
|
| 575 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 576 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 577 |
+
|
| 578 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 579 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 580 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 581 |
+
|
| 582 |
+
attn_dropout = self.attention_dropout if self.training else 0.0
|
| 583 |
+
|
| 584 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 585 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 586 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 587 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 588 |
+
# in fp32.
|
| 589 |
+
|
| 590 |
+
if query_states.dtype == torch.float32:
|
| 591 |
+
if torch.is_autocast_enabled():
|
| 592 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 593 |
+
# Handle the case where the model is quantized
|
| 594 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 595 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 596 |
+
else:
|
| 597 |
+
target_dtype = self.qkv_proj.weight.dtype
|
| 598 |
+
|
| 599 |
+
logger.warning_once(
|
| 600 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 601 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 602 |
+
f" {target_dtype}."
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
query_states = query_states.to(target_dtype)
|
| 606 |
+
key_states = key_states.to(target_dtype)
|
| 607 |
+
value_states = value_states.to(target_dtype)
|
| 608 |
+
|
| 609 |
+
# Reashape to the expected shape for Flash Attention
|
| 610 |
+
query_states = query_states.transpose(1, 2)
|
| 611 |
+
key_states = key_states.transpose(1, 2)
|
| 612 |
+
value_states = value_states.transpose(1, 2)
|
| 613 |
+
|
| 614 |
+
attn_output = _flash_attention_forward(
|
| 615 |
+
query_states,
|
| 616 |
+
key_states,
|
| 617 |
+
value_states,
|
| 618 |
+
attention_mask,
|
| 619 |
+
q_len,
|
| 620 |
+
position_ids=position_ids,
|
| 621 |
+
dropout=attn_dropout,
|
| 622 |
+
sliding_window=getattr(self.config, "sliding_window", None),
|
| 623 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 624 |
+
is_causal=self.is_causal,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 628 |
+
attn_output = self.o_proj(attn_output)
|
| 629 |
+
|
| 630 |
+
if not output_attentions:
|
| 631 |
+
attn_weights = None
|
| 632 |
+
|
| 633 |
+
return attn_output, attn_weights, past_key_value
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
|
| 637 |
+
# TODO @Arthur no longer copied from LLama after static cache
|
| 638 |
+
class Phi3SdpaAttention(Phi3Attention):
|
| 639 |
+
"""
|
| 640 |
+
Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 641 |
+
`Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 642 |
+
SDPA API.
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
+
# Adapted from Phi3Attention.forward
|
| 646 |
+
def forward(
|
| 647 |
+
self,
|
| 648 |
+
hidden_states: torch.Tensor,
|
| 649 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 650 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 651 |
+
past_key_value: Optional[Cache] = None,
|
| 652 |
+
output_attentions: bool = False,
|
| 653 |
+
use_cache: bool = False,
|
| 654 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 655 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 656 |
+
if output_attentions:
|
| 657 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 658 |
+
logger.warning_once(
|
| 659 |
+
"Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 660 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 661 |
+
)
|
| 662 |
+
return super().forward(
|
| 663 |
+
hidden_states=hidden_states,
|
| 664 |
+
attention_mask=attention_mask,
|
| 665 |
+
position_ids=position_ids,
|
| 666 |
+
past_key_value=past_key_value,
|
| 667 |
+
output_attentions=output_attentions,
|
| 668 |
+
use_cache=use_cache,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
bsz, q_len, _ = hidden_states.size()
|
| 672 |
+
|
| 673 |
+
qkv = self.qkv_proj(hidden_states)
|
| 674 |
+
query_pos = self.num_heads * self.head_dim
|
| 675 |
+
query_states = qkv[..., :query_pos]
|
| 676 |
+
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
|
| 677 |
+
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
|
| 678 |
+
|
| 679 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 680 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 681 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 682 |
+
|
| 683 |
+
kv_seq_len = key_states.shape[-2]
|
| 684 |
+
if past_key_value is not None:
|
| 685 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 686 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 687 |
+
|
| 688 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 689 |
+
|
| 690 |
+
if past_key_value is not None:
|
| 691 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 692 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 693 |
+
|
| 694 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 695 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 696 |
+
|
| 697 |
+
causal_mask = attention_mask
|
| 698 |
+
if attention_mask is not None:
|
| 699 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 700 |
+
|
| 701 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 702 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 703 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
| 704 |
+
query_states = query_states.contiguous()
|
| 705 |
+
key_states = key_states.contiguous()
|
| 706 |
+
value_states = value_states.contiguous()
|
| 707 |
+
|
| 708 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 709 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 710 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 711 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 712 |
+
|
| 713 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 714 |
+
query_states,
|
| 715 |
+
key_states,
|
| 716 |
+
value_states,
|
| 717 |
+
attn_mask=causal_mask,
|
| 718 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 719 |
+
is_causal=is_causal,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 723 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 724 |
+
|
| 725 |
+
attn_output = self.o_proj(attn_output)
|
| 726 |
+
|
| 727 |
+
return attn_output, None, past_key_value
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
PHI3_ATTENTION_CLASSES = {
|
| 731 |
+
"eager": Phi3Attention,
|
| 732 |
+
"flash_attention_2": Phi3FlashAttention2,
|
| 733 |
+
"sdpa": Phi3SdpaAttention,
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
class Phi3DecoderLayer(nn.Module):
|
| 738 |
+
def __init__(self, config: Phi3Config, layer_idx: int):
|
| 739 |
+
super().__init__()
|
| 740 |
+
|
| 741 |
+
self.config = config
|
| 742 |
+
self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
| 743 |
+
|
| 744 |
+
self.mlp = Phi3MLP(config)
|
| 745 |
+
self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 746 |
+
|
| 747 |
+
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
|
| 748 |
+
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
|
| 749 |
+
self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 750 |
+
|
| 751 |
+
def forward(
|
| 752 |
+
self,
|
| 753 |
+
hidden_states: torch.Tensor,
|
| 754 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 755 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 756 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 757 |
+
output_attentions: Optional[bool] = False,
|
| 758 |
+
use_cache: Optional[bool] = False,
|
| 759 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 760 |
+
**kwargs,
|
| 761 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 762 |
+
"""
|
| 763 |
+
Args:
|
| 764 |
+
hidden_states (`torch.FloatTensor`):
|
| 765 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 766 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 767 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 768 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 769 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
| 770 |
+
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
| 771 |
+
output_attentions (`bool`, *optional*):
|
| 772 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 773 |
+
returned tensors for more detail.
|
| 774 |
+
use_cache (`bool`, *optional*):
|
| 775 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 776 |
+
(see `past_key_values`).
|
| 777 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 778 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 779 |
+
Indices depicting the position of the input sequence tokens in the sequence
|
| 780 |
+
kwargs (`dict`, *optional*):
|
| 781 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 782 |
+
into the model
|
| 783 |
+
"""
|
| 784 |
+
|
| 785 |
+
residual = hidden_states
|
| 786 |
+
|
| 787 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 788 |
+
|
| 789 |
+
# Self Attention
|
| 790 |
+
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
|
| 791 |
+
hidden_states=hidden_states,
|
| 792 |
+
attention_mask=attention_mask,
|
| 793 |
+
position_ids=position_ids,
|
| 794 |
+
past_key_value=past_key_value,
|
| 795 |
+
output_attentions=output_attentions,
|
| 796 |
+
use_cache=use_cache,
|
| 797 |
+
cache_position=cache_position,
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
hidden_states = residual + self.resid_attn_dropout(attn_outputs)
|
| 801 |
+
|
| 802 |
+
residual = hidden_states
|
| 803 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 804 |
+
hidden_states = self.mlp(hidden_states)
|
| 805 |
+
hidden_states = residual + self.resid_mlp_dropout(hidden_states)
|
| 806 |
+
|
| 807 |
+
outputs = (hidden_states,)
|
| 808 |
+
|
| 809 |
+
if output_attentions:
|
| 810 |
+
outputs += (self_attn_weights,)
|
| 811 |
+
|
| 812 |
+
if use_cache:
|
| 813 |
+
outputs += (present_key_value,)
|
| 814 |
+
|
| 815 |
+
return outputs
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
PHI3_START_DOCSTRING = r"""
|
| 819 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 820 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 821 |
+
etc.)
|
| 822 |
+
|
| 823 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 824 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 825 |
+
and behavior.
|
| 826 |
+
|
| 827 |
+
Parameters:
|
| 828 |
+
config ([`Phi3Config`]):
|
| 829 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 830 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 831 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 832 |
+
"""
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
@add_start_docstrings(
|
| 836 |
+
"The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
|
| 837 |
+
PHI3_START_DOCSTRING,
|
| 838 |
+
)
|
| 839 |
+
class Phi3PreTrainedModel(PreTrainedModel):
|
| 840 |
+
config_class = Phi3Config
|
| 841 |
+
base_model_prefix = "model"
|
| 842 |
+
supports_gradient_checkpointing = True
|
| 843 |
+
_no_split_modules = ["Phi3DecoderLayer"]
|
| 844 |
+
_skip_keys_device_placement = "past_key_values"
|
| 845 |
+
_supports_flash_attn_2 = True
|
| 846 |
+
_supports_sdpa = True
|
| 847 |
+
_supports_cache_class = True
|
| 848 |
+
|
| 849 |
+
_version = "0.0.5"
|
| 850 |
+
|
| 851 |
+
def _init_weights(self, module):
|
| 852 |
+
std = self.config.initializer_range
|
| 853 |
+
if isinstance(module, nn.Linear):
|
| 854 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 855 |
+
if module.bias is not None:
|
| 856 |
+
module.bias.data.zero_()
|
| 857 |
+
elif isinstance(module, nn.Embedding):
|
| 858 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 859 |
+
if module.padding_idx is not None:
|
| 860 |
+
module.weight.data[module.padding_idx].zero_()
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
PHI3_INPUTS_DOCSTRING = r"""
|
| 864 |
+
Args:
|
| 865 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 866 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 867 |
+
it.
|
| 868 |
+
|
| 869 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 870 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 871 |
+
|
| 872 |
+
[What are input IDs?](../glossary#input-ids)
|
| 873 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 874 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 875 |
+
|
| 876 |
+
- 1 for tokens that are **not masked**,
|
| 877 |
+
- 0 for tokens that are **masked**.
|
| 878 |
+
|
| 879 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 880 |
+
|
| 881 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 882 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 883 |
+
|
| 884 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 885 |
+
`past_key_values`).
|
| 886 |
+
|
| 887 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 888 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 889 |
+
information on the default strategy.
|
| 890 |
+
|
| 891 |
+
- 1 indicates the head is **not masked**,
|
| 892 |
+
- 0 indicates the head is **masked**.
|
| 893 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 894 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 895 |
+
config.n_positions - 1]`.
|
| 896 |
+
|
| 897 |
+
[What are position IDs?](../glossary#position-ids)
|
| 898 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 899 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 900 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 901 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 902 |
+
|
| 903 |
+
Two formats are allowed:
|
| 904 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 905 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 906 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 907 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 908 |
+
cache format.
|
| 909 |
+
|
| 910 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 911 |
+
legacy cache format will be returned.
|
| 912 |
+
|
| 913 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 914 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 915 |
+
of shape `(batch_size, sequence_length)`.
|
| 916 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 917 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 918 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 919 |
+
model's internal embedding lookup matrix.
|
| 920 |
+
use_cache (`bool`, *optional*):
|
| 921 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 922 |
+
`past_key_values`).
|
| 923 |
+
output_attentions (`bool`, *optional*):
|
| 924 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 925 |
+
tensors for more detail.
|
| 926 |
+
output_hidden_states (`bool`, *optional*):
|
| 927 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 928 |
+
more detail.
|
| 929 |
+
return_dict (`bool`, *optional*):
|
| 930 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 931 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 932 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 933 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 934 |
+
the complete sequence length.
|
| 935 |
+
"""
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
@add_start_docstrings(
|
| 939 |
+
"The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
|
| 940 |
+
PHI3_START_DOCSTRING,
|
| 941 |
+
)
|
| 942 |
+
class Phi3Model(Phi3PreTrainedModel):
|
| 943 |
+
"""
|
| 944 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
config: Phi3Config
|
| 948 |
+
"""
|
| 949 |
+
|
| 950 |
+
def __init__(self, config: Phi3Config):
|
| 951 |
+
super().__init__(config)
|
| 952 |
+
self.padding_idx = config.pad_token_id
|
| 953 |
+
self.vocab_size = config.vocab_size
|
| 954 |
+
|
| 955 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 956 |
+
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
| 957 |
+
self.layers = nn.ModuleList(
|
| 958 |
+
[Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 959 |
+
)
|
| 960 |
+
self._attn_implementation = config._attn_implementation
|
| 961 |
+
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 962 |
+
|
| 963 |
+
self.gradient_checkpointing = False
|
| 964 |
+
# Initialize weights and apply final processing
|
| 965 |
+
self.post_init()
|
| 966 |
+
|
| 967 |
+
def get_input_embeddings(self):
|
| 968 |
+
return self.embed_tokens
|
| 969 |
+
|
| 970 |
+
def set_input_embeddings(self, value):
|
| 971 |
+
self.embed_tokens = value
|
| 972 |
+
|
| 973 |
+
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
| 974 |
+
def forward(
|
| 975 |
+
self,
|
| 976 |
+
input_ids: torch.LongTensor = None,
|
| 977 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 978 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 979 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 980 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 981 |
+
use_cache: Optional[bool] = None,
|
| 982 |
+
output_attentions: Optional[bool] = None,
|
| 983 |
+
output_hidden_states: Optional[bool] = None,
|
| 984 |
+
return_dict: Optional[bool] = None,
|
| 985 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 986 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 987 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 988 |
+
output_hidden_states = (
|
| 989 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 990 |
+
)
|
| 991 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 992 |
+
|
| 993 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 994 |
+
|
| 995 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 996 |
+
raise ValueError(
|
| 997 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
if self.gradient_checkpointing and self.training:
|
| 1001 |
+
if use_cache:
|
| 1002 |
+
logger.warning_once(
|
| 1003 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1004 |
+
)
|
| 1005 |
+
use_cache = False
|
| 1006 |
+
|
| 1007 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 1008 |
+
return_legacy_cache = False
|
| 1009 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 1010 |
+
return_legacy_cache = True
|
| 1011 |
+
if past_key_values is None:
|
| 1012 |
+
past_key_values = DynamicCache()
|
| 1013 |
+
else:
|
| 1014 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1015 |
+
logger.warning_once(
|
| 1016 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 1017 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 1018 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
if inputs_embeds is None:
|
| 1022 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1023 |
+
|
| 1024 |
+
if cache_position is None:
|
| 1025 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1026 |
+
cache_position = torch.arange(
|
| 1027 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 1028 |
+
)
|
| 1029 |
+
if position_ids is None:
|
| 1030 |
+
position_ids = cache_position.unsqueeze(0)
|
| 1031 |
+
|
| 1032 |
+
causal_mask = self._update_causal_mask(
|
| 1033 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
hidden_states = inputs_embeds
|
| 1037 |
+
|
| 1038 |
+
# decoder layers
|
| 1039 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1040 |
+
all_self_attns = () if output_attentions else None
|
| 1041 |
+
next_decoder_cache = None
|
| 1042 |
+
|
| 1043 |
+
for decoder_layer in self.layers:
|
| 1044 |
+
if output_hidden_states:
|
| 1045 |
+
all_hidden_states += (hidden_states,)
|
| 1046 |
+
|
| 1047 |
+
if self.gradient_checkpointing and self.training:
|
| 1048 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1049 |
+
decoder_layer.__call__,
|
| 1050 |
+
hidden_states,
|
| 1051 |
+
causal_mask,
|
| 1052 |
+
position_ids,
|
| 1053 |
+
past_key_values,
|
| 1054 |
+
output_attentions,
|
| 1055 |
+
use_cache,
|
| 1056 |
+
cache_position,
|
| 1057 |
+
)
|
| 1058 |
+
else:
|
| 1059 |
+
layer_outputs = decoder_layer(
|
| 1060 |
+
hidden_states,
|
| 1061 |
+
attention_mask=causal_mask,
|
| 1062 |
+
position_ids=position_ids,
|
| 1063 |
+
past_key_value=past_key_values,
|
| 1064 |
+
output_attentions=output_attentions,
|
| 1065 |
+
use_cache=use_cache,
|
| 1066 |
+
cache_position=cache_position,
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
hidden_states = layer_outputs[0]
|
| 1070 |
+
|
| 1071 |
+
if use_cache:
|
| 1072 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 1073 |
+
|
| 1074 |
+
if output_attentions:
|
| 1075 |
+
all_self_attns += (layer_outputs[1],)
|
| 1076 |
+
|
| 1077 |
+
hidden_states = self.norm(hidden_states)
|
| 1078 |
+
|
| 1079 |
+
# add hidden states from the last decoder layer
|
| 1080 |
+
if output_hidden_states:
|
| 1081 |
+
all_hidden_states += (hidden_states,)
|
| 1082 |
+
|
| 1083 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 1084 |
+
if return_legacy_cache:
|
| 1085 |
+
next_cache = next_cache.to_legacy_cache()
|
| 1086 |
+
|
| 1087 |
+
if not return_dict:
|
| 1088 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 1089 |
+
return BaseModelOutputWithPast(
|
| 1090 |
+
last_hidden_state=hidden_states,
|
| 1091 |
+
past_key_values=next_cache,
|
| 1092 |
+
hidden_states=all_hidden_states,
|
| 1093 |
+
attentions=all_self_attns,
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
| 1097 |
+
def _update_causal_mask(
|
| 1098 |
+
self,
|
| 1099 |
+
attention_mask: torch.Tensor,
|
| 1100 |
+
input_tensor: torch.Tensor,
|
| 1101 |
+
cache_position: torch.Tensor,
|
| 1102 |
+
past_key_values: Cache,
|
| 1103 |
+
output_attentions: bool,
|
| 1104 |
+
):
|
| 1105 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1106 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1107 |
+
return attention_mask
|
| 1108 |
+
return None
|
| 1109 |
+
|
| 1110 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1111 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1112 |
+
# to infer the attention mask.
|
| 1113 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1114 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1115 |
+
|
| 1116 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1117 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 1118 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1119 |
+
attention_mask,
|
| 1120 |
+
inputs_embeds=input_tensor,
|
| 1121 |
+
past_key_values_length=past_seen_tokens,
|
| 1122 |
+
is_training=self.training,
|
| 1123 |
+
):
|
| 1124 |
+
return None
|
| 1125 |
+
|
| 1126 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 1127 |
+
min_dtype = torch.finfo(dtype).min
|
| 1128 |
+
sequence_length = input_tensor.shape[1]
|
| 1129 |
+
if using_static_cache:
|
| 1130 |
+
target_length = past_key_values.get_max_length()
|
| 1131 |
+
else:
|
| 1132 |
+
target_length = (
|
| 1133 |
+
attention_mask.shape[-1]
|
| 1134 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1135 |
+
else past_seen_tokens + sequence_length + 1
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1139 |
+
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1140 |
+
attention_mask,
|
| 1141 |
+
sequence_length=sequence_length,
|
| 1142 |
+
target_length=target_length,
|
| 1143 |
+
dtype=dtype,
|
| 1144 |
+
device=device,
|
| 1145 |
+
min_dtype=min_dtype,
|
| 1146 |
+
cache_position=cache_position,
|
| 1147 |
+
batch_size=input_tensor.shape[0],
|
| 1148 |
+
)
|
| 1149 |
+
|
| 1150 |
+
if (
|
| 1151 |
+
self.config._attn_implementation == "sdpa"
|
| 1152 |
+
and attention_mask is not None
|
| 1153 |
+
and attention_mask.device.type == "cuda"
|
| 1154 |
+
and not output_attentions
|
| 1155 |
+
):
|
| 1156 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1157 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1158 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1159 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1160 |
+
|
| 1161 |
+
return causal_mask
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
| 1165 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1166 |
+
|
| 1167 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
|
| 1168 |
+
def __init__(self, config):
|
| 1169 |
+
super().__init__(config)
|
| 1170 |
+
self.model = Phi3Model(config)
|
| 1171 |
+
self.vocab_size = config.vocab_size
|
| 1172 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1173 |
+
|
| 1174 |
+
# Initialize weights and apply final processing
|
| 1175 |
+
self.post_init()
|
| 1176 |
+
|
| 1177 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
|
| 1178 |
+
def get_input_embeddings(self):
|
| 1179 |
+
return self.model.embed_tokens
|
| 1180 |
+
|
| 1181 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
|
| 1182 |
+
def set_input_embeddings(self, value):
|
| 1183 |
+
self.model.embed_tokens = value
|
| 1184 |
+
|
| 1185 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
|
| 1186 |
+
def get_output_embeddings(self):
|
| 1187 |
+
return self.lm_head
|
| 1188 |
+
|
| 1189 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
|
| 1190 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1191 |
+
self.lm_head = new_embeddings
|
| 1192 |
+
|
| 1193 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
| 1194 |
+
def set_decoder(self, decoder):
|
| 1195 |
+
self.model = decoder
|
| 1196 |
+
|
| 1197 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
| 1198 |
+
def get_decoder(self):
|
| 1199 |
+
return self.model
|
| 1200 |
+
|
| 1201 |
+
# Ignore copy
|
| 1202 |
+
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
| 1203 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1204 |
+
def forward(
|
| 1205 |
+
self,
|
| 1206 |
+
input_ids: torch.LongTensor = None,
|
| 1207 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1208 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1209 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1210 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1211 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1212 |
+
use_cache: Optional[bool] = None,
|
| 1213 |
+
output_attentions: Optional[bool] = None,
|
| 1214 |
+
output_hidden_states: Optional[bool] = None,
|
| 1215 |
+
return_dict: Optional[bool] = None,
|
| 1216 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1217 |
+
num_logits_to_keep: int = 0,
|
| 1218 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1219 |
+
r"""
|
| 1220 |
+
Args:
|
| 1221 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1222 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1223 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1224 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1225 |
+
|
| 1226 |
+
num_logits_to_keep (`int`, *optional*):
|
| 1227 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
| 1228 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 1229 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 1230 |
+
|
| 1231 |
+
Returns:
|
| 1232 |
+
|
| 1233 |
+
Example:
|
| 1234 |
+
|
| 1235 |
+
```python
|
| 1236 |
+
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
|
| 1237 |
+
|
| 1238 |
+
>>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
| 1239 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
| 1240 |
+
|
| 1241 |
+
>>> prompt = "This is an example script ."
|
| 1242 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1243 |
+
|
| 1244 |
+
>>> # Generate
|
| 1245 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1246 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1247 |
+
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
| 1248 |
+
```"""
|
| 1249 |
+
if (
|
| 1250 |
+
use_cache
|
| 1251 |
+
and self.config.rope_scaling
|
| 1252 |
+
and cache_position is not None
|
| 1253 |
+
and cache_position[0] == self.config.original_max_position_embeddings
|
| 1254 |
+
):
|
| 1255 |
+
logger.warning(
|
| 1256 |
+
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
| 1257 |
+
)
|
| 1258 |
+
|
| 1259 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1260 |
+
output_hidden_states = (
|
| 1261 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1262 |
+
)
|
| 1263 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1264 |
+
|
| 1265 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1266 |
+
outputs = self.model(
|
| 1267 |
+
input_ids=input_ids,
|
| 1268 |
+
attention_mask=attention_mask,
|
| 1269 |
+
position_ids=position_ids,
|
| 1270 |
+
past_key_values=past_key_values,
|
| 1271 |
+
inputs_embeds=inputs_embeds,
|
| 1272 |
+
use_cache=use_cache,
|
| 1273 |
+
output_attentions=output_attentions,
|
| 1274 |
+
output_hidden_states=output_hidden_states,
|
| 1275 |
+
return_dict=return_dict,
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
hidden_states = outputs[0]
|
| 1279 |
+
if labels is None and not is_torchdynamo_compiling():
|
| 1280 |
+
logger.warning_once(
|
| 1281 |
+
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
| 1282 |
+
)
|
| 1283 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1284 |
+
# TODO: remove the float() operation in v4.46
|
| 1285 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
| 1286 |
+
|
| 1287 |
+
loss = None
|
| 1288 |
+
if labels is not None:
|
| 1289 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 1290 |
+
logits = logits.float()
|
| 1291 |
+
# Shift so that tokens < n predict n
|
| 1292 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1293 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1294 |
+
# Flatten the tokens
|
| 1295 |
+
loss_fct = CrossEntropyLoss()
|
| 1296 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1297 |
+
shift_labels = shift_labels.view(-1)
|
| 1298 |
+
# Enable model parallelism
|
| 1299 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1300 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1301 |
+
|
| 1302 |
+
if not return_dict:
|
| 1303 |
+
output = (logits,) + outputs[1:]
|
| 1304 |
+
return (loss,) + output if loss is not None else output
|
| 1305 |
+
|
| 1306 |
+
return CausalLMOutputWithPast(
|
| 1307 |
+
loss=loss,
|
| 1308 |
+
logits=logits,
|
| 1309 |
+
past_key_values=outputs.past_key_values,
|
| 1310 |
+
hidden_states=outputs.hidden_states,
|
| 1311 |
+
attentions=outputs.attentions,
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
def prepare_inputs_for_generation(
|
| 1315 |
+
self,
|
| 1316 |
+
input_ids,
|
| 1317 |
+
past_key_values=None,
|
| 1318 |
+
attention_mask=None,
|
| 1319 |
+
inputs_embeds=None,
|
| 1320 |
+
cache_position=None,
|
| 1321 |
+
position_ids=None,
|
| 1322 |
+
use_cache=True,
|
| 1323 |
+
num_logits_to_keep=None,
|
| 1324 |
+
**kwargs,
|
| 1325 |
+
):
|
| 1326 |
+
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
| 1327 |
+
# It will cause downside of slower at this single token position, however, better than current failure.
|
| 1328 |
+
if (
|
| 1329 |
+
past_key_values
|
| 1330 |
+
and self.config.rope_scaling
|
| 1331 |
+
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
|
| 1332 |
+
):
|
| 1333 |
+
past_length = cache_position[0]
|
| 1334 |
+
if past_length <= self.config.original_max_position_embeddings:
|
| 1335 |
+
past_key_values = None
|
| 1336 |
+
|
| 1337 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 1338 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 1339 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 1340 |
+
if past_key_values is not None:
|
| 1341 |
+
if inputs_embeds is not None: # Exception 1
|
| 1342 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 1343 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 1344 |
+
input_ids = input_ids[:, cache_position]
|
| 1345 |
+
|
| 1346 |
+
if attention_mask is not None and position_ids is None:
|
| 1347 |
+
# create position_ids on the fly for batch generation
|
| 1348 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1349 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1350 |
+
if past_key_values:
|
| 1351 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1352 |
+
|
| 1353 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
| 1354 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
| 1355 |
+
|
| 1356 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1357 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
| 1358 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
| 1359 |
+
else:
|
| 1360 |
+
# The clone here is for the same reason as for `position_ids`.
|
| 1361 |
+
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
| 1362 |
+
|
| 1363 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
| 1364 |
+
if model_inputs["inputs_embeds"] is not None:
|
| 1365 |
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
| 1366 |
+
device = model_inputs["inputs_embeds"].device
|
| 1367 |
+
else:
|
| 1368 |
+
batch_size, sequence_length = model_inputs["input_ids"].shape
|
| 1369 |
+
device = model_inputs["input_ids"].device
|
| 1370 |
+
|
| 1371 |
+
dtype = self.lm_head.weight.dtype
|
| 1372 |
+
min_dtype = torch.finfo(dtype).min
|
| 1373 |
+
|
| 1374 |
+
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1375 |
+
attention_mask,
|
| 1376 |
+
sequence_length=sequence_length,
|
| 1377 |
+
target_length=past_key_values.get_max_length(),
|
| 1378 |
+
dtype=dtype,
|
| 1379 |
+
device=device,
|
| 1380 |
+
min_dtype=min_dtype,
|
| 1381 |
+
cache_position=cache_position,
|
| 1382 |
+
batch_size=batch_size,
|
| 1383 |
+
)
|
| 1384 |
+
|
| 1385 |
+
if num_logits_to_keep is not None:
|
| 1386 |
+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
| 1387 |
+
|
| 1388 |
+
model_inputs.update(
|
| 1389 |
+
{
|
| 1390 |
+
"position_ids": position_ids,
|
| 1391 |
+
"cache_position": cache_position,
|
| 1392 |
+
"past_key_values": past_key_values,
|
| 1393 |
+
"use_cache": use_cache,
|
| 1394 |
+
"attention_mask": attention_mask,
|
| 1395 |
+
}
|
| 1396 |
+
)
|
| 1397 |
+
return model_inputs
|
| 1398 |
+
|
| 1399 |
+
|
| 1400 |
+
@add_start_docstrings(
|
| 1401 |
+
"""
|
| 1402 |
+
The [`Phi3Model`] with a sequence classification head on top (linear layer).
|
| 1403 |
+
|
| 1404 |
+
[`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 1405 |
+
(e.g. GPT-2) do.
|
| 1406 |
+
|
| 1407 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 1408 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 1409 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 1410 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1411 |
+
each row of the batch).
|
| 1412 |
+
""",
|
| 1413 |
+
PHI3_START_DOCSTRING,
|
| 1414 |
+
)
|
| 1415 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
|
| 1416 |
+
class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
| 1417 |
+
def __init__(self, config):
|
| 1418 |
+
super().__init__(config)
|
| 1419 |
+
self.num_labels = config.num_labels
|
| 1420 |
+
self.model = Phi3Model(config)
|
| 1421 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1422 |
+
|
| 1423 |
+
# Initialize weights and apply final processing
|
| 1424 |
+
self.post_init()
|
| 1425 |
+
|
| 1426 |
+
def get_input_embeddings(self):
|
| 1427 |
+
return self.model.embed_tokens
|
| 1428 |
+
|
| 1429 |
+
def set_input_embeddings(self, value):
|
| 1430 |
+
self.model.embed_tokens = value
|
| 1431 |
+
|
| 1432 |
+
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
| 1433 |
+
def forward(
|
| 1434 |
+
self,
|
| 1435 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1436 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1437 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1438 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 1439 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1440 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1441 |
+
use_cache: Optional[bool] = None,
|
| 1442 |
+
output_attentions: Optional[bool] = None,
|
| 1443 |
+
output_hidden_states: Optional[bool] = None,
|
| 1444 |
+
return_dict: Optional[bool] = None,
|
| 1445 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 1446 |
+
r"""
|
| 1447 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1448 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1449 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1450 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1451 |
+
"""
|
| 1452 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1453 |
+
|
| 1454 |
+
model_outputs = self.model(
|
| 1455 |
+
input_ids,
|
| 1456 |
+
attention_mask=attention_mask,
|
| 1457 |
+
position_ids=position_ids,
|
| 1458 |
+
past_key_values=past_key_values,
|
| 1459 |
+
inputs_embeds=inputs_embeds,
|
| 1460 |
+
use_cache=use_cache,
|
| 1461 |
+
output_attentions=output_attentions,
|
| 1462 |
+
output_hidden_states=output_hidden_states,
|
| 1463 |
+
return_dict=return_dict,
|
| 1464 |
+
)
|
| 1465 |
+
hidden_states = model_outputs[0]
|
| 1466 |
+
logits = self.score(hidden_states)
|
| 1467 |
+
|
| 1468 |
+
if input_ids is not None:
|
| 1469 |
+
batch_size = input_ids.shape[0]
|
| 1470 |
+
else:
|
| 1471 |
+
batch_size = inputs_embeds.shape[0]
|
| 1472 |
+
|
| 1473 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 1474 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 1475 |
+
if self.config.pad_token_id is None:
|
| 1476 |
+
sequence_lengths = -1
|
| 1477 |
+
else:
|
| 1478 |
+
if input_ids is not None:
|
| 1479 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 1480 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 1481 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 1482 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 1483 |
+
else:
|
| 1484 |
+
sequence_lengths = -1
|
| 1485 |
+
|
| 1486 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 1487 |
+
|
| 1488 |
+
loss = None
|
| 1489 |
+
if labels is not None:
|
| 1490 |
+
labels = labels.to(logits.device)
|
| 1491 |
+
if self.config.problem_type is None:
|
| 1492 |
+
if self.num_labels == 1:
|
| 1493 |
+
self.config.problem_type = "regression"
|
| 1494 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1495 |
+
self.config.problem_type = "single_label_classification"
|
| 1496 |
+
else:
|
| 1497 |
+
self.config.problem_type = "multi_label_classification"
|
| 1498 |
+
|
| 1499 |
+
if self.config.problem_type == "regression":
|
| 1500 |
+
loss_fct = MSELoss()
|
| 1501 |
+
if self.num_labels == 1:
|
| 1502 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 1503 |
+
else:
|
| 1504 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1505 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1506 |
+
loss_fct = CrossEntropyLoss()
|
| 1507 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 1508 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1509 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1510 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1511 |
+
if not return_dict:
|
| 1512 |
+
output = (pooled_logits,) + model_outputs[1:]
|
| 1513 |
+
return ((loss,) + output) if loss is not None else output
|
| 1514 |
+
|
| 1515 |
+
return SequenceClassifierOutputWithPast(
|
| 1516 |
+
loss=loss,
|
| 1517 |
+
logits=pooled_logits,
|
| 1518 |
+
past_key_values=model_outputs.past_key_values,
|
| 1519 |
+
hidden_states=model_outputs.hidden_states,
|
| 1520 |
+
attentions=model_outputs.attentions,
|
| 1521 |
+
)
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
@add_start_docstrings(
|
| 1525 |
+
"""
|
| 1526 |
+
[`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1527 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1528 |
+
""",
|
| 1529 |
+
PHI3_START_DOCSTRING,
|
| 1530 |
+
)
|
| 1531 |
+
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
|
| 1532 |
+
class Phi3ForTokenClassification(Phi3PreTrainedModel):
|
| 1533 |
+
def __init__(self, config: Phi3Config):
|
| 1534 |
+
super().__init__(config)
|
| 1535 |
+
self.num_labels = config.num_labels
|
| 1536 |
+
|
| 1537 |
+
self.model = Phi3Model(config)
|
| 1538 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
| 1539 |
+
classifier_dropout = config.classifier_dropout
|
| 1540 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
| 1541 |
+
classifier_dropout = config.hidden_dropout
|
| 1542 |
+
else:
|
| 1543 |
+
classifier_dropout = 0.1
|
| 1544 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1545 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1546 |
+
|
| 1547 |
+
# Initialize weights and apply final processing
|
| 1548 |
+
self.post_init()
|
| 1549 |
+
|
| 1550 |
+
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
| 1551 |
+
@add_code_sample_docstrings(
|
| 1552 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1553 |
+
output_type=TokenClassifierOutput,
|
| 1554 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1555 |
+
)
|
| 1556 |
+
def forward(
|
| 1557 |
+
self,
|
| 1558 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1559 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 1560 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1561 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1562 |
+
labels: Optional[torch.Tensor] = None,
|
| 1563 |
+
use_cache: Optional[bool] = None,
|
| 1564 |
+
output_attentions: Optional[bool] = None,
|
| 1565 |
+
output_hidden_states: Optional[bool] = None,
|
| 1566 |
+
return_dict: Optional[bool] = None,
|
| 1567 |
+
**deprecated_arguments,
|
| 1568 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 1569 |
+
r"""
|
| 1570 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1571 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1572 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1573 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1574 |
+
"""
|
| 1575 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1576 |
+
|
| 1577 |
+
model_outputs = self.model(
|
| 1578 |
+
input_ids,
|
| 1579 |
+
past_key_values=past_key_values,
|
| 1580 |
+
attention_mask=attention_mask,
|
| 1581 |
+
inputs_embeds=inputs_embeds,
|
| 1582 |
+
use_cache=use_cache,
|
| 1583 |
+
output_attentions=output_attentions,
|
| 1584 |
+
output_hidden_states=output_hidden_states,
|
| 1585 |
+
return_dict=return_dict,
|
| 1586 |
+
)
|
| 1587 |
+
|
| 1588 |
+
hidden_states = model_outputs[0]
|
| 1589 |
+
hidden_states = self.dropout(hidden_states)
|
| 1590 |
+
logits = self.classifier(hidden_states)
|
| 1591 |
+
|
| 1592 |
+
loss = None
|
| 1593 |
+
if labels is not None:
|
| 1594 |
+
# move labels to correct device to enable model parallelism
|
| 1595 |
+
labels = labels.to(logits.device)
|
| 1596 |
+
batch_size, seq_length = labels.shape
|
| 1597 |
+
loss_fct = CrossEntropyLoss()
|
| 1598 |
+
loss = loss_fct(
|
| 1599 |
+
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
| 1600 |
+
)
|
| 1601 |
+
|
| 1602 |
+
if not return_dict:
|
| 1603 |
+
output = (logits,) + model_outputs[2:]
|
| 1604 |
+
return ((loss,) + output) if loss is not None else output
|
| 1605 |
+
|
| 1606 |
+
return TokenClassifierOutput(
|
| 1607 |
+
loss=loss,
|
| 1608 |
+
logits=logits,
|
| 1609 |
+
hidden_states=model_outputs.hidden_states,
|
| 1610 |
+
attentions=model_outputs.attentions,
|
| 1611 |
+
)
|
src/image_decoder/processor.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def crop_arr(pil_image, max_image_size):
|
| 13 |
+
while min(*pil_image.size) >= 2 * max_image_size:
|
| 14 |
+
pil_image = pil_image.resize(
|
| 15 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
if max(*pil_image.size) > max_image_size:
|
| 19 |
+
scale = max_image_size / max(*pil_image.size)
|
| 20 |
+
pil_image = pil_image.resize(
|
| 21 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
if min(*pil_image.size) < 16:
|
| 25 |
+
scale = 16 / min(*pil_image.size)
|
| 26 |
+
pil_image = pil_image.resize(
|
| 27 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
arr = np.array(pil_image)
|
| 31 |
+
crop_y1 = (arr.shape[0] % 16) // 2
|
| 32 |
+
crop_y2 = arr.shape[0] % 16 - crop_y1
|
| 33 |
+
|
| 34 |
+
crop_x1 = (arr.shape[1] % 16) // 2
|
| 35 |
+
crop_x2 = arr.shape[1] % 16 - crop_x1
|
| 36 |
+
|
| 37 |
+
arr = arr[crop_y1:arr.shape[0] - crop_y2, crop_x1:arr.shape[1] - crop_x2]
|
| 38 |
+
return Image.fromarray(arr)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class OmniGenProcessor:
|
| 42 |
+
def __init__(self, max_image_size: int = 1024):
|
| 43 |
+
self.max_image_size = max_image_size
|
| 44 |
+
|
| 45 |
+
self.image_transform = transforms.Compose([
|
| 46 |
+
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
|
| 47 |
+
transforms.ToTensor(),
|
| 48 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
self.collator = OmniGenCollator()
|
| 52 |
+
self.separate_collator = OmniGenSeparateCollator()
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def from_pretrained(cls, model_name):
|
| 56 |
+
if not os.path.exists(model_name):
|
| 57 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
| 58 |
+
model_name = snapshot_download(repo_id=model_name,
|
| 59 |
+
cache_dir=cache_folder,
|
| 60 |
+
allow_patterns="*.json")
|
| 61 |
+
text_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 62 |
+
|
| 63 |
+
return cls(text_tokenizer)
|
| 64 |
+
|
| 65 |
+
def process_image(self, image):
|
| 66 |
+
image = Image.open(image).convert('RGB')
|
| 67 |
+
return self.image_transform(image)
|
| 68 |
+
|
| 69 |
+
def __call__(self,
|
| 70 |
+
context_hidden_state: List[torch.tensor],
|
| 71 |
+
neg_context_hidden_state: List[torch.tensor],
|
| 72 |
+
height: int = 1024,
|
| 73 |
+
width: int = 1024,
|
| 74 |
+
separate_cfg_input: bool = False,
|
| 75 |
+
) -> Dict:
|
| 76 |
+
|
| 77 |
+
input_data = []
|
| 78 |
+
for i in range(len(context_hidden_state)):
|
| 79 |
+
cur_context_hidden_state = context_hidden_state[i]
|
| 80 |
+
cur_neg_context_hidden_state = neg_context_hidden_state[i]
|
| 81 |
+
|
| 82 |
+
input_data.append((cur_context_hidden_state, cur_neg_context_hidden_state, [height, width]))
|
| 83 |
+
|
| 84 |
+
if separate_cfg_input:
|
| 85 |
+
return self.separate_collator(input_data)
|
| 86 |
+
return self.collator(input_data)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class OmniGenCollator:
|
| 90 |
+
def __init__(self, pad_token_id=2, llm_pad_token_id=151643, hidden_size=3072):
|
| 91 |
+
self.llm_pad_token_id = llm_pad_token_id
|
| 92 |
+
self.pad_token_id = pad_token_id
|
| 93 |
+
self.hidden_size = hidden_size
|
| 94 |
+
|
| 95 |
+
def create_position(self, attention_mask, num_tokens_for_output_images):
|
| 96 |
+
position_ids = []
|
| 97 |
+
text_length = attention_mask.size(-1)
|
| 98 |
+
img_length = max(num_tokens_for_output_images)
|
| 99 |
+
for mask in attention_mask:
|
| 100 |
+
temp_l = torch.sum(mask)
|
| 101 |
+
temp_position = [0] * (text_length - temp_l) + [i for i in range(temp_l + img_length + 1)] # we add a time embedding into the sequence, so add one more token
|
| 102 |
+
position_ids.append(temp_position)
|
| 103 |
+
return torch.LongTensor(position_ids)
|
| 104 |
+
|
| 105 |
+
def create_connector_position(self, llm_2d_attention_mask):
|
| 106 |
+
position_ids = []
|
| 107 |
+
text_length = llm_2d_attention_mask.size(-1)
|
| 108 |
+
# img_length = max(num_tokens_for_output_images)
|
| 109 |
+
for batch_idx, mask in enumerate(llm_2d_attention_mask):
|
| 110 |
+
temp_l = torch.sum(llm_2d_attention_mask[batch_idx])
|
| 111 |
+
# temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
|
| 112 |
+
temp_position = [0] * (text_length - temp_l) + [i for i in range(temp_l)] # only condition for mllm like qwen
|
| 113 |
+
position_ids.append(temp_position)
|
| 114 |
+
return torch.LongTensor(position_ids)
|
| 115 |
+
|
| 116 |
+
def create_mask(self, attention_mask, num_tokens_for_output_images):
|
| 117 |
+
extended_mask = []
|
| 118 |
+
padding_images = []
|
| 119 |
+
text_length = attention_mask.size(-1)
|
| 120 |
+
img_length = max(num_tokens_for_output_images)
|
| 121 |
+
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
|
| 122 |
+
inx = 0
|
| 123 |
+
for mask in attention_mask:
|
| 124 |
+
temp_l = torch.sum(mask)
|
| 125 |
+
pad_l = text_length - temp_l
|
| 126 |
+
|
| 127 |
+
temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
|
| 128 |
+
|
| 129 |
+
image_mask = torch.zeros(size=(temp_l + 1, img_length))
|
| 130 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
|
| 131 |
+
|
| 132 |
+
image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
|
| 133 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
|
| 134 |
+
|
| 135 |
+
if pad_l > 0:
|
| 136 |
+
pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
|
| 137 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
|
| 138 |
+
|
| 139 |
+
pad_mask = torch.ones(size=(pad_l, seq_len))
|
| 140 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
|
| 141 |
+
|
| 142 |
+
true_img_length = num_tokens_for_output_images[inx]
|
| 143 |
+
pad_img_length = img_length - true_img_length
|
| 144 |
+
if pad_img_length > 0:
|
| 145 |
+
temp_mask[:, -pad_img_length:] = 0
|
| 146 |
+
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
|
| 147 |
+
else:
|
| 148 |
+
temp_padding_imgs = None
|
| 149 |
+
|
| 150 |
+
extended_mask.append(temp_mask.unsqueeze(0))
|
| 151 |
+
padding_images.append(temp_padding_imgs)
|
| 152 |
+
inx += 1
|
| 153 |
+
return torch.cat(extended_mask, dim=0), padding_images
|
| 154 |
+
|
| 155 |
+
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
|
| 156 |
+
for b_inx in image_sizes.keys():
|
| 157 |
+
for start_inx, end_inx in image_sizes[b_inx]:
|
| 158 |
+
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
|
| 159 |
+
|
| 160 |
+
return attention_mask
|
| 161 |
+
|
| 162 |
+
def pad_input(self, context_hidden_state):
|
| 163 |
+
# pad_token_id = self.llm_pad_token_id # 151642 <|endoftext|> in qwen2.5vl
|
| 164 |
+
max_l = max([x.shape[1] for x in context_hidden_state])
|
| 165 |
+
attention_mask = []
|
| 166 |
+
|
| 167 |
+
for i in range(len(context_hidden_state)):
|
| 168 |
+
temp_hidden = context_hidden_state[i]
|
| 169 |
+
temp_l = temp_hidden.shape[1]
|
| 170 |
+
pad_l = max_l - temp_l
|
| 171 |
+
if pad_l == 0:
|
| 172 |
+
attention_mask.append([1] * max_l)
|
| 173 |
+
else:
|
| 174 |
+
attention_mask.append([0] * pad_l + [1] * temp_l)
|
| 175 |
+
|
| 176 |
+
return torch.LongTensor(attention_mask)
|
| 177 |
+
|
| 178 |
+
def process_mllm_input(self, context_hidden_state, target_img_size):
|
| 179 |
+
num_tokens_for_output_images = []
|
| 180 |
+
for img_size in target_img_size:
|
| 181 |
+
num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
|
| 182 |
+
|
| 183 |
+
llm_2d_attention_mask = self.pad_input(context_hidden_state)
|
| 184 |
+
connector_position_ids = self.create_connector_position(llm_2d_attention_mask)
|
| 185 |
+
llm_position_ids = self.create_position(llm_2d_attention_mask, num_tokens_for_output_images)
|
| 186 |
+
llm_attention_mask, _ = self.create_mask(llm_2d_attention_mask, num_tokens_for_output_images)
|
| 187 |
+
|
| 188 |
+
return llm_2d_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class OmniGenSeparateCollator(OmniGenCollator):
|
| 192 |
+
def __call__(self, features):
|
| 193 |
+
context_hidden_state = [f[0] for f in features]
|
| 194 |
+
neg_context_hidden_state = [f[1] for f in features]
|
| 195 |
+
target_img_size = [f[2] for f in features]
|
| 196 |
+
|
| 197 |
+
all_context_hidden_state, all_connector_attention_mask, all_connector_position_ids, all_llm_attention_mask, all_llm_position_ids = [], [], [], [], []
|
| 198 |
+
connector_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids = self.process_mllm_input(context_hidden_state, target_img_size)
|
| 199 |
+
|
| 200 |
+
all_context_hidden_state.append(context_hidden_state[0])
|
| 201 |
+
all_connector_attention_mask.append(connector_attention_mask)
|
| 202 |
+
all_connector_position_ids.append(connector_position_ids)
|
| 203 |
+
all_llm_attention_mask.append(llm_attention_mask)
|
| 204 |
+
all_llm_position_ids.append(llm_position_ids)
|
| 205 |
+
|
| 206 |
+
if neg_context_hidden_state[0] is not None:
|
| 207 |
+
connector_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids = self.process_mllm_input(neg_context_hidden_state, target_img_size)
|
| 208 |
+
all_context_hidden_state.append(neg_context_hidden_state[0])
|
| 209 |
+
all_connector_attention_mask.append(connector_attention_mask)
|
| 210 |
+
all_connector_position_ids.append(connector_position_ids)
|
| 211 |
+
all_llm_attention_mask.append(llm_attention_mask)
|
| 212 |
+
all_llm_position_ids.append(llm_position_ids)
|
| 213 |
+
|
| 214 |
+
data = {
|
| 215 |
+
"context_hidden_state": all_context_hidden_state,
|
| 216 |
+
"connector_attention_mask": all_connector_attention_mask,
|
| 217 |
+
"connector_position_ids": all_connector_position_ids,
|
| 218 |
+
"llm_attention_mask": all_llm_attention_mask,
|
| 219 |
+
"llm_position_ids": all_llm_position_ids,
|
| 220 |
+
}
|
| 221 |
+
return data
|
src/image_decoder/scheduler.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
from typing import Optional, Dict, Any, Tuple, List
|
| 3 |
+
import gc
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
try:
|
| 7 |
+
import torch_npu
|
| 8 |
+
except Exception as e:
|
| 9 |
+
print(e)
|
| 10 |
+
from transformers.cache_utils import DynamicCache
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OmniGenCache(DynamicCache):
|
| 14 |
+
def __init__(self, num_tokens_for_img: int, offload_kv_cache: bool = False) -> None:
|
| 15 |
+
# if not torch.cuda.is_available():
|
| 16 |
+
# # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
|
| 17 |
+
# # offload_kv_cache = False
|
| 18 |
+
# raise RuntimeError("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.original_device = []
|
| 21 |
+
self.prefetch_stream = torch.cuda.Stream() if torch.cuda.is_available() else torch_npu.npu.Stream()
|
| 22 |
+
self.num_tokens_for_img = num_tokens_for_img
|
| 23 |
+
self.offload_kv_cache = offload_kv_cache
|
| 24 |
+
|
| 25 |
+
def prefetch_layer(self, layer_idx: int):
|
| 26 |
+
"Starts prefetching the next layer cache"
|
| 27 |
+
if layer_idx < len(self):
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
with torch.cuda.stream(self.prefetch_stream):
|
| 30 |
+
# Prefetch next layer tensors to GPU
|
| 31 |
+
device = self.original_device[layer_idx]
|
| 32 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
| 33 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
|
| 34 |
+
else:
|
| 35 |
+
with torch_npu.npu.stream(self.prefetch_stream):
|
| 36 |
+
# Prefetch next layer tensors to GPU
|
| 37 |
+
device = self.original_device[layer_idx]
|
| 38 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
| 39 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
|
| 40 |
+
|
| 41 |
+
def evict_previous_layer(self, layer_idx: int):
|
| 42 |
+
"Moves the previous layer cache to the CPU"
|
| 43 |
+
if len(self) > 2:
|
| 44 |
+
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
|
| 45 |
+
if layer_idx == 0:
|
| 46 |
+
prev_layer_idx = -1
|
| 47 |
+
else:
|
| 48 |
+
prev_layer_idx = (layer_idx - 1) % len(self)
|
| 49 |
+
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
| 50 |
+
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
| 53 |
+
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
| 54 |
+
if layer_idx < len(self):
|
| 55 |
+
if self.offload_kv_cache:
|
| 56 |
+
# Evict the previous layer if necessary
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
torch.cuda.current_stream().synchronize()
|
| 59 |
+
else:
|
| 60 |
+
torch_npu.npu.current_stream().synchronize()
|
| 61 |
+
self.evict_previous_layer(layer_idx)
|
| 62 |
+
# Load current layer cache to its original device if not already there
|
| 63 |
+
# self.prefetch_stream.synchronize(original_device)
|
| 64 |
+
if torch.cuda.is_available():
|
| 65 |
+
torch.cuda.synchronize(self.prefetch_stream)
|
| 66 |
+
else:
|
| 67 |
+
torch_npu.npu.synchronize(self.prefetch_stream)
|
| 68 |
+
key_tensor = self.key_cache[layer_idx]
|
| 69 |
+
value_tensor = self.value_cache[layer_idx]
|
| 70 |
+
|
| 71 |
+
# Prefetch the next layer
|
| 72 |
+
self.prefetch_layer((layer_idx + 1) % len(self))
|
| 73 |
+
else:
|
| 74 |
+
key_tensor = self.key_cache[layer_idx]
|
| 75 |
+
value_tensor = self.value_cache[layer_idx]
|
| 76 |
+
return (key_tensor, value_tensor)
|
| 77 |
+
else:
|
| 78 |
+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
| 79 |
+
|
| 80 |
+
def update(
|
| 81 |
+
self,
|
| 82 |
+
key_states: torch.Tensor,
|
| 83 |
+
value_states: torch.Tensor,
|
| 84 |
+
layer_idx: int,
|
| 85 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 86 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 87 |
+
"""
|
| 88 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
| 89 |
+
Parameters:
|
| 90 |
+
key_states (`torch.Tensor`):
|
| 91 |
+
The new key states to cache.
|
| 92 |
+
value_states (`torch.Tensor`):
|
| 93 |
+
The new value states to cache.
|
| 94 |
+
layer_idx (`int`):
|
| 95 |
+
The index of the layer to cache the states for.
|
| 96 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
| 97 |
+
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
|
| 98 |
+
Return:
|
| 99 |
+
A tuple containing the updated key and value states.
|
| 100 |
+
"""
|
| 101 |
+
# Update the cache
|
| 102 |
+
if len(self.key_cache) < layer_idx:
|
| 103 |
+
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
|
| 104 |
+
elif len(self.key_cache) == layer_idx:
|
| 105 |
+
# only cache the states for condition tokens
|
| 106 |
+
key_states = key_states[..., :-(self.num_tokens_for_img + 1), :]
|
| 107 |
+
value_states = value_states[..., :-(self.num_tokens_for_img + 1), :]
|
| 108 |
+
|
| 109 |
+
# Update the number of seen tokens
|
| 110 |
+
if layer_idx == 0:
|
| 111 |
+
self._seen_tokens += key_states.shape[-2]
|
| 112 |
+
|
| 113 |
+
self.key_cache.append(key_states)
|
| 114 |
+
self.value_cache.append(value_states)
|
| 115 |
+
self.original_device.append(key_states.device)
|
| 116 |
+
if self.offload_kv_cache:
|
| 117 |
+
self.evict_previous_layer(layer_idx)
|
| 118 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 119 |
+
else:
|
| 120 |
+
# only cache the states for condition tokens
|
| 121 |
+
key_tensor, value_tensor = self[layer_idx]
|
| 122 |
+
k = torch.cat([key_tensor, key_states], dim=-2)
|
| 123 |
+
v = torch.cat([value_tensor, value_states], dim=-2)
|
| 124 |
+
return k, v
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class OmniGenScheduler:
|
| 128 |
+
def __init__(self, num_steps: int = 50, time_shifting_factor: int = 1):
|
| 129 |
+
self.num_steps = num_steps
|
| 130 |
+
self.time_shift = time_shifting_factor
|
| 131 |
+
|
| 132 |
+
t = torch.linspace(0, 1, num_steps + 1)
|
| 133 |
+
t = t / (t + time_shifting_factor - time_shifting_factor * t)
|
| 134 |
+
self.sigma = t
|
| 135 |
+
|
| 136 |
+
def crop_kv_cache(self, past_key_values, num_tokens_for_img):
|
| 137 |
+
# return
|
| 138 |
+
crop_past_key_values = ()
|
| 139 |
+
for layer_idx in range(len(past_key_values)):
|
| 140 |
+
key_states, value_states = past_key_values[layer_idx][:2]
|
| 141 |
+
crop_past_key_values += ((key_states[..., :-(num_tokens_for_img + 1), :], value_states[..., :-(num_tokens_for_img + 1), :], ),)
|
| 142 |
+
# return crop_past_key_values
|
| 143 |
+
return DynamicCache.from_legacy_cache(crop_past_key_values)
|
| 144 |
+
|
| 145 |
+
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
| 146 |
+
if isinstance(position_ids, list):
|
| 147 |
+
for i in range(len(position_ids)):
|
| 148 |
+
position_ids[i] = position_ids[i][:, -(num_tokens_for_img + 1):]
|
| 149 |
+
else:
|
| 150 |
+
position_ids = position_ids[:, -(num_tokens_for_img + 1):]
|
| 151 |
+
return position_ids
|
| 152 |
+
|
| 153 |
+
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
|
| 154 |
+
if isinstance(attention_mask, list):
|
| 155 |
+
return [x[..., -(num_tokens_for_img + 1):, :] for x in attention_mask]
|
| 156 |
+
return attention_mask[..., -(num_tokens_for_img + 1):, :]
|
| 157 |
+
|
| 158 |
+
def crop_cache(self, cache, num_tokens_for_img):
|
| 159 |
+
for i in range(len(cache.key_cache)):
|
| 160 |
+
cache.key_cache[i] = cache.key_cache[i][..., :-(num_tokens_for_img + 1), :]
|
| 161 |
+
cache.value_cache[i] = cache.value_cache[i][..., :-(num_tokens_for_img + 1), :]
|
| 162 |
+
|
| 163 |
+
return cache
|
| 164 |
+
|
| 165 |
+
def __call__(self, z, func, model_kwargs, use_kv_cache: bool = True, offload_kv_cache: bool = True, tqdm_disable: bool = False):
|
| 166 |
+
|
| 167 |
+
num_tokens_for_img = z.size(-1) * z.size(-2) // 4
|
| 168 |
+
if isinstance(model_kwargs['llm_input_embeds'], list):
|
| 169 |
+
cache = [OmniGenCache(num_tokens_for_img, offload_kv_cache) for _ in range(len(model_kwargs['llm_input_embeds']))] if use_kv_cache else None
|
| 170 |
+
else:
|
| 171 |
+
cache = OmniGenCache(num_tokens_for_img, offload_kv_cache) if use_kv_cache else None
|
| 172 |
+
for i in tqdm(range(self.num_steps), disable=tqdm_disable):
|
| 173 |
+
timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
|
| 174 |
+
pred, cache = func(z, timesteps, past_key_values=cache, **model_kwargs)
|
| 175 |
+
sigma_next = self.sigma[i + 1]
|
| 176 |
+
sigma = self.sigma[i]
|
| 177 |
+
z = z + (sigma_next - sigma) * pred
|
| 178 |
+
if i == 0 and use_kv_cache:
|
| 179 |
+
num_tokens_for_img = z.size(-1) * z.size(-2) // 4
|
| 180 |
+
if isinstance(cache, list):
|
| 181 |
+
model_kwargs['llm_input_embeds'] = [None] * len(cache)
|
| 182 |
+
else:
|
| 183 |
+
model_kwargs['llm_input_embeds'] = None
|
| 184 |
+
|
| 185 |
+
model_kwargs['llm_position_ids'] = self.crop_position_ids_for_cache(model_kwargs['llm_position_ids'], num_tokens_for_img)
|
| 186 |
+
model_kwargs['llm_attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['llm_attention_mask'], num_tokens_for_img)
|
| 187 |
+
|
| 188 |
+
del cache
|
| 189 |
+
if torch.cuda.is_available():
|
| 190 |
+
torch.cuda.empty_cache()
|
| 191 |
+
else:
|
| 192 |
+
torch_npu.npu.empty_cache()
|
| 193 |
+
gc.collect()
|
| 194 |
+
return z
|
src/image_decoder/transformer.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 6 |
+
from .modeling_phi3 import Phi3Model
|
| 7 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 8 |
+
from transformers.utils import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Phi3Transformer(Phi3Model):
|
| 14 |
+
"""
|
| 15 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
| 16 |
+
We only modified the attention mask
|
| 17 |
+
Args:
|
| 18 |
+
config: Phi3Config
|
| 19 |
+
"""
|
| 20 |
+
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
| 21 |
+
"Starts prefetching the next layer cache"
|
| 22 |
+
with torch.cuda.stream(self.prefetch_stream):
|
| 23 |
+
# Prefetch next layer tensors to GPU
|
| 24 |
+
for name, param in self.layers[layer_idx].named_parameters():
|
| 25 |
+
param.data = param.data.to(device, non_blocking=True)
|
| 26 |
+
|
| 27 |
+
def evict_previous_layer(self, layer_idx: int):
|
| 28 |
+
"Moves the previous layer cache to the CPU"
|
| 29 |
+
prev_layer_idx = layer_idx - 1
|
| 30 |
+
for name, param in self.layers[prev_layer_idx].named_parameters():
|
| 31 |
+
param.data = param.data.to("cpu", non_blocking=True)
|
| 32 |
+
|
| 33 |
+
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
| 34 |
+
# init stream
|
| 35 |
+
if not hasattr(self, "prefetch_stream"):
|
| 36 |
+
self.prefetch_stream = torch.cuda.Stream()
|
| 37 |
+
|
| 38 |
+
# delete previous layer
|
| 39 |
+
torch.cuda.current_stream().synchronize()
|
| 40 |
+
self.evict_previous_layer(layer_idx)
|
| 41 |
+
|
| 42 |
+
# make sure the current layer is ready
|
| 43 |
+
torch.cuda.synchronize(self.prefetch_stream)
|
| 44 |
+
|
| 45 |
+
# load next layer
|
| 46 |
+
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
| 47 |
+
|
| 48 |
+
def forward(
|
| 49 |
+
self,
|
| 50 |
+
input_ids: torch.LongTensor = None,
|
| 51 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 52 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 53 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 54 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 55 |
+
use_cache: Optional[bool] = None,
|
| 56 |
+
output_attentions: Optional[bool] = None,
|
| 57 |
+
output_hidden_states: Optional[bool] = None,
|
| 58 |
+
return_dict: Optional[bool] = None,
|
| 59 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 60 |
+
offload_model: Optional[bool] = False,
|
| 61 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 62 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 63 |
+
output_hidden_states = (
|
| 64 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 65 |
+
)
|
| 66 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 67 |
+
|
| 68 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 69 |
+
|
| 70 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 71 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 72 |
+
|
| 73 |
+
if self.gradient_checkpointing and self.training:
|
| 74 |
+
if use_cache:
|
| 75 |
+
logger.warning_once(
|
| 76 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 77 |
+
)
|
| 78 |
+
use_cache = False
|
| 79 |
+
|
| 80 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 81 |
+
return_legacy_cache = False
|
| 82 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 83 |
+
return_legacy_cache = True
|
| 84 |
+
if past_key_values is None:
|
| 85 |
+
past_key_values = DynamicCache()
|
| 86 |
+
else:
|
| 87 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 88 |
+
logger.warning_once(
|
| 89 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 90 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 91 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# if inputs_embeds is None:
|
| 95 |
+
# inputs_embeds = self.embed_tokens(input_ids)
|
| 96 |
+
|
| 97 |
+
# if cache_position is None:
|
| 98 |
+
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 99 |
+
# cache_position = torch.arange(
|
| 100 |
+
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 101 |
+
# )
|
| 102 |
+
# if position_ids is None:
|
| 103 |
+
# position_ids = cache_position.unsqueeze(0)
|
| 104 |
+
|
| 105 |
+
if attention_mask is not None and attention_mask.dim() == 3:
|
| 106 |
+
dtype = inputs_embeds.dtype
|
| 107 |
+
min_dtype = torch.finfo(dtype).min
|
| 108 |
+
attention_mask = (1 - attention_mask) * min_dtype
|
| 109 |
+
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
| 110 |
+
else:
|
| 111 |
+
raise Exception("attention_mask parameter was unavailable or invalid")
|
| 112 |
+
# causal_mask = self._update_causal_mask(
|
| 113 |
+
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 114 |
+
# )
|
| 115 |
+
|
| 116 |
+
hidden_states = inputs_embeds
|
| 117 |
+
|
| 118 |
+
# decoder layers
|
| 119 |
+
all_hidden_states = () if output_hidden_states else None
|
| 120 |
+
all_self_attns = () if output_attentions else None
|
| 121 |
+
next_decoder_cache = None
|
| 122 |
+
|
| 123 |
+
layer_idx = -1
|
| 124 |
+
for decoder_layer in self.layers:
|
| 125 |
+
layer_idx += 1
|
| 126 |
+
|
| 127 |
+
if output_hidden_states:
|
| 128 |
+
all_hidden_states += (hidden_states,)
|
| 129 |
+
|
| 130 |
+
if self.gradient_checkpointing and self.training:
|
| 131 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 132 |
+
decoder_layer.__call__,
|
| 133 |
+
hidden_states,
|
| 134 |
+
attention_mask,
|
| 135 |
+
position_ids,
|
| 136 |
+
past_key_values,
|
| 137 |
+
output_attentions,
|
| 138 |
+
use_cache,
|
| 139 |
+
cache_position,
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
if offload_model and not self.training:
|
| 143 |
+
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
| 144 |
+
layer_outputs = decoder_layer(
|
| 145 |
+
hidden_states,
|
| 146 |
+
attention_mask=attention_mask,
|
| 147 |
+
position_ids=position_ids,
|
| 148 |
+
past_key_value=past_key_values,
|
| 149 |
+
output_attentions=output_attentions,
|
| 150 |
+
use_cache=use_cache,
|
| 151 |
+
cache_position=cache_position,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
hidden_states = layer_outputs[0]
|
| 155 |
+
|
| 156 |
+
if use_cache:
|
| 157 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 158 |
+
|
| 159 |
+
if output_attentions:
|
| 160 |
+
all_self_attns += (layer_outputs[1],)
|
| 161 |
+
|
| 162 |
+
hidden_states = self.norm(hidden_states)
|
| 163 |
+
|
| 164 |
+
# add hidden states from the last decoder layer
|
| 165 |
+
if output_hidden_states:
|
| 166 |
+
all_hidden_states += (hidden_states,)
|
| 167 |
+
|
| 168 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 169 |
+
if return_legacy_cache:
|
| 170 |
+
next_cache = next_cache.to_legacy_cache()
|
| 171 |
+
|
| 172 |
+
if not return_dict:
|
| 173 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 174 |
+
return BaseModelOutputWithPast(
|
| 175 |
+
last_hidden_state=hidden_states,
|
| 176 |
+
past_key_values=next_cache,
|
| 177 |
+
hidden_states=all_hidden_states,
|
| 178 |
+
attentions=all_self_attns,
|
| 179 |
+
)
|
src/mindomni.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mllm import MindOmniMLLM
|
| 2 |
+
from .image_decoder import OmniGen
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from .image_decoder import Phi3DecoderLayer, ImageDecoderPipeline, OmniGenProcessor
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
from safetensors.torch import load_file
|
| 8 |
+
from typing import Union
|
| 9 |
+
from diffusers.utils import logging
|
| 10 |
+
from diffusers.models import AutoencoderKL
|
| 11 |
+
from transformers import AutoProcessor
|
| 12 |
+
import re
|
| 13 |
+
from qwen_vl_utils import process_vision_info
|
| 14 |
+
try:
|
| 15 |
+
import torch_npu
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(e)
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MindOmniConnector(nn.Module):
|
| 23 |
+
def __init__(self, pre_config, post_config, layer_num: int = 2):
|
| 24 |
+
super().__init__()
|
| 25 |
+
connector_decoder = nn.ModuleList(
|
| 26 |
+
[Phi3DecoderLayer(post_config, layer_idx) for layer_idx in range(layer_num)]
|
| 27 |
+
)
|
| 28 |
+
self.connector = nn.ModuleList(
|
| 29 |
+
[nn.Linear(pre_config.hidden_size, post_config.hidden_size)] # qwen2.5vl-7b: 3584
|
| 30 |
+
)
|
| 31 |
+
self.connector.extend(connector_decoder)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MindOmni:
|
| 35 |
+
def __init__(self, mllm, image_decoder, connector, vae, processor, mllm_processor, device: Union[str, torch.device] = None):
|
| 36 |
+
self.mllm = mllm
|
| 37 |
+
self.image_decoder = image_decoder
|
| 38 |
+
self.connector = connector
|
| 39 |
+
self.vae = vae
|
| 40 |
+
self.processor = processor
|
| 41 |
+
self.mllm_processor = mllm_processor
|
| 42 |
+
|
| 43 |
+
self.vae.to(torch.float32)
|
| 44 |
+
self.device = device
|
| 45 |
+
if device is None:
|
| 46 |
+
if torch.cuda.is_available():
|
| 47 |
+
self.device = torch.device("cuda")
|
| 48 |
+
elif torch_npu.npu.is_available():
|
| 49 |
+
self.device = torch.device("npu")
|
| 50 |
+
elif torch.backends.mps.is_available():
|
| 51 |
+
self.device = torch.device("mps")
|
| 52 |
+
else:
|
| 53 |
+
logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
|
| 54 |
+
self.device = torch.device("cpu")
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def from_pretrained(cls, model_path):
|
| 58 |
+
mllm = MindOmniMLLM.from_pretrained(os.path.join(model_path, 'mllm'))
|
| 59 |
+
image_decoder = OmniGen.from_pretrained(os.path.join(model_path, 'image_decoder'))
|
| 60 |
+
connector = MindOmniConnector(mllm.config, image_decoder.llm.config, 2).connector
|
| 61 |
+
connector_state = load_file(os.path.join(model_path, 'connector.safetensors'))
|
| 62 |
+
connector.load_state_dict(connector_state)
|
| 63 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(model_path, "vae"))
|
| 64 |
+
processor = OmniGenProcessor.from_pretrained(os.path.join(model_path, 'image_decoder'))
|
| 65 |
+
mllm_processor = AutoProcessor.from_pretrained(os.path.join(model_path, 'mllm'))
|
| 66 |
+
logger.info("Preparing MindOmni")
|
| 67 |
+
return cls(mllm, image_decoder, connector, vae, processor, mllm_processor)
|
| 68 |
+
|
| 69 |
+
def to(self, device: Union[str, torch.device] = None, dtype: Union[str, torch.device] = None):
|
| 70 |
+
if device is not None:
|
| 71 |
+
if isinstance(device, str):
|
| 72 |
+
device = torch.device(device)
|
| 73 |
+
self.mllm.to(device)
|
| 74 |
+
self.image_decoder.to(device)
|
| 75 |
+
self.connector.to(device)
|
| 76 |
+
self.vae.to(device)
|
| 77 |
+
self.device = device
|
| 78 |
+
if dtype is not None:
|
| 79 |
+
self.mllm.to(dtype)
|
| 80 |
+
self.image_decoder.to(dtype)
|
| 81 |
+
self.connector.to(dtype)
|
| 82 |
+
|
| 83 |
+
def eval(self):
|
| 84 |
+
self.mllm.eval()
|
| 85 |
+
self.image_decoder.eval()
|
| 86 |
+
self.connector.eval()
|
| 87 |
+
self.vae.eval()
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def get_mllm_hidden_state(self, user_input, input_images, do_sample, temperature, max_new_tokens, only_understand=False, use_cot=False):
|
| 91 |
+
input_llm_images = input_images
|
| 92 |
+
processor = self.mllm_processor
|
| 93 |
+
model = self.mllm
|
| 94 |
+
if only_understand or not use_cot:
|
| 95 |
+
system_prompt = (
|
| 96 |
+
"You are a helpful assistant."
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
system_prompt = (
|
| 100 |
+
"You are a helpful assistant. When the user requests an image, the assistant "
|
| 101 |
+
"first thinks about the reasoning process in the mind and then provides the user with concise prompt as the answer. "
|
| 102 |
+
"The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 103 |
+
"<think> reasoning process here </think><answer> answer here </answer>."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
messages = [
|
| 107 |
+
{
|
| 108 |
+
"role": "system",
|
| 109 |
+
"content": [
|
| 110 |
+
{"type": "text", "text": system_prompt},
|
| 111 |
+
],
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"role": "user",
|
| 115 |
+
"content": [
|
| 116 |
+
{"type": "text", "text": "Generate an image according to the following instructions\n"},
|
| 117 |
+
{"type": "text", "text": user_input},
|
| 118 |
+
],
|
| 119 |
+
}
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
if input_llm_images is not None:
|
| 123 |
+
if only_understand:
|
| 124 |
+
assert len(input_llm_images) == 1, "only support single image when multimodal understanding"
|
| 125 |
+
messages[1]['content'][0] = {"type": "image", "image": input_llm_images[0]}
|
| 126 |
+
else:
|
| 127 |
+
user_input = f'<img><|image_1|></img> {user_input}'
|
| 128 |
+
messages[1]['content'][1] = {"type": "text", "text": user_input}
|
| 129 |
+
image_tags = re.findall(r'<\|image_\d+\|>', messages[1]['content'][1]['text'])
|
| 130 |
+
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
| 131 |
+
pattern = r"<img><\|image_\d+\|></img>"
|
| 132 |
+
prompt_chunks = [chunk for chunk in re.split(pattern, messages[1]['content'][1]['text'])]
|
| 133 |
+
assert len(prompt_chunks) == len(input_llm_images) + 1
|
| 134 |
+
new_content = []
|
| 135 |
+
for idx, per_prompt in enumerate(prompt_chunks):
|
| 136 |
+
if idx != len(prompt_chunks) - 1:
|
| 137 |
+
item_text = {"type": "text", "text": per_prompt}
|
| 138 |
+
# resized_height, resized_width = input_images_shape[image_ids[idx] - 1]
|
| 139 |
+
image_path = input_llm_images[image_ids[idx] - 1]
|
| 140 |
+
# item_vit = {"type": "image", "image": image_path, "resized_height": resized_height, "resized_width": resized_width}
|
| 141 |
+
item_vit = {"type": "image", "image": image_path}
|
| 142 |
+
item_tag = {"type": "text", "text": f"<img>{image_tags[idx]}</img>"}
|
| 143 |
+
new_content.append(item_text)
|
| 144 |
+
new_content.append(item_vit)
|
| 145 |
+
new_content.append(item_tag)
|
| 146 |
+
else:
|
| 147 |
+
item_text = {"type": "text", "text": per_prompt}
|
| 148 |
+
new_content.append(item_text)
|
| 149 |
+
messages[1]['content'] = messages[1]['content'][:1] + new_content
|
| 150 |
+
|
| 151 |
+
text = processor.apply_chat_template(
|
| 152 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 153 |
+
)
|
| 154 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 155 |
+
inputs = processor(
|
| 156 |
+
text=[text],
|
| 157 |
+
images=image_inputs,
|
| 158 |
+
videos=video_inputs,
|
| 159 |
+
padding=True,
|
| 160 |
+
return_tensors="pt",
|
| 161 |
+
)
|
| 162 |
+
inputs = inputs.to("npu")
|
| 163 |
+
|
| 164 |
+
if use_cot:
|
| 165 |
+
# Inference: Generation of the output
|
| 166 |
+
temperature = temperature if do_sample else None
|
| 167 |
+
generated_dict = model.generate(**inputs, do_sample=do_sample, temperature=temperature, max_new_tokens=max_new_tokens, output_hidden_states=True, return_dict_in_generate=True)
|
| 168 |
+
generated_ids_trimmed = [
|
| 169 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_dict.sequences)
|
| 170 |
+
]
|
| 171 |
+
output_hidden_state = [hidden_state[-1] for hidden_state in generated_dict.hidden_states]
|
| 172 |
+
context_hidden_state = torch.cat(output_hidden_state, dim=1)
|
| 173 |
+
|
| 174 |
+
output_text = processor.batch_decode(
|
| 175 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
prompt_ = output_text[0]
|
| 179 |
+
|
| 180 |
+
assistant_content = [
|
| 181 |
+
{
|
| 182 |
+
"role": "assistant",
|
| 183 |
+
"content": [
|
| 184 |
+
{"type": "text", "text": prompt_},
|
| 185 |
+
],
|
| 186 |
+
}
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
messages += assistant_content
|
| 190 |
+
else:
|
| 191 |
+
prompt_ = user_input
|
| 192 |
+
context_hidden_state = model(**inputs, output_hidden_states=True).hidden_states[-1]
|
| 193 |
+
return messages, prompt_, context_hidden_state
|
| 194 |
+
|
| 195 |
+
def generate_image(self, height, width, guidance_scale, inference_steps, separate_cfg_infer, offload_model, seed, max_input_image_size,
|
| 196 |
+
text, NEGATIVE_PROMPT, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=False):
|
| 197 |
+
gen_pipe = ImageDecoderPipeline(self.vae, self.image_decoder, self.connector, self.processor)
|
| 198 |
+
message, prompt_, context_hidden_state = self.get_mllm_hidden_state(text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=use_cot)
|
| 199 |
+
neg_message, neg_prompt_, neg_context_hidden_state = self.get_mllm_hidden_state(NEGATIVE_PROMPT, None, do_sample, temperature, max_new_tokens, only_understand, use_cot=False)
|
| 200 |
+
print(message)
|
| 201 |
+
output = gen_pipe(
|
| 202 |
+
context_hidden_state=context_hidden_state,
|
| 203 |
+
neg_context_hidden_state=neg_context_hidden_state,
|
| 204 |
+
height=height,
|
| 205 |
+
width=width,
|
| 206 |
+
guidance_scale=guidance_scale,
|
| 207 |
+
num_inference_steps=inference_steps,
|
| 208 |
+
separate_cfg_infer=separate_cfg_infer,
|
| 209 |
+
use_kv_cache=True,
|
| 210 |
+
offload_kv_cache=True,
|
| 211 |
+
offload_model=offload_model,
|
| 212 |
+
seed=seed,
|
| 213 |
+
max_input_image_size=max_input_image_size,
|
| 214 |
+
)
|
| 215 |
+
return output, prompt_
|
| 216 |
+
|
| 217 |
+
def generate_text(self, text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand):
|
| 218 |
+
_, answer, _ = self.get_mllm_hidden_state(text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand=True, use_cot=True)
|
| 219 |
+
return answer
|
src/mllm.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 5 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import logger
|
| 6 |
+
from transformers.cache_utils import DynamicCache
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MindOmniMLLM_Model(Qwen2_5_VLModel):
|
| 10 |
+
|
| 11 |
+
def forward(
|
| 12 |
+
self,
|
| 13 |
+
input_ids: torch.LongTensor = None,
|
| 14 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 15 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 16 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 17 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 18 |
+
use_cache: Optional[bool] = None,
|
| 19 |
+
output_attentions: Optional[bool] = None,
|
| 20 |
+
output_hidden_states: Optional[bool] = None,
|
| 21 |
+
return_dict: Optional[bool] = None,
|
| 22 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 23 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 24 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 25 |
+
output_hidden_states = (
|
| 26 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 27 |
+
)
|
| 28 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 29 |
+
|
| 30 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 31 |
+
|
| 32 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 33 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 34 |
+
|
| 35 |
+
if self.gradient_checkpointing and self.training:
|
| 36 |
+
if use_cache:
|
| 37 |
+
logger.warning_once(
|
| 38 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 39 |
+
)
|
| 40 |
+
use_cache = False
|
| 41 |
+
|
| 42 |
+
# torch.jit.trace() doesn't support cache objects in the output
|
| 43 |
+
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
| 44 |
+
past_key_values = DynamicCache()
|
| 45 |
+
|
| 46 |
+
if inputs_embeds is None:
|
| 47 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 48 |
+
|
| 49 |
+
if cache_position is None:
|
| 50 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 51 |
+
cache_position = torch.arange(
|
| 52 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# the hard coded `3` is for temporal, height and width.
|
| 56 |
+
if position_ids is None:
|
| 57 |
+
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
| 58 |
+
elif position_ids.dim() == 2:
|
| 59 |
+
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 60 |
+
|
| 61 |
+
causal_mask = self._update_causal_mask(
|
| 62 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 63 |
+
)
|
| 64 |
+
hidden_states = inputs_embeds
|
| 65 |
+
|
| 66 |
+
# create position embeddings to be shared across the decoder layers
|
| 67 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 68 |
+
|
| 69 |
+
# decoder layers
|
| 70 |
+
all_hidden_states = () if output_hidden_states else None
|
| 71 |
+
all_self_attns = () if output_attentions else None
|
| 72 |
+
next_decoder_cache = None
|
| 73 |
+
|
| 74 |
+
for decoder_layer in self.layers:
|
| 75 |
+
if output_hidden_states:
|
| 76 |
+
all_hidden_states += (hidden_states,)
|
| 77 |
+
|
| 78 |
+
if self.gradient_checkpointing and self.training:
|
| 79 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 80 |
+
decoder_layer.__call__,
|
| 81 |
+
hidden_states,
|
| 82 |
+
causal_mask,
|
| 83 |
+
position_ids,
|
| 84 |
+
past_key_values,
|
| 85 |
+
output_attentions,
|
| 86 |
+
use_cache,
|
| 87 |
+
cache_position,
|
| 88 |
+
position_embeddings,
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
layer_outputs = decoder_layer(
|
| 92 |
+
hidden_states,
|
| 93 |
+
attention_mask=causal_mask,
|
| 94 |
+
position_ids=position_ids,
|
| 95 |
+
past_key_value=past_key_values,
|
| 96 |
+
output_attentions=output_attentions,
|
| 97 |
+
use_cache=use_cache,
|
| 98 |
+
cache_position=cache_position,
|
| 99 |
+
position_embeddings=position_embeddings,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
hidden_states = layer_outputs[0]
|
| 103 |
+
|
| 104 |
+
if use_cache:
|
| 105 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 106 |
+
|
| 107 |
+
if output_attentions:
|
| 108 |
+
all_self_attns += (layer_outputs[1],)
|
| 109 |
+
|
| 110 |
+
# add hidden states from the last decoder layer before the self.norm
|
| 111 |
+
# import ipdb; ipdb.set_trace()
|
| 112 |
+
if output_hidden_states:
|
| 113 |
+
all_hidden_states += (hidden_states,)
|
| 114 |
+
|
| 115 |
+
hidden_states = self.norm(hidden_states)
|
| 116 |
+
|
| 117 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 118 |
+
|
| 119 |
+
if not return_dict:
|
| 120 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 121 |
+
return BaseModelOutputWithPast(
|
| 122 |
+
last_hidden_state=hidden_states,
|
| 123 |
+
past_key_values=next_cache,
|
| 124 |
+
hidden_states=all_hidden_states,
|
| 125 |
+
attentions=all_self_attns,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class MindOmniMLLM(Qwen2_5_VLForConditionalGeneration):
|
| 130 |
+
|
| 131 |
+
def __init__(self, config):
|
| 132 |
+
super().__init__(config)
|
| 133 |
+
self.model = MindOmniMLLM_Model(config)
|
| 134 |
+
|
| 135 |
+
# @staticmethod
|
| 136 |
+
# def _update_model_kwargs_for_generation(
|
| 137 |
+
# outputs, model_kwargs, past_key_values_field="past_key_values"
|
| 138 |
+
# ):
|
| 139 |
+
# if past_key_values_field in outputs:
|
| 140 |
+
# model_kwargs[past_key_values_field] = outputs[past_key_values_field]
|
| 141 |
+
|
| 142 |
+
# if "attention_mask" in model_kwargs:
|
| 143 |
+
# bs, _ = model_kwargs["attention_mask"].shape
|
| 144 |
+
# new_mask = torch.ones(bs, 1, dtype=model_kwargs["attention_mask"].dtype,
|
| 145 |
+
# device=model_kwargs["attention_mask"].device)
|
| 146 |
+
# model_kwargs["attention_mask"] = torch.cat(
|
| 147 |
+
# [model_kwargs["attention_mask"], new_mask], dim=-1
|
| 148 |
+
# )
|
| 149 |
+
# return model_kwargs
|
| 150 |
+
|
| 151 |
+
# @staticmethod
|
| 152 |
+
# def _sample_token(
|
| 153 |
+
# logits: torch.Tensor,
|
| 154 |
+
# do_sample: bool,
|
| 155 |
+
# logits_processors: LogitsProcessorList,
|
| 156 |
+
# temperature: float,
|
| 157 |
+
# top_p: float,
|
| 158 |
+
# ):
|
| 159 |
+
# """do sample / greedy"""
|
| 160 |
+
# logits = logits_processors(None, logits)
|
| 161 |
+
# if do_sample:
|
| 162 |
+
# # 温度缩放
|
| 163 |
+
# if temperature != 1.0 and temperature > 0:
|
| 164 |
+
# logits = logits / temperature
|
| 165 |
+
# # nucleus
|
| 166 |
+
# if top_p < 1.0:
|
| 167 |
+
# logits = TopPLogitsWarper(top_p=top_p)(None, logits)
|
| 168 |
+
# probs = nn.functional.softmax(logits, dim=-1, dtype=torch.float32)
|
| 169 |
+
# next_token = torch.multinomial(probs, num_samples=1)
|
| 170 |
+
# else: # greedy
|
| 171 |
+
# next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 172 |
+
# return next_token
|
| 173 |
+
|
| 174 |
+
# @torch.no_grad()
|
| 175 |
+
# def generate(
|
| 176 |
+
# self,
|
| 177 |
+
# pixel_values: Optional[torch.FloatTensor] = None,
|
| 178 |
+
# input_ids: Optional[torch.LongTensor] = None,
|
| 179 |
+
# attention_mask: Optional[torch.LongTensor] = None,
|
| 180 |
+
# max_new_tokens: int = 64,
|
| 181 |
+
# do_sample: bool = False,
|
| 182 |
+
# temperature: float = 1.0,
|
| 183 |
+
# top_p: float = 0.95,
|
| 184 |
+
# device: Union[str, torch.device] = "cuda",
|
| 185 |
+
# ) -> torch.LongTensor:
|
| 186 |
+
|
| 187 |
+
# assert input_ids is not None
|
| 188 |
+
# eos_token_id = self.config.eos_token_id
|
| 189 |
+
|
| 190 |
+
# generated = [input_ids]
|
| 191 |
+
|
| 192 |
+
# input_ids = input_ids.to(device)
|
| 193 |
+
# if pixel_values is not None:
|
| 194 |
+
# pixel_values = pixel_values.to(device)
|
| 195 |
+
# if attention_mask is None:
|
| 196 |
+
# attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
| 197 |
+
|
| 198 |
+
# logits_processors = LogitsProcessorList()
|
| 199 |
+
# if temperature != 1.0 and do_sample:
|
| 200 |
+
# logits_processors.append(TemperatureLogitsWarper(temperature))
|
| 201 |
+
# if top_p < 1.0 and do_sample:
|
| 202 |
+
# logits_processors.append(TopPLogitsWarper(top_p=top_p))
|
| 203 |
+
|
| 204 |
+
# # ---- 推理循环 ---- #
|
| 205 |
+
# model_kwargs = {
|
| 206 |
+
# "attention_mask": attention_mask,
|
| 207 |
+
# "use_cache": True,
|
| 208 |
+
# "past_key_values": None,
|
| 209 |
+
# "cache_position": torch.arange(attention_mask.shape[-1]).to(attention_mask)
|
| 210 |
+
# }
|
| 211 |
+
|
| 212 |
+
# for _ in range(max_new_tokens):
|
| 213 |
+
# model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 214 |
+
|
| 215 |
+
# outputs = self(
|
| 216 |
+
# input_ids=input_ids,
|
| 217 |
+
# use_cache=True,
|
| 218 |
+
# **model_kwargs,
|
| 219 |
+
# )
|
| 220 |
+
|
| 221 |
+
# next_token = self._sample_token(
|
| 222 |
+
# outputs.logits[:, -1, :],
|
| 223 |
+
# do_sample=do_sample,
|
| 224 |
+
# logits_processors=logits_processors,
|
| 225 |
+
# temperature=temperature,
|
| 226 |
+
# top_p=top_p,
|
| 227 |
+
# ) # (bs, 1)
|
| 228 |
+
|
| 229 |
+
# # 追加生成
|
| 230 |
+
# input_ids = next_token
|
| 231 |
+
# generated.append(next_token)
|
| 232 |
+
|
| 233 |
+
# # 更新 kv cache / attention_mask
|
| 234 |
+
# model_kwargs = self._update_model_kwargs_for_generation(
|
| 235 |
+
# outputs, model_kwargs
|
| 236 |
+
# )
|
| 237 |
+
|
| 238 |
+
# # 判断终止:所有 batch 均生成 eos
|
| 239 |
+
# if eos_token_id is not None:
|
| 240 |
+
# if (next_token == eos_token_id).all():
|
| 241 |
+
# break
|
| 242 |
+
|
| 243 |
+
# generated_ids = torch.cat(generated, dim=1)
|
| 244 |
+
|
| 245 |
+
# return generated_ids
|