stevengrove commited on
Commit
18d050b
·
verified ·
1 Parent(s): 067ef85

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/example_outputs/case_1.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/example_outputs/case_2.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/example_outputs/case_3.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/example_outputs/case_4.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/example_outputs/case_5.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/example_outputs/case_6.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/example_outputs/case_7.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/framework.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/grpo_curve.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/inference.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/reasoning_case_com.png filter=lfs diff=lfs merge=lfs -text
License.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making MindOmni available.
2
+
3
+ Copyright (C) 2025 Tencent. All rights reserved.
4
+
5
+ MindOmni is licensed under the MIT License.
6
+
7
+
8
+ Terms of the MIT License:
9
+ --------------------------------------------------------------------
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
app.py CHANGED
@@ -1,7 +1,179 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+ import argparse
3
+ from functools import partial
4
+
5
+ import torch
6
+ import random
7
+ import spaces
8
  import gradio as gr
9
+ from src import MindOmni
10
+
11
+ NEGATIVE_PROMPT = '''
12
+ low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.
13
+ '''
14
+
15
+
16
+ def parse_args():
17
+ args = argparse.ArgumentParser(description='MindOmni')
18
+ args.add_argument('--device', type=str, default='cuda')
19
+ args.add_argument('--dtype', type=str, default='bf16')
20
+ args.add_argument('--server_name', type=str, default='127.0.0.1')
21
+ args.add_argument('--port', type=int, default=8080)
22
+ args.add_argument('--model_path', type=str,
23
+ default='your_path/MindOmni')
24
+ args = args.parse_args()
25
+ return args
26
+
27
+
28
+ def build_model(args):
29
+ device = args.device
30
+ MindOmni_model = MindOmni.from_pretrained(args.model_path)
31
+ if args.dtype == "bf16":
32
+ dtype = torch.bfloat16
33
+ MindOmni_model.to(device=device, dtype=dtype)
34
+ MindOmni_model.eval()
35
+ return MindOmni_model
36
+
37
+
38
+ @spaces.GPU(duration=180)
39
+ def understand_func(
40
+ MindOmni_model, text, do_sample, temperature,
41
+ max_new_tokens, input_llm_images):
42
+ if input_llm_images is not None and not isinstance(input_llm_images, list):
43
+ input_llm_images = [input_llm_images]
44
+ answer = MindOmni_model.generate_text(
45
+ text, input_llm_images, do_sample, temperature,
46
+ max_new_tokens, only_understand=True)
47
+ return answer
48
+
49
+
50
+ @spaces.GPU(duration=180)
51
+ def generate_func(
52
+ MindOmni_model, text, use_cot, height, width, guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model, max_input_image_size, randomize_seed, save_images, do_sample, temperature, max_new_tokens, input_llm_images, only_understand):
53
+ if input_llm_images is not None and not isinstance(input_llm_images, list):
54
+ input_llm_images = [input_llm_images]
55
+
56
+ if randomize_seed:
57
+ seed = random.randint(0, 10000000)
58
+
59
+ os.makedirs(os.path.dirname('/tmp/.unhold'), exist_ok=True)
60
+ with open('/tmp/.unhold', 'w') as f:
61
+ f.write('')
62
+ output, prompt_ = MindOmni_model.generate_image(
63
+ height, width, guidance_scale, inference_steps, separate_cfg_infer, offload_model, seed, max_input_image_size,
64
+ text, NEGATIVE_PROMPT, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=use_cot)
65
+ os.remove('/tmp/.unhold')
66
+
67
+ img = output[0]
68
+
69
+ if save_images:
70
+ # Save All Generated Images
71
+ from datetime import datetime
72
+ # Create outputs directory if it doesn't exist
73
+ os.makedirs('assets/outputs', exist_ok=True)
74
+ # Generate unique filename with timestamp
75
+ timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
76
+ output_path = os.path.join('assets/outputs', f'{timestamp}.png')
77
+ # Save the image
78
+ img.save(output_path)
79
+
80
+ return img, prompt_, seed
81
+
82
+
83
+ def build_gradio(args, MindOmni_model):
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("## 🪄 MindOmni Demo")
86
+
87
+ with gr.Tabs():
88
+ # ---------- GENERATE ----------
89
+ with gr.TabItem("🎨 Generate"):
90
+ with gr.Row():
91
+ with gr.Column(scale=1):
92
+ g_prompt = gr.Textbox(label="Text prompt")
93
+ g_image = gr.Image(label="Condition image (optional)", type="filepath")
94
+ g_btn = gr.Button("🚀 Generate Image")
95
+
96
+ with gr.Accordion("📚 Image Generation Args"):
97
+ g_use_cot = gr.Checkbox(label="With thinking", value=False)
98
+ g_do_sample = gr.Checkbox(label="Do sample", value=False)
99
+ g_temperature = gr.Slider(0, 10, value=1, label="Temperature")
100
+ g_max_new_tok = gr.Slider(32, 8192, value=512, label="Max new tokens")
101
+
102
+ g_height = gr.Slider(128, 2048, value=1024, step=16, label="Height")
103
+ g_width = gr.Slider(128, 2048, value=1024, step=16, label="Width")
104
+ g_scale = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Guidance Scale")
105
+ g_steps = gr.Slider(1, 100, value=50, label="Inference Steps")
106
+ g_seed = gr.Slider(0, 2**31 - 1, value=42, label="Seed")
107
+ g_rand = gr.Checkbox(label="Randomize seed", value=False)
108
+ g_max_img = gr.Slider(128, 2048, value=1024, step=16,
109
+ label="Max input image size")
110
+ g_sep_cfg = gr.Checkbox(label="Separate-CFG infer", value=True)
111
+ g_offload = gr.Checkbox(label="Offload model to CPU", value=False)
112
+ g_save = gr.Checkbox(label="Save generated images", value=False)
113
+
114
+ with gr.Column(scale=1):
115
+ g_out_img = gr.Image(label="Generated Image")
116
+ g_prompt_out = gr.Textbox(label="MindOmni CoT Content")
117
+ g_seed_out = gr.Textbox(label="Used seed")
118
+
119
+ with gr.Accordion("🖼️ Prompt Examples: Text-only"):
120
+ gr.Examples(
121
+ examples=[
122
+ ["Futuristic city skyline at sunset, digital art", 42, False, False, False, 1024, 1024, "assets/example_outputs/case_1.png"],
123
+ ["An image of multiple apples, the quantity of apples is the solution of '2x + 6 = 16'.", 1723284, False, True, False, 512, 1024, "assets/example_outputs/case_2.png"],
124
+ ["A park with benches equal to the solution of 'x^2 -2x = 8'.", 4318852, False, True, False, 512, 512, "assets/example_outputs/case_3.png"],
125
+ ["An image of China's national treasure animal.", 42, False, True, False, 1024, 1024, "assets/example_outputs/case_4.png"],
126
+ ["Scene in the Sydney Opera House when New York is at noon.", 42, False, True, False, 1024, 1024, "assets/example_outputs/case_5.png"],
127
+ ["Generate an image of an animal with (3 + 6) lives", 7393438, False, True, False, 1024, 1024, "assets/example_outputs/case_6.png"],
128
+ ],
129
+ inputs=[g_prompt, g_seed, g_rand, g_use_cot, g_do_sample, g_height, g_width, g_out_img],
130
+ )
131
+ with gr.Accordion("🖼️ Prompt Examples: With reference image"):
132
+ gr.Examples(
133
+ examples=[
134
+ ["An image of the animal growing up", "assets/tapdole.jpeg", 42, False, True, True, 1024, 1024, "assets/example_outputs/case_7.png"]
135
+ ],
136
+ inputs=[g_prompt, g_image, g_seed, g_rand, g_use_cot, g_do_sample, g_height, g_width, g_out_img],
137
+ )
138
+
139
+ g_btn.click(
140
+ partial(generate_func, MindOmni_model),
141
+ inputs=[g_prompt, g_use_cot, g_height, g_width, g_scale, g_steps,
142
+ g_seed, g_sep_cfg, g_offload, g_max_img, g_rand, g_save,
143
+ g_do_sample, g_temperature, g_max_new_tok,
144
+ g_image, gr.State(False)], # only_understand=False
145
+ outputs=[g_out_img, g_prompt_out, g_seed_out])
146
+
147
+ # ---------- UNDERSTAND ----------
148
+ with gr.TabItem("🧠 Understand"):
149
+ with gr.Row():
150
+ with gr.Column(scale=1):
151
+ u_prompt = gr.Textbox(label="Text prompt")
152
+ u_image = gr.Image(label="Image (optional)", type="filepath")
153
+ u_btn = gr.Button("🔍 Understand")
154
+ with gr.Accordion("📚 Text Generation Args"):
155
+ u_do_sample = gr.Checkbox(label="Do sample", value=False)
156
+ u_temperature = gr.Slider(0, 10, value=1, label="Temperature")
157
+ u_max_new_tok = gr.Slider(32, 8192, value=512, label="Max new tokens")
158
+
159
+ with gr.Column(scale=1):
160
+ u_answer = gr.Textbox(label="Answer", lines=8)
161
+
162
+ u_btn.click(
163
+ partial(understand_func, MindOmni_model),
164
+ inputs=[u_prompt, u_do_sample,
165
+ u_temperature, u_max_new_tok, u_image],
166
+ outputs=u_answer)
167
+
168
+ demo.launch(server_name=args.server_name, server_port=args.port)
169
+
170
+
171
+ def main():
172
+ args = parse_args()
173
+ print(f'running args: {args}')
174
+ MindOmni_model = build_model(args)
175
+ build_gradio(args, MindOmni_model)
176
 
 
 
177
 
178
+ if __name__ == '__main__':
179
+ main()
assets/example_outputs/case_1.png ADDED

Git LFS Details

  • SHA256: fc5d5930caa9582f75f622e2d22fe7ff41ed2d8d324d3fb2a452cdf6fe4b3d7d
  • Pointer size: 131 Bytes
  • Size of remote file: 903 kB
assets/example_outputs/case_2.png ADDED

Git LFS Details

  • SHA256: 9a912bf42c39967cf473cf0a79bc9c5ceb5294f677a4c31abd51dcd644fc861b
  • Pointer size: 131 Bytes
  • Size of remote file: 643 kB
assets/example_outputs/case_3.png ADDED

Git LFS Details

  • SHA256: 5c9f373ff129c5553be29bcabef537390e6790c82011592494b57f2ccb64fd67
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
assets/example_outputs/case_4.png ADDED

Git LFS Details

  • SHA256: 82f75c052fed92ffa698dcaf8669e675ece8160b6b7f6a40f54de8eaf95e00c9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
assets/example_outputs/case_5.png ADDED

Git LFS Details

  • SHA256: 3124064afc37df34f7e4544881ce56d2d7c70a52095f5eb4e1a79d5e3b68ebd3
  • Pointer size: 131 Bytes
  • Size of remote file: 851 kB
assets/example_outputs/case_6.png ADDED

Git LFS Details

  • SHA256: a7f5806c38c4b11b0cd1d4ecf4a540fa4e5f5a30a485f18574ef01cf1d29e9cd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
assets/example_outputs/case_7.png ADDED

Git LFS Details

  • SHA256: e8fdba59a52d503b50785fad20a68d968cb8085829ab0f645a61c8bf5842e89b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
assets/framework.png ADDED

Git LFS Details

  • SHA256: db7bd9d42517f5c5ca029caa2d0df470fe726deb49159c135200d0b94bc8af7e
  • Pointer size: 131 Bytes
  • Size of remote file: 628 kB
assets/grpo_curve.png ADDED

Git LFS Details

  • SHA256: 50f7a896152152c034bcbfc685c31593ce28d1f4f790f96a1c3b0d4bf5487303
  • Pointer size: 131 Bytes
  • Size of remote file: 192 kB
assets/inference.png ADDED

Git LFS Details

  • SHA256: d0d385f141ab67ef2297667d157a5ffbab15f6bfd86e1e7ad2363babd9e3ae61
  • Pointer size: 131 Bytes
  • Size of remote file: 846 kB
assets/reasoning_case_com.png ADDED

Git LFS Details

  • SHA256: 56e354f1ca5bac5c6b4aa2bc728474093bb028335cfbf1d0deacea94bec0c2f0
  • Pointer size: 132 Bytes
  • Size of remote file: 2.74 MB
assets/tapdole.jpeg ADDED
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.7.0
2
+ datasets==2.20.0
3
+ decord==0.6.0
4
+ deepspeed==0.16.5
5
+ diffusers==0.30.3
6
+ gradio==4.44.1
7
+ gradio_client==1.3.0
8
+ huggingface-hub==0.32.0
9
+ numpy==1.26.3
10
+ omegaconf==2.3.0
11
+ pandas==2.2.3
12
+ pathvalidate==3.2.1
13
+ peft==0.13.2
14
+ qwen-vl-utils==0.0.8
15
+ safetensors==0.4.5
16
+ scipy==1.13.1
17
+ sympy==1.13.3
18
+ timm==0.9.16
19
+ tokenizers==0.21.1
20
+ torch==2.4.0
21
+ transformers==4.51.1
src/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .image_decoder import * # noqa
2
+ from .mllm import MindOmniMLLM, MindOmniMLLM_Model
3
+ from .mindomni import MindOmni
4
+
5
+ __all__ = ["MindOmniMLLM", "MindOmniMLLM_Model", "MindOmni"]
src/image_decoder/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .image_pipeline import ImageDecoderPipeline
2
+ from .model import OmniGen
3
+ from .modeling_phi3 import Phi3DecoderLayer
4
+ from .processor import OmniGenProcessor
5
+
6
+ __all__ = ["ImageDecoderPipeline", "OmniGen", "Phi3DecoderLayer", "OmniGenProcessor"]
src/image_decoder/image_pipeline.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on OmniGen
2
+ from typing import List, Union
3
+ import gc
4
+
5
+ from PIL import Image
6
+ import torch
7
+ try:
8
+ import torch_npu
9
+ except Exception as e:
10
+ print(e)
11
+ from diffusers.models import AutoencoderKL
12
+ from diffusers.utils import logging
13
+ import torch.nn as nn
14
+ from .processor import OmniGenProcessor
15
+ from .model import OmniGen
16
+ from .scheduler import OmniGenScheduler
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class ImageDecoderPipeline:
23
+ def __init__(
24
+ self,
25
+ vae: AutoencoderKL,
26
+ model: OmniGen,
27
+ connector: nn.Module,
28
+ processor: OmniGenProcessor,
29
+ device: Union[str, torch.device] = None,
30
+ ):
31
+ self.vae = vae
32
+ self.model = model
33
+ self.connector = connector
34
+ self.processor = processor
35
+ self.device = device
36
+
37
+ if device is None:
38
+ if torch.cuda.is_available():
39
+ self.device = torch.device("cuda")
40
+ elif torch_npu.npu.is_available():
41
+ self.device = torch.device("npu")
42
+ elif torch.backends.mps.is_available():
43
+ self.device = torch.device("mps")
44
+ else:
45
+ logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
46
+ self.device = torch.device("cpu")
47
+
48
+ # self.model.to(torch.bfloat16)
49
+ self.model.eval()
50
+ self.vae.eval()
51
+
52
+ self.model_cpu_offload = False
53
+
54
+ def to(self, device: Union[str, torch.device]):
55
+ if isinstance(device, str):
56
+ device = torch.device(device)
57
+ self.model.to(device)
58
+ self.vae.to(device)
59
+ self.device = device
60
+
61
+ def vae_encode(self, x, dtype):
62
+ if self.vae.config.shift_factor is not None:
63
+ x = self.vae.encode(x).latent_dist.sample()
64
+ x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
65
+ else:
66
+ x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
67
+ x = x.to(dtype)
68
+ return x
69
+
70
+ def move_to_device(self, data):
71
+ if isinstance(data, list):
72
+ return [x.to(self.device) for x in data]
73
+ return data.to(self.device)
74
+
75
+ def enable_model_cpu_offload(self):
76
+ self.model_cpu_offload = True
77
+ self.model.to("cpu")
78
+ self.vae.to("cpu")
79
+ if torch.cuda.is_available():
80
+ torch.cuda.empty_cache() # Clear VRAM
81
+ elif torch_npu.npu.is_available():
82
+ torch_npu.npu.empty_cache() # Clear VRAM
83
+ gc.collect() # Run garbage collection to free system RAM
84
+
85
+ def disable_model_cpu_offload(self):
86
+ self.model_cpu_offload = False
87
+ self.model.to(self.device)
88
+ self.vae.to(self.device)
89
+
90
+ @torch.no_grad()
91
+ def __call__(
92
+ self,
93
+ context_hidden_state: Union[str, List[str]] = None,
94
+ neg_context_hidden_state: Union[str, List[str]] = None,
95
+ height: int = 1024,
96
+ width: int = 1024,
97
+ num_inference_steps: int = 50,
98
+ guidance_scale: float = 3,
99
+ max_input_image_size: int = 1024,
100
+ separate_cfg_infer: bool = True,
101
+ offload_model: bool = False,
102
+ use_kv_cache: bool = True,
103
+ offload_kv_cache: bool = True,
104
+ dtype: torch.dtype = torch.bfloat16,
105
+ seed: int = None,
106
+ output_type: str = "pil",
107
+ tqdm_disable: bool = False,
108
+ ):
109
+ r"""
110
+ Function invoked when calling the pipeline for generation.
111
+
112
+ Args:
113
+ prompt (`str` or `List[str]`):
114
+ The prompt or prompts to guide the image generation.
115
+ input_images (`List[str]` or `List[List[str]]`, *optional*):
116
+ The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
117
+ height (`int`, *optional*, defaults to 1024):
118
+ The height in pixels of the generated image. The number must be a multiple of 16.
119
+ width (`int`, *optional*, defaults to 1024):
120
+ The width in pixels of the generated image. The number must be a multiple of 16.
121
+ num_inference_steps (`int`, *optional*, defaults to 50):
122
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
123
+ guidance_scale (`float`, *optional*, defaults to 4.0):
124
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
125
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
126
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
127
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
128
+ usually at the expense of lower image quality.
129
+ use_img_guidance (`bool`, *optional*, defaults to True):
130
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
131
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
132
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
133
+ max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
134
+ separate_cfg_infer (`bool`, *optional*, defaults to False):
135
+ Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
136
+ use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
137
+ offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
138
+ offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
139
+ use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
140
+ seed (`int`, *optional*):
141
+ A random seed for generating output.
142
+ dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
143
+ data type for the model
144
+ output_type (`str`, *optional*, defaults to "pil"):
145
+ The type of the output image, which can be "pt" or "pil"
146
+ Examples:
147
+
148
+ Returns:
149
+ A list with the generated images.
150
+ """
151
+ # check inputs:
152
+ assert height % 16 == 0 and width % 16 == 0, "The height and width must be a multiple of 16."
153
+ if context_hidden_state is not None and not isinstance(context_hidden_state, list):
154
+ context_hidden_state = [context_hidden_state]
155
+ neg_context_hidden_state = [neg_context_hidden_state]
156
+
157
+ # set model and processor
158
+ if max_input_image_size != self.processor.max_image_size:
159
+ self.processor = OmniGenProcessor(max_image_size=max_input_image_size)
160
+ self.model.to(dtype)
161
+ if offload_model:
162
+ self.enable_model_cpu_offload()
163
+ else:
164
+ self.disable_model_cpu_offload()
165
+
166
+ input_data = self.processor(context_hidden_state, neg_context_hidden_state, height=height, width=width, separate_cfg_input=separate_cfg_infer)
167
+
168
+ num_prompt = len(context_hidden_state)
169
+ num_cfg = 1
170
+ latent_size_h, latent_size_w = height // 8, width // 8
171
+
172
+ if seed is not None:
173
+ generator = torch.Generator(device=self.device).manual_seed(seed)
174
+ else:
175
+ generator = None
176
+ latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
177
+ latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype)
178
+
179
+ model_kwargs = dict(cfg_scale=guidance_scale,
180
+ use_kv_cache=use_kv_cache,
181
+ offload_model=offload_model,
182
+ )
183
+ # obtain the qwen feature
184
+ # if self.llm_processor is not None:
185
+ llm_input_embeds = []
186
+ with torch.no_grad():
187
+ # for seperate cfg infer mode
188
+ for i in range(len(input_data['context_hidden_state'])):
189
+
190
+ context_hidden_state = input_data['context_hidden_state'][i]
191
+ hidden_states = self.connector[0](context_hidden_state)
192
+ cache_position = torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
193
+
194
+ mask_func = self.model.llm._update_causal_mask
195
+ cond_causal_mask = mask_func(
196
+ input_data['connector_attention_mask'][i].to(self.device), hidden_states, cache_position, None, None)
197
+ for decoder_layer in self.connector[1:]:
198
+ layer_out = decoder_layer(
199
+ hidden_states,
200
+ attention_mask=cond_causal_mask,
201
+ position_ids=input_data['connector_position_ids'][i].to(self.device),
202
+ )
203
+ hidden_states = layer_out[0]
204
+
205
+ llm_input_embeds.append(hidden_states)
206
+
207
+ # import ipdb; ipdb.set_trace()
208
+ model_kwargs['llm_input_embeds'] = llm_input_embeds
209
+ model_kwargs['llm_attention_mask'] = self.move_to_device(input_data['llm_attention_mask'])
210
+ model_kwargs['llm_position_ids'] = self.move_to_device(input_data['llm_position_ids'])
211
+
212
+ if separate_cfg_infer:
213
+ func = self.model.forward_with_separate_cfg
214
+ else:
215
+ func = self.model.forward_with_cfg
216
+
217
+ if self.model_cpu_offload:
218
+ for name, param in self.model.named_parameters():
219
+ if 'layers' in name and 'layers.0' not in name:
220
+ param.data = param.data.cpu()
221
+ else:
222
+ param.data = param.data.to(self.device)
223
+ for buffer_name, buffer in self.model.named_buffers():
224
+ setattr(self.model, buffer_name, buffer.to(self.device))
225
+ # else:
226
+ # self.model.to(self.device)
227
+
228
+ scheduler = OmniGenScheduler(num_steps=num_inference_steps)
229
+ samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache, tqdm_disable=tqdm_disable)
230
+ samples = samples.chunk((1 + num_cfg), dim=0)[0]
231
+
232
+ if self.model_cpu_offload:
233
+ self.model.to('cpu')
234
+ if torch.cuda.is_available():
235
+ torch.cuda.empty_cache() # Clear VRAM
236
+ elif torch_npu.npu.is_available():
237
+ torch_npu.npu.empty_cache() # Clear VRAM
238
+ gc.collect()
239
+
240
+ self.vae.to(self.device)
241
+ samples = samples.to(torch.float32)
242
+ if self.vae.config.shift_factor is not None:
243
+ samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
244
+ else:
245
+ samples = samples / self.vae.config.scaling_factor
246
+ samples = self.vae.decode(samples).sample
247
+
248
+ if self.model_cpu_offload:
249
+ self.vae.to('cpu')
250
+ if torch.cuda.is_available():
251
+ torch.cuda.empty_cache() # Clear VRAM
252
+ elif torch_npu.npu.is_available():
253
+ torch_npu.npu.empty_cache() # Clear VRAM
254
+ gc.collect()
255
+
256
+ samples = (samples * 0.5 + 0.5).clamp(0, 1)
257
+
258
+ if output_type == "pt":
259
+ output_images = samples
260
+ else:
261
+ output_samples = (samples * 255).to("cpu", dtype=torch.uint8)
262
+ output_samples = output_samples.permute(0, 2, 3, 1).numpy()
263
+ output_images = []
264
+ for i, sample in enumerate(output_samples):
265
+ output_images.append(Image.fromarray(sample))
266
+
267
+ if torch.cuda.is_available():
268
+ torch.cuda.empty_cache() # Clear VRAM
269
+ elif torch_npu.npu.is_available():
270
+ torch_npu.npu.empty_cache() # Clear VRAM
271
+ gc.collect() # Run garbage collection to free system RAM
272
+
273
+ return output_images
src/image_decoder/model.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The code is revised from DiT
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import math
7
+ from diffusers.loaders import PeftAdapterMixin
8
+ from huggingface_hub import snapshot_download
9
+ from safetensors.torch import load_file
10
+
11
+ from .transformer import Phi3Transformer
12
+ from transformers import Phi3Config
13
+
14
+
15
+ def modulate(x, shift, scale):
16
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
+
18
+
19
+ class TimestepEmbedder(nn.Module):
20
+ """
21
+ Embeds scalar timesteps into vector representations.
22
+ """
23
+ def __init__(self, hidden_size, frequency_embedding_size=256):
24
+ super().__init__()
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
27
+ nn.SiLU(),
28
+ nn.Linear(hidden_size, hidden_size, bias=True),
29
+ )
30
+ self.frequency_embedding_size = frequency_embedding_size
31
+
32
+ @staticmethod
33
+ def timestep_embedding(t, dim, max_period=10000):
34
+ """
35
+ Create sinusoidal timestep embeddings.
36
+ :param t: a 1-D Tensor of N indices, one per batch element.
37
+ These may be fractional.
38
+ :param dim: the dimension of the output.
39
+ :param max_period: controls the minimum frequency of the embeddings.
40
+ :return: an (N, D) Tensor of positional embeddings.
41
+ """
42
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
43
+ half = dim // 2
44
+ freqs = torch.exp(
45
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
46
+ ).to(device=t.device)
47
+ args = t[:, None].float() * freqs[None]
48
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
49
+ if dim % 2:
50
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
51
+ return embedding
52
+
53
+ def forward(self, t, dtype=torch.float32):
54
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
55
+ t_emb = self.mlp(t_freq)
56
+ return t_emb
57
+
58
+
59
+ class FinalLayer(nn.Module):
60
+ """
61
+ The final layer of DiT.
62
+ """
63
+ def __init__(self, hidden_size, patch_size, out_channels):
64
+ super().__init__()
65
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
66
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
67
+ self.adaLN_modulation = nn.Sequential(
68
+ nn.SiLU(),
69
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
70
+ )
71
+
72
+ def forward(self, x, c):
73
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
74
+ x = modulate(self.norm_final(x), shift, scale)
75
+ x = self.linear(x)
76
+ return x
77
+
78
+
79
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
80
+ """
81
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
82
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
83
+ """
84
+ if isinstance(grid_size, int):
85
+ grid_size = (grid_size, grid_size)
86
+
87
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
88
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
89
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
90
+ grid = np.stack(grid, axis=0)
91
+
92
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
93
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
94
+ if cls_token and extra_tokens > 0:
95
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
96
+ return pos_embed
97
+
98
+
99
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
100
+ assert embed_dim % 2 == 0
101
+
102
+ # use half of dimensions to encode grid_h
103
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
104
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
105
+
106
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
107
+ return emb
108
+
109
+
110
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
111
+ """
112
+ embed_dim: output dimension for each position
113
+ pos: a list of positions to be encoded: size (M,)
114
+ out: (M, D)
115
+ """
116
+ assert embed_dim % 2 == 0
117
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
118
+ omega /= embed_dim / 2.
119
+ omega = 1. / 10000**omega # (D/2,)
120
+
121
+ pos = pos.reshape(-1) # (M,)
122
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
123
+
124
+ emb_sin = np.sin(out) # (M, D/2)
125
+ emb_cos = np.cos(out) # (M, D/2)
126
+
127
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
128
+ return emb
129
+
130
+
131
+ class PatchEmbedMR(nn.Module):
132
+ """ 2D Image to Patch Embedding
133
+ """
134
+ def __init__(
135
+ self,
136
+ patch_size: int = 2,
137
+ in_chans: int = 4,
138
+ embed_dim: int = 768,
139
+ bias: bool = True,
140
+ ):
141
+ super().__init__()
142
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
143
+
144
+ def forward(self, x):
145
+ x = self.proj(x)
146
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
147
+ return x
148
+
149
+
150
+ class OmniGen(nn.Module, PeftAdapterMixin):
151
+ """
152
+ Diffusion model with a Transformer backbone.
153
+ """
154
+ def __init__(
155
+ self,
156
+ transformer_config: Phi3Config,
157
+ patch_size=2,
158
+ in_channels=4,
159
+ pe_interpolation: float = 1.0,
160
+ pos_embed_max_size: int = 192,
161
+ ):
162
+ super().__init__()
163
+ self.in_channels = in_channels
164
+ self.out_channels = in_channels
165
+ self.patch_size = patch_size
166
+ self.pos_embed_max_size = pos_embed_max_size
167
+
168
+ hidden_size = transformer_config.hidden_size
169
+
170
+ self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
171
+ self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
172
+
173
+ self.time_token = TimestepEmbedder(hidden_size)
174
+ self.t_embedder = TimestepEmbedder(hidden_size)
175
+
176
+ self.pe_interpolation = pe_interpolation
177
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
178
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
179
+
180
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
181
+
182
+ self.initialize_weights()
183
+
184
+ self.llm = Phi3Transformer(config=transformer_config)
185
+ self.llm.config.use_cache = False
186
+
187
+ @classmethod
188
+ def from_pretrained(cls, model_name):
189
+ if not os.path.exists(model_name):
190
+ cache_folder = os.getenv('HF_HUB_CACHE')
191
+ model_name = snapshot_download(repo_id=model_name,
192
+ cache_dir=cache_folder,
193
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
194
+ config = Phi3Config.from_pretrained(model_name)
195
+ model = cls(config)
196
+ if os.path.exists(os.path.join(model_name, 'model.safetensors')):
197
+ print("Loading safetensors")
198
+ ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
199
+ else:
200
+ ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
201
+
202
+ module_keys = list(model.state_dict().keys())
203
+ pretrained_keys = list(ckpt.keys())
204
+ all_keys = module_keys + pretrained_keys
205
+ missing_modules = []
206
+ unexpected_modules = []
207
+ for item in all_keys:
208
+ if item in module_keys and item not in ckpt.keys():
209
+ missing_modules.append(item)
210
+ if item not in module_keys and item in ckpt.keys():
211
+ unexpected_modules.append(item)
212
+
213
+ print(f"loading {model.__class__.__name__} but missing modules: {missing_modules}, unexpected modules: {unexpected_modules}")
214
+ model.load_state_dict(ckpt, strict=False)
215
+ return model
216
+
217
+ def initialize_weights(self):
218
+ assert not hasattr(self, "llama")
219
+
220
+ # Initialize transformer layers:
221
+ def _basic_init(module):
222
+ if isinstance(module, nn.Linear):
223
+ torch.nn.init.xavier_uniform_(module.weight)
224
+ if module.bias is not None:
225
+ nn.init.constant_(module.bias, 0)
226
+ self.apply(_basic_init)
227
+
228
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
229
+ w = self.x_embedder.proj.weight.data
230
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
231
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
232
+
233
+ w = self.input_x_embedder.proj.weight.data
234
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
235
+ nn.init.constant_(self.input_x_embedder.proj.bias, 0)
236
+
237
+ # Initialize timestep embedding MLP:
238
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
239
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
240
+ nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
241
+ nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
242
+
243
+ # Zero-out output layers:
244
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
245
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
246
+ nn.init.constant_(self.final_layer.linear.weight, 0)
247
+ nn.init.constant_(self.final_layer.linear.bias, 0)
248
+
249
+ def unpatchify(self, x, h, w):
250
+ """
251
+ x: (N, T, patch_size**2 * C)
252
+ imgs: (N, H, W, C)
253
+ """
254
+ c = self.out_channels
255
+
256
+ x = x.reshape(shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c))
257
+ x = torch.einsum('nhwpqc->nchpwq', x)
258
+ imgs = x.reshape(shape=(x.shape[0], c, h, w))
259
+ return imgs
260
+
261
+ def cropped_pos_embed(self, height, width):
262
+ """Crops positional embeddings for SD3 compatibility."""
263
+ if self.pos_embed_max_size is None:
264
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
265
+
266
+ height = height // self.patch_size
267
+ width = width // self.patch_size
268
+ if height > self.pos_embed_max_size:
269
+ raise ValueError(
270
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
271
+ )
272
+ if width > self.pos_embed_max_size:
273
+ raise ValueError(
274
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
275
+ )
276
+
277
+ top = (self.pos_embed_max_size - height) // 2
278
+ left = (self.pos_embed_max_size - width) // 2
279
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
280
+ spatial_pos_embed = spatial_pos_embed[:, top: top + height, left: left + width, :]
281
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
282
+ return spatial_pos_embed
283
+
284
+ def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images=False):
285
+ if isinstance(latents, list):
286
+ return_list = False
287
+ if padding_latent is None:
288
+ padding_latent = [None] * len(latents)
289
+ return_list = True
290
+ patched_latents, num_tokens, shapes = [], [], []
291
+ for latent, padding in zip(latents, padding_latent):
292
+ height, width = latent.shape[-2:]
293
+ if is_input_images:
294
+ latent = self.input_x_embedder(latent)
295
+ else:
296
+ latent = self.x_embedder(latent)
297
+ pos_embed = self.cropped_pos_embed(height, width)
298
+ latent = latent + pos_embed
299
+ if padding is not None:
300
+ latent = torch.cat([latent, padding], dim=-2)
301
+ patched_latents.append(latent)
302
+
303
+ num_tokens.append(pos_embed.size(1))
304
+ shapes.append([height, width])
305
+ if not return_list:
306
+ latents = torch.cat(patched_latents, dim=0)
307
+ else:
308
+ latents = patched_latents
309
+ else:
310
+ height, width = latents.shape[-2:]
311
+ if is_input_images:
312
+ latents = self.input_x_embedder(latents)
313
+ else:
314
+ latents = self.x_embedder(latents)
315
+ pos_embed = self.cropped_pos_embed(height, width)
316
+ latents = latents + pos_embed
317
+ num_tokens = latents.size(1)
318
+ shapes = [height, width]
319
+ return latents, num_tokens, shapes
320
+
321
+ def forward(self, x, timestep, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model: bool = False,
322
+ llm_input_embeds=None, llm_attention_mask=None, llm_position_ids=None, use_dist=False):
323
+ input_is_list = isinstance(x, list)
324
+ x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
325
+ time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
326
+
327
+ if llm_input_embeds is not None:
328
+ condition_embeds_llm = llm_input_embeds
329
+ input_emb = torch.cat([condition_embeds_llm, time_token, x], dim=1)
330
+ attention_mask = llm_attention_mask
331
+ position_ids = llm_position_ids
332
+ else:
333
+ input_emb = torch.cat([time_token, x], dim=1)
334
+ attention_mask = llm_attention_mask
335
+ position_ids = llm_position_ids
336
+
337
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model, output_hidden_states=True)
338
+ output, past_key_values, all_hidden_states = output.last_hidden_state, output.past_key_values, output.hidden_states
339
+ if not use_dist:
340
+ all_states_noise = None
341
+ if input_is_list:
342
+ image_embedding = output[:, -max(num_tokens):]
343
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
344
+ x = self.final_layer(image_embedding, time_emb)
345
+ latents = []
346
+ if use_dist:
347
+ all_states = torch.stack([hidden_states[:, -max(num_tokens):] for hidden_states in all_hidden_states], dim=1) # b l s d
348
+ all_states_noise = []
349
+ for i in range(x.size(0)):
350
+ latent = x[i: i + 1, :num_tokens[i]]
351
+ latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
352
+ latents.append(latent)
353
+ if use_dist:
354
+ all_states_noise.append(all_states[i, :, :num_tokens[i]])
355
+ else:
356
+ image_embedding = output[:, -num_tokens:]
357
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
358
+ x = self.final_layer(image_embedding, time_emb)
359
+ latents = self.unpatchify(x, shapes[0], shapes[1])
360
+ if use_dist:
361
+ all_states_noise = torch.stack([hidden_states[:, -num_tokens:] for hidden_states in all_hidden_states], dim=1) # b l s d
362
+
363
+ if return_past_key_values:
364
+ return latents, past_key_values, all_states_noise
365
+ return latents, all_states_noise
366
+
367
+ @torch.no_grad()
368
+ def forward_with_separate_cfg(self, x, timestep, cfg_scale, past_key_values, use_kv_cache, offload_model,
369
+ llm_input_embeds=None, llm_attention_mask=None, llm_position_ids=None, llm_padded_input_ids=None, llm_image_sizes=None):
370
+ self.llm.config.use_cache = use_kv_cache
371
+ if past_key_values is None:
372
+ past_key_values = [None] * len(llm_attention_mask)
373
+
374
+ x = torch.split(x, len(x) // len(llm_attention_mask), dim=0)
375
+ timestep = timestep.to(x[0].dtype)
376
+ timestep = torch.split(timestep, len(timestep) // len(llm_input_embeds), dim=0)
377
+
378
+ model_out, pask_key_values = [], []
379
+ for i in range(len(llm_input_embeds)):
380
+ if llm_input_embeds is not None:
381
+ temp_out, temp_pask_key_values, _ = self.forward(x[i], timestep[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model,
382
+ llm_input_embeds=llm_input_embeds[i], llm_attention_mask=llm_attention_mask[i], llm_position_ids=llm_position_ids[i])
383
+ else:
384
+ temp_out, temp_pask_key_values, _ = self.forward(x[i], timestep[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
385
+ model_out.append(temp_out)
386
+ pask_key_values.append(temp_pask_key_values)
387
+
388
+ if len(model_out) == 2:
389
+ cond, uncond = model_out
390
+ cond = uncond + cfg_scale * (cond - uncond)
391
+ model_out = [cond, cond]
392
+ else:
393
+ return model_out[0]
394
+
395
+ return torch.cat(model_out, dim=0), pask_key_values
src/image_decoder/modeling_phi3.py ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi-3 model."""
17
+
18
+ import math
19
+ import warnings
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
+ from transformers.generation import GenerationMixin
30
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ SequenceClassifierOutputWithPast,
35
+ TokenClassifierOutput,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ is_flash_attn_2_available,
43
+ is_flash_attn_greater_or_equal_2_10,
44
+ is_torchdynamo_compiling,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers import Phi3Config
49
+
50
+
51
+ if is_flash_attn_2_available():
52
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
57
+ _CONFIG_FOR_DOC = "Phi3Config"
58
+
59
+
60
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
61
+ def _prepare_4d_causal_attention_mask_with_cache_position(
62
+ attention_mask: torch.Tensor,
63
+ sequence_length: int,
64
+ target_length: int,
65
+ dtype: torch.dtype,
66
+ device: torch.device,
67
+ min_dtype: float,
68
+ cache_position: torch.Tensor,
69
+ batch_size: int,
70
+ ):
71
+ """
72
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
73
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
74
+
75
+ Args:
76
+ attention_mask (`torch.Tensor`):
77
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
78
+ sequence_length (`int`):
79
+ The sequence length being processed.
80
+ target_length (`int`):
81
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
82
+ dtype (`torch.dtype`):
83
+ The dtype to use for the 4D attention mask.
84
+ device (`torch.device`):
85
+ The device to plcae the 4D attention mask on.
86
+ min_dtype (`float`):
87
+ The minimum value representable with the dtype `dtype`.
88
+ cache_position (`torch.Tensor`):
89
+ Indices depicting the position of the input sequence tokens in the sequence.
90
+ batch_size (`torch.Tensor`):
91
+ Batch size.
92
+ """
93
+ if attention_mask is not None and attention_mask.dim() == 4:
94
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
95
+ causal_mask = attention_mask
96
+ else:
97
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
98
+ if sequence_length != 1:
99
+ causal_mask = torch.triu(causal_mask, diagonal=1)
100
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
101
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
102
+ if attention_mask is not None:
103
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
104
+ mask_length = attention_mask.shape[-1]
105
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
106
+ padding_mask = padding_mask == 0
107
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
108
+ padding_mask, min_dtype
109
+ )
110
+
111
+ return causal_mask
112
+
113
+
114
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
115
+ class Phi3RMSNorm(nn.Module):
116
+ def __init__(self, hidden_size, eps=1e-6):
117
+ """
118
+ Phi3RMSNorm is equivalent to T5LayerNorm
119
+ """
120
+ super().__init__()
121
+ self.weight = nn.Parameter(torch.ones(hidden_size))
122
+ self.variance_epsilon = eps
123
+
124
+ def forward(self, hidden_states):
125
+ input_dtype = hidden_states.dtype
126
+ hidden_states = hidden_states.to(torch.float32)
127
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
128
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
129
+ return self.weight * hidden_states.to(input_dtype)
130
+
131
+ def extra_repr(self):
132
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
133
+
134
+
135
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
136
+ class Phi3RotaryEmbedding(nn.Module):
137
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
138
+ super().__init__()
139
+
140
+ self.dim = dim
141
+ self.max_position_embeddings = max_position_embeddings
142
+ self.base = base
143
+
144
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
145
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
146
+
147
+ @torch.no_grad()
148
+ def forward(self, x, position_ids, seq_len=None):
149
+ # x: [bs, num_attention_heads, seq_len, head_size]
150
+ self.inv_freq.to(x.device)
151
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
152
+ position_ids_expanded = position_ids[:, None, :].float()
153
+ # Force float32 since bfloat16 loses precision on long contexts
154
+ # See https://github.com/huggingface/transformers/pull/29285
155
+ device_type = x.device.type
156
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
157
+ with torch.autocast(device_type=device_type, enabled=False):
158
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
159
+ emb = torch.cat((freqs, freqs), dim=-1)
160
+ cos = emb.cos()
161
+ sin = emb.sin()
162
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
163
+
164
+
165
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
166
+ def __init__(self, dim, config, device=None):
167
+ warnings.warn(
168
+ "The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
169
+ " use Phi3LongRoPEScaledRotaryEmbedding instead.",
170
+ FutureWarning,
171
+ )
172
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
173
+
174
+ self.short_factor = config.rope_scaling["short_factor"]
175
+ self.long_factor = config.rope_scaling["long_factor"]
176
+ self.original_max_position_embeddings = config.original_max_position_embeddings
177
+
178
+ @torch.no_grad()
179
+ def forward(self, x, position_ids, seq_len=None):
180
+ seq_len = torch.max(position_ids) + 1
181
+ if seq_len > self.original_max_position_embeddings:
182
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
183
+ else:
184
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
185
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
186
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
187
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
188
+ position_ids_expanded = position_ids[:, None, :].float()
189
+ # Force float32 since bfloat16 loses precision on long contexts
190
+ # See https://github.com/huggingface/transformers/pull/29285
191
+ device_type = x.device.type
192
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
193
+ with torch.autocast(device_type=device_type, enabled=False):
194
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
195
+ emb = torch.cat((freqs, freqs), dim=-1)
196
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
197
+ if scale <= 1.0:
198
+ scaling_factor = 1.0
199
+ else:
200
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
201
+ cos = emb.cos() * scaling_factor
202
+ sin = emb.sin() * scaling_factor
203
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
204
+
205
+
206
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
207
+ def __init__(self, dim, config, device=None):
208
+ warnings.warn(
209
+ "The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
210
+ FutureWarning,
211
+ )
212
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
213
+
214
+ self.short_factor = config.rope_scaling["short_factor"]
215
+ self.long_factor = config.rope_scaling["long_factor"]
216
+ self.original_max_position_embeddings = config.original_max_position_embeddings
217
+
218
+ @torch.no_grad()
219
+ def forward(self, x, position_ids, seq_len=None):
220
+ seq_len = torch.max(position_ids) + 1
221
+ if seq_len > self.original_max_position_embeddings:
222
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
223
+ else:
224
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
225
+
226
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
227
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
228
+
229
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
230
+ position_ids_expanded = position_ids[:, None, :].float()
231
+
232
+ # Force float32 since bfloat16 loses precision on long contexts
233
+ # See https://github.com/huggingface/transformers/pull/29285
234
+ device_type = x.device.type
235
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
236
+ with torch.autocast(device_type=device_type, enabled=False):
237
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
238
+ emb = torch.cat((freqs, freqs), dim=-1)
239
+
240
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
241
+ if scale <= 1.0:
242
+ scaling_factor = 1.0
243
+ else:
244
+ scaling_factor = 0.1 * math.log(scale) + 1.0
245
+
246
+ cos = emb.cos() * scaling_factor
247
+ sin = emb.sin() * scaling_factor
248
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
249
+
250
+
251
+ class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
252
+ def __init__(self, dim, config, device=None):
253
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
254
+
255
+ self.short_factor = config.rope_scaling["short_factor"]
256
+ self.long_factor = config.rope_scaling["long_factor"]
257
+ self.original_max_position_embeddings = config.original_max_position_embeddings
258
+
259
+ @torch.no_grad()
260
+ def forward(self, x, position_ids, seq_len=None):
261
+ seq_len = seq_len or torch.max(position_ids) + 1
262
+ if seq_len > self.original_max_position_embeddings:
263
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
264
+ else:
265
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
266
+
267
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
268
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
269
+
270
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
271
+ position_ids_expanded = position_ids[:, None, :].float()
272
+
273
+ # Force float32 since bfloat16 loses precision on long contexts
274
+ # See https://github.com/huggingface/transformers/pull/29285
275
+ device_type = x.device.type
276
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
277
+ with torch.autocast(device_type=device_type, enabled=False):
278
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
279
+ emb = torch.cat((freqs, freqs), dim=-1)
280
+
281
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
282
+ if scale <= 1.0:
283
+ scaling_factor = 1.0
284
+ else:
285
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
286
+
287
+ cos = emb.cos() * scaling_factor
288
+ sin = emb.sin() * scaling_factor
289
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
290
+
291
+
292
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
293
+ def rotate_half(x):
294
+ """Rotates half the hidden dims of the input."""
295
+ x1 = x[..., : x.shape[-1] // 2]
296
+ x2 = x[..., x.shape[-1] // 2 :]
297
+ return torch.cat((-x2, x1), dim=-1)
298
+
299
+
300
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
301
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
302
+ """Applies Rotary Position Embedding to the query and key tensors.
303
+
304
+ Args:
305
+ q (`torch.Tensor`): The query tensor.
306
+ k (`torch.Tensor`): The key tensor.
307
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
308
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
309
+ position_ids (`torch.Tensor`, *optional*):
310
+ Deprecated and unused.
311
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
312
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
313
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
314
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
315
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
316
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
317
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
318
+ Returns:
319
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
320
+ """
321
+ cos = cos.unsqueeze(unsqueeze_dim)
322
+ sin = sin.unsqueeze(unsqueeze_dim)
323
+ q_embed = (q * cos) + (rotate_half(q) * sin)
324
+ k_embed = (k * cos) + (rotate_half(k) * sin)
325
+ return q_embed, k_embed
326
+
327
+
328
+ class Phi3MLP(nn.Module):
329
+ def __init__(self, config):
330
+ super().__init__()
331
+
332
+ self.config = config
333
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
334
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
335
+
336
+ self.activation_fn = ACT2FN[config.hidden_act]
337
+
338
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
339
+ up_states = self.gate_up_proj(hidden_states)
340
+
341
+ gate, up_states = up_states.chunk(2, dim=-1)
342
+ up_states = up_states * self.activation_fn(gate)
343
+
344
+ return self.down_proj(up_states)
345
+
346
+
347
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
348
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
349
+ """
350
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
351
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
352
+ """
353
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
354
+ if n_rep == 1:
355
+ return hidden_states
356
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
357
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
358
+
359
+
360
+ class Phi3Attention(nn.Module):
361
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
362
+
363
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
364
+ super().__init__()
365
+ self.config = config
366
+ self.layer_idx = layer_idx
367
+ if layer_idx is None:
368
+ logger.warning_once(
369
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
370
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
371
+ "when creating this class."
372
+ )
373
+
374
+ self.attention_dropout = config.attention_dropout
375
+ self.hidden_size = config.hidden_size
376
+ self.num_heads = config.num_attention_heads
377
+ self.head_dim = self.hidden_size // self.num_heads
378
+ self.num_key_value_heads = config.num_key_value_heads
379
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
380
+ self.max_position_embeddings = config.max_position_embeddings
381
+ self.original_max_position_embeddings = config.original_max_position_embeddings
382
+ self.rope_theta = config.rope_theta
383
+ self.rope_scaling = config.rope_scaling
384
+ self.is_causal = True
385
+
386
+ if (self.head_dim * self.num_heads) != self.hidden_size:
387
+ raise ValueError(
388
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
389
+ f" and `num_heads`: {self.num_heads})."
390
+ )
391
+
392
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
393
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
394
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
395
+ self._init_rope()
396
+
397
+ def _init_rope(self):
398
+ if self.rope_scaling is None:
399
+ self.rotary_emb = Phi3RotaryEmbedding(
400
+ self.head_dim,
401
+ max_position_embeddings=self.max_position_embeddings,
402
+ base=self.rope_theta,
403
+ )
404
+ else:
405
+ scaling_type = self.config.rope_scaling["type"]
406
+ if scaling_type == "longrope":
407
+ self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
408
+ else:
409
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
410
+
411
+ def forward(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ position_ids: Optional[torch.LongTensor] = None,
416
+ past_key_value: Optional[Cache] = None,
417
+ output_attentions: bool = False,
418
+ use_cache: bool = False,
419
+ cache_position: Optional[torch.LongTensor] = None,
420
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
421
+ logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
422
+
423
+ bsz, q_len, _ = hidden_states.size()
424
+
425
+ qkv = self.qkv_proj(hidden_states)
426
+ query_pos = self.num_heads * self.head_dim
427
+ query_states = qkv[..., :query_pos]
428
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
429
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
430
+
431
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
432
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
433
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
434
+
435
+ kv_seq_len = key_states.shape[-2]
436
+ if past_key_value is not None:
437
+ if self.layer_idx is None:
438
+ raise ValueError(
439
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
440
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
441
+ "with a layer index."
442
+ )
443
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
444
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
445
+
446
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
447
+
448
+ if past_key_value is not None:
449
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
450
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
451
+
452
+ # repeat k/v heads if n_kv_heads < n_heads
453
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
454
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
455
+
456
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
457
+
458
+ if attention_mask is not None:
459
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
460
+ attn_weights += causal_mask
461
+
462
+ # upcast attention to fp32
463
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
464
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
465
+
466
+ attn_output = torch.matmul(attn_weights, value_states)
467
+
468
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
469
+ raise ValueError(
470
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
471
+ f" {attn_output.size()}"
472
+ )
473
+
474
+ attn_output = attn_output.transpose(1, 2).contiguous()
475
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
476
+
477
+ attn_output = self.o_proj(attn_output)
478
+
479
+ if not output_attentions:
480
+ attn_weights = None
481
+
482
+ return attn_output, attn_weights, past_key_value
483
+
484
+
485
+ class Phi3FlashAttention2(Phi3Attention):
486
+ """
487
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
488
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
489
+ flash attention and deal with padding tokens in case the input contains any of them.
490
+ """
491
+
492
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
493
+ def __init__(self, *args, **kwargs):
494
+ super().__init__(*args, **kwargs)
495
+
496
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
497
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
498
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
499
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ attention_mask: Optional[torch.LongTensor] = None,
505
+ position_ids: Optional[torch.LongTensor] = None,
506
+ past_key_value: Optional[Cache] = None,
507
+ output_attentions: bool = False,
508
+ use_cache: bool = False,
509
+ cache_position: Optional[torch.LongTensor] = None,
510
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
511
+ # Phi3FlashAttention2 attention does not support output_attentions
512
+
513
+ output_attentions = False
514
+
515
+ bsz, q_len, _ = hidden_states.size()
516
+
517
+ qkv = self.qkv_proj(hidden_states)
518
+ query_pos = self.num_heads * self.head_dim
519
+ query_states = qkv[..., :query_pos]
520
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
521
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
522
+
523
+ # Flash attention requires the input to have the shape
524
+ # batch_size x seq_length x head_dim x hidden_dim
525
+ # therefore we just need to keep the original shape
526
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
527
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
528
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
529
+
530
+ kv_seq_len = key_states.shape[-2]
531
+ if past_key_value is not None:
532
+ if self.layer_idx is None:
533
+ raise ValueError(
534
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
535
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
536
+ "with a layer index."
537
+ )
538
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
539
+
540
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
541
+ rotary_seq_len = (
542
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
543
+ )
544
+
545
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids)
546
+
547
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
548
+
549
+ if past_key_value is not None:
550
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
551
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
552
+ if (
553
+ getattr(self.config, "sliding_window", None) is not None
554
+ and kv_seq_len > self.config.sliding_window
555
+ and cache_has_contents
556
+ ):
557
+ slicing_tokens = 1 - self.config.sliding_window
558
+
559
+ past_key = past_key_value[self.layer_idx][0]
560
+ past_value = past_key_value[self.layer_idx][1]
561
+
562
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
563
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
564
+
565
+ if past_key.shape[-2] != self.config.sliding_window - 1:
566
+ raise ValueError(
567
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
568
+ f" {past_key.shape}"
569
+ )
570
+
571
+ if attention_mask is not None:
572
+ attention_mask = attention_mask[:, slicing_tokens:]
573
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
574
+
575
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
576
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
577
+
578
+ # repeat k/v heads if n_kv_heads < n_heads
579
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
580
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
581
+
582
+ attn_dropout = self.attention_dropout if self.training else 0.0
583
+
584
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
585
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
586
+ # cast them back in the correct dtype just to be sure everything works as expected.
587
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
588
+ # in fp32.
589
+
590
+ if query_states.dtype == torch.float32:
591
+ if torch.is_autocast_enabled():
592
+ target_dtype = torch.get_autocast_gpu_dtype()
593
+ # Handle the case where the model is quantized
594
+ elif hasattr(self.config, "_pre_quantization_dtype"):
595
+ target_dtype = self.config._pre_quantization_dtype
596
+ else:
597
+ target_dtype = self.qkv_proj.weight.dtype
598
+
599
+ logger.warning_once(
600
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
601
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
602
+ f" {target_dtype}."
603
+ )
604
+
605
+ query_states = query_states.to(target_dtype)
606
+ key_states = key_states.to(target_dtype)
607
+ value_states = value_states.to(target_dtype)
608
+
609
+ # Reashape to the expected shape for Flash Attention
610
+ query_states = query_states.transpose(1, 2)
611
+ key_states = key_states.transpose(1, 2)
612
+ value_states = value_states.transpose(1, 2)
613
+
614
+ attn_output = _flash_attention_forward(
615
+ query_states,
616
+ key_states,
617
+ value_states,
618
+ attention_mask,
619
+ q_len,
620
+ position_ids=position_ids,
621
+ dropout=attn_dropout,
622
+ sliding_window=getattr(self.config, "sliding_window", None),
623
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
624
+ is_causal=self.is_causal,
625
+ )
626
+
627
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
628
+ attn_output = self.o_proj(attn_output)
629
+
630
+ if not output_attentions:
631
+ attn_weights = None
632
+
633
+ return attn_output, attn_weights, past_key_value
634
+
635
+
636
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
637
+ # TODO @Arthur no longer copied from LLama after static cache
638
+ class Phi3SdpaAttention(Phi3Attention):
639
+ """
640
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
641
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
642
+ SDPA API.
643
+ """
644
+
645
+ # Adapted from Phi3Attention.forward
646
+ def forward(
647
+ self,
648
+ hidden_states: torch.Tensor,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ position_ids: Optional[torch.LongTensor] = None,
651
+ past_key_value: Optional[Cache] = None,
652
+ output_attentions: bool = False,
653
+ use_cache: bool = False,
654
+ cache_position: Optional[torch.LongTensor] = None,
655
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
656
+ if output_attentions:
657
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
658
+ logger.warning_once(
659
+ "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
660
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
661
+ )
662
+ return super().forward(
663
+ hidden_states=hidden_states,
664
+ attention_mask=attention_mask,
665
+ position_ids=position_ids,
666
+ past_key_value=past_key_value,
667
+ output_attentions=output_attentions,
668
+ use_cache=use_cache,
669
+ )
670
+
671
+ bsz, q_len, _ = hidden_states.size()
672
+
673
+ qkv = self.qkv_proj(hidden_states)
674
+ query_pos = self.num_heads * self.head_dim
675
+ query_states = qkv[..., :query_pos]
676
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
677
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
678
+
679
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
680
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
681
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
682
+
683
+ kv_seq_len = key_states.shape[-2]
684
+ if past_key_value is not None:
685
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
686
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
687
+
688
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
689
+
690
+ if past_key_value is not None:
691
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
692
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
693
+
694
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
695
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
696
+
697
+ causal_mask = attention_mask
698
+ if attention_mask is not None:
699
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
700
+
701
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
702
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
703
+ if query_states.device.type == "cuda" and attention_mask is not None:
704
+ query_states = query_states.contiguous()
705
+ key_states = key_states.contiguous()
706
+ value_states = value_states.contiguous()
707
+
708
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
709
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
710
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
711
+ is_causal = True if causal_mask is None and q_len > 1 else False
712
+
713
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
714
+ query_states,
715
+ key_states,
716
+ value_states,
717
+ attn_mask=causal_mask,
718
+ dropout_p=self.attention_dropout if self.training else 0.0,
719
+ is_causal=is_causal,
720
+ )
721
+
722
+ attn_output = attn_output.transpose(1, 2).contiguous()
723
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
724
+
725
+ attn_output = self.o_proj(attn_output)
726
+
727
+ return attn_output, None, past_key_value
728
+
729
+
730
+ PHI3_ATTENTION_CLASSES = {
731
+ "eager": Phi3Attention,
732
+ "flash_attention_2": Phi3FlashAttention2,
733
+ "sdpa": Phi3SdpaAttention,
734
+ }
735
+
736
+
737
+ class Phi3DecoderLayer(nn.Module):
738
+ def __init__(self, config: Phi3Config, layer_idx: int):
739
+ super().__init__()
740
+
741
+ self.config = config
742
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
743
+
744
+ self.mlp = Phi3MLP(config)
745
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
746
+
747
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
748
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
749
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
750
+
751
+ def forward(
752
+ self,
753
+ hidden_states: torch.Tensor,
754
+ attention_mask: Optional[torch.Tensor] = None,
755
+ position_ids: Optional[torch.LongTensor] = None,
756
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
757
+ output_attentions: Optional[bool] = False,
758
+ use_cache: Optional[bool] = False,
759
+ cache_position: Optional[torch.LongTensor] = None,
760
+ **kwargs,
761
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
762
+ """
763
+ Args:
764
+ hidden_states (`torch.FloatTensor`):
765
+ input to the layer of shape `(batch, seq_len, embed_dim)`
766
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
767
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
768
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
769
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
770
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
771
+ output_attentions (`bool`, *optional*):
772
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
773
+ returned tensors for more detail.
774
+ use_cache (`bool`, *optional*):
775
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
776
+ (see `past_key_values`).
777
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
778
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
779
+ Indices depicting the position of the input sequence tokens in the sequence
780
+ kwargs (`dict`, *optional*):
781
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
782
+ into the model
783
+ """
784
+
785
+ residual = hidden_states
786
+
787
+ hidden_states = self.input_layernorm(hidden_states)
788
+
789
+ # Self Attention
790
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
791
+ hidden_states=hidden_states,
792
+ attention_mask=attention_mask,
793
+ position_ids=position_ids,
794
+ past_key_value=past_key_value,
795
+ output_attentions=output_attentions,
796
+ use_cache=use_cache,
797
+ cache_position=cache_position,
798
+ )
799
+
800
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
801
+
802
+ residual = hidden_states
803
+ hidden_states = self.post_attention_layernorm(hidden_states)
804
+ hidden_states = self.mlp(hidden_states)
805
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
806
+
807
+ outputs = (hidden_states,)
808
+
809
+ if output_attentions:
810
+ outputs += (self_attn_weights,)
811
+
812
+ if use_cache:
813
+ outputs += (present_key_value,)
814
+
815
+ return outputs
816
+
817
+
818
+ PHI3_START_DOCSTRING = r"""
819
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
820
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
821
+ etc.)
822
+
823
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
824
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
825
+ and behavior.
826
+
827
+ Parameters:
828
+ config ([`Phi3Config`]):
829
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
830
+ load the weights associated with the model, only the configuration. Check out the
831
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
832
+ """
833
+
834
+
835
+ @add_start_docstrings(
836
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
837
+ PHI3_START_DOCSTRING,
838
+ )
839
+ class Phi3PreTrainedModel(PreTrainedModel):
840
+ config_class = Phi3Config
841
+ base_model_prefix = "model"
842
+ supports_gradient_checkpointing = True
843
+ _no_split_modules = ["Phi3DecoderLayer"]
844
+ _skip_keys_device_placement = "past_key_values"
845
+ _supports_flash_attn_2 = True
846
+ _supports_sdpa = True
847
+ _supports_cache_class = True
848
+
849
+ _version = "0.0.5"
850
+
851
+ def _init_weights(self, module):
852
+ std = self.config.initializer_range
853
+ if isinstance(module, nn.Linear):
854
+ module.weight.data.normal_(mean=0.0, std=std)
855
+ if module.bias is not None:
856
+ module.bias.data.zero_()
857
+ elif isinstance(module, nn.Embedding):
858
+ module.weight.data.normal_(mean=0.0, std=std)
859
+ if module.padding_idx is not None:
860
+ module.weight.data[module.padding_idx].zero_()
861
+
862
+
863
+ PHI3_INPUTS_DOCSTRING = r"""
864
+ Args:
865
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
866
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
867
+ it.
868
+
869
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
870
+ [`PreTrainedTokenizer.__call__`] for details.
871
+
872
+ [What are input IDs?](../glossary#input-ids)
873
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
874
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
875
+
876
+ - 1 for tokens that are **not masked**,
877
+ - 0 for tokens that are **masked**.
878
+
879
+ [What are attention masks?](../glossary#attention-mask)
880
+
881
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
882
+ [`PreTrainedTokenizer.__call__`] for details.
883
+
884
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
885
+ `past_key_values`).
886
+
887
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
888
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
889
+ information on the default strategy.
890
+
891
+ - 1 indicates the head is **not masked**,
892
+ - 0 indicates the head is **masked**.
893
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
894
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
895
+ config.n_positions - 1]`.
896
+
897
+ [What are position IDs?](../glossary#position-ids)
898
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
899
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
900
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
901
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
902
+
903
+ Two formats are allowed:
904
+ - a [`~cache_utils.Cache`] instance, see our
905
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
906
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
907
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
908
+ cache format.
909
+
910
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
911
+ legacy cache format will be returned.
912
+
913
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
914
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
915
+ of shape `(batch_size, sequence_length)`.
916
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
917
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
918
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
919
+ model's internal embedding lookup matrix.
920
+ use_cache (`bool`, *optional*):
921
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
922
+ `past_key_values`).
923
+ output_attentions (`bool`, *optional*):
924
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
925
+ tensors for more detail.
926
+ output_hidden_states (`bool`, *optional*):
927
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
928
+ more detail.
929
+ return_dict (`bool`, *optional*):
930
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
931
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
932
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
933
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
934
+ the complete sequence length.
935
+ """
936
+
937
+
938
+ @add_start_docstrings(
939
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
940
+ PHI3_START_DOCSTRING,
941
+ )
942
+ class Phi3Model(Phi3PreTrainedModel):
943
+ """
944
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
945
+
946
+ Args:
947
+ config: Phi3Config
948
+ """
949
+
950
+ def __init__(self, config: Phi3Config):
951
+ super().__init__(config)
952
+ self.padding_idx = config.pad_token_id
953
+ self.vocab_size = config.vocab_size
954
+
955
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
956
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
957
+ self.layers = nn.ModuleList(
958
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
959
+ )
960
+ self._attn_implementation = config._attn_implementation
961
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
962
+
963
+ self.gradient_checkpointing = False
964
+ # Initialize weights and apply final processing
965
+ self.post_init()
966
+
967
+ def get_input_embeddings(self):
968
+ return self.embed_tokens
969
+
970
+ def set_input_embeddings(self, value):
971
+ self.embed_tokens = value
972
+
973
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
974
+ def forward(
975
+ self,
976
+ input_ids: torch.LongTensor = None,
977
+ attention_mask: Optional[torch.Tensor] = None,
978
+ position_ids: Optional[torch.LongTensor] = None,
979
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
980
+ inputs_embeds: Optional[torch.FloatTensor] = None,
981
+ use_cache: Optional[bool] = None,
982
+ output_attentions: Optional[bool] = None,
983
+ output_hidden_states: Optional[bool] = None,
984
+ return_dict: Optional[bool] = None,
985
+ cache_position: Optional[torch.LongTensor] = None,
986
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
987
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
988
+ output_hidden_states = (
989
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
990
+ )
991
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
992
+
993
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
994
+
995
+ if (input_ids is None) ^ (inputs_embeds is not None):
996
+ raise ValueError(
997
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
998
+ )
999
+
1000
+ if self.gradient_checkpointing and self.training:
1001
+ if use_cache:
1002
+ logger.warning_once(
1003
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1004
+ )
1005
+ use_cache = False
1006
+
1007
+ # kept for BC (non `Cache` `past_key_values` inputs)
1008
+ return_legacy_cache = False
1009
+ if use_cache and not isinstance(past_key_values, Cache):
1010
+ return_legacy_cache = True
1011
+ if past_key_values is None:
1012
+ past_key_values = DynamicCache()
1013
+ else:
1014
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1015
+ logger.warning_once(
1016
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
1017
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
1018
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
1019
+ )
1020
+
1021
+ if inputs_embeds is None:
1022
+ inputs_embeds = self.embed_tokens(input_ids)
1023
+
1024
+ if cache_position is None:
1025
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1026
+ cache_position = torch.arange(
1027
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1028
+ )
1029
+ if position_ids is None:
1030
+ position_ids = cache_position.unsqueeze(0)
1031
+
1032
+ causal_mask = self._update_causal_mask(
1033
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1034
+ )
1035
+
1036
+ hidden_states = inputs_embeds
1037
+
1038
+ # decoder layers
1039
+ all_hidden_states = () if output_hidden_states else None
1040
+ all_self_attns = () if output_attentions else None
1041
+ next_decoder_cache = None
1042
+
1043
+ for decoder_layer in self.layers:
1044
+ if output_hidden_states:
1045
+ all_hidden_states += (hidden_states,)
1046
+
1047
+ if self.gradient_checkpointing and self.training:
1048
+ layer_outputs = self._gradient_checkpointing_func(
1049
+ decoder_layer.__call__,
1050
+ hidden_states,
1051
+ causal_mask,
1052
+ position_ids,
1053
+ past_key_values,
1054
+ output_attentions,
1055
+ use_cache,
1056
+ cache_position,
1057
+ )
1058
+ else:
1059
+ layer_outputs = decoder_layer(
1060
+ hidden_states,
1061
+ attention_mask=causal_mask,
1062
+ position_ids=position_ids,
1063
+ past_key_value=past_key_values,
1064
+ output_attentions=output_attentions,
1065
+ use_cache=use_cache,
1066
+ cache_position=cache_position,
1067
+ )
1068
+
1069
+ hidden_states = layer_outputs[0]
1070
+
1071
+ if use_cache:
1072
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1073
+
1074
+ if output_attentions:
1075
+ all_self_attns += (layer_outputs[1],)
1076
+
1077
+ hidden_states = self.norm(hidden_states)
1078
+
1079
+ # add hidden states from the last decoder layer
1080
+ if output_hidden_states:
1081
+ all_hidden_states += (hidden_states,)
1082
+
1083
+ next_cache = next_decoder_cache if use_cache else None
1084
+ if return_legacy_cache:
1085
+ next_cache = next_cache.to_legacy_cache()
1086
+
1087
+ if not return_dict:
1088
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1089
+ return BaseModelOutputWithPast(
1090
+ last_hidden_state=hidden_states,
1091
+ past_key_values=next_cache,
1092
+ hidden_states=all_hidden_states,
1093
+ attentions=all_self_attns,
1094
+ )
1095
+
1096
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1097
+ def _update_causal_mask(
1098
+ self,
1099
+ attention_mask: torch.Tensor,
1100
+ input_tensor: torch.Tensor,
1101
+ cache_position: torch.Tensor,
1102
+ past_key_values: Cache,
1103
+ output_attentions: bool,
1104
+ ):
1105
+ if self.config._attn_implementation == "flash_attention_2":
1106
+ if attention_mask is not None and 0.0 in attention_mask:
1107
+ return attention_mask
1108
+ return None
1109
+
1110
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1111
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1112
+ # to infer the attention mask.
1113
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1114
+ using_static_cache = isinstance(past_key_values, StaticCache)
1115
+
1116
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1117
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1118
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1119
+ attention_mask,
1120
+ inputs_embeds=input_tensor,
1121
+ past_key_values_length=past_seen_tokens,
1122
+ is_training=self.training,
1123
+ ):
1124
+ return None
1125
+
1126
+ dtype, device = input_tensor.dtype, input_tensor.device
1127
+ min_dtype = torch.finfo(dtype).min
1128
+ sequence_length = input_tensor.shape[1]
1129
+ if using_static_cache:
1130
+ target_length = past_key_values.get_max_length()
1131
+ else:
1132
+ target_length = (
1133
+ attention_mask.shape[-1]
1134
+ if isinstance(attention_mask, torch.Tensor)
1135
+ else past_seen_tokens + sequence_length + 1
1136
+ )
1137
+
1138
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1139
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1140
+ attention_mask,
1141
+ sequence_length=sequence_length,
1142
+ target_length=target_length,
1143
+ dtype=dtype,
1144
+ device=device,
1145
+ min_dtype=min_dtype,
1146
+ cache_position=cache_position,
1147
+ batch_size=input_tensor.shape[0],
1148
+ )
1149
+
1150
+ if (
1151
+ self.config._attn_implementation == "sdpa"
1152
+ and attention_mask is not None
1153
+ and attention_mask.device.type == "cuda"
1154
+ and not output_attentions
1155
+ ):
1156
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1157
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1158
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1159
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1160
+
1161
+ return causal_mask
1162
+
1163
+
1164
+ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
1165
+ _tied_weights_keys = ["lm_head.weight"]
1166
+
1167
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1168
+ def __init__(self, config):
1169
+ super().__init__(config)
1170
+ self.model = Phi3Model(config)
1171
+ self.vocab_size = config.vocab_size
1172
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1173
+
1174
+ # Initialize weights and apply final processing
1175
+ self.post_init()
1176
+
1177
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1178
+ def get_input_embeddings(self):
1179
+ return self.model.embed_tokens
1180
+
1181
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1182
+ def set_input_embeddings(self, value):
1183
+ self.model.embed_tokens = value
1184
+
1185
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1186
+ def get_output_embeddings(self):
1187
+ return self.lm_head
1188
+
1189
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1190
+ def set_output_embeddings(self, new_embeddings):
1191
+ self.lm_head = new_embeddings
1192
+
1193
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1194
+ def set_decoder(self, decoder):
1195
+ self.model = decoder
1196
+
1197
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1198
+ def get_decoder(self):
1199
+ return self.model
1200
+
1201
+ # Ignore copy
1202
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1203
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1204
+ def forward(
1205
+ self,
1206
+ input_ids: torch.LongTensor = None,
1207
+ attention_mask: Optional[torch.Tensor] = None,
1208
+ position_ids: Optional[torch.LongTensor] = None,
1209
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1210
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1211
+ labels: Optional[torch.LongTensor] = None,
1212
+ use_cache: Optional[bool] = None,
1213
+ output_attentions: Optional[bool] = None,
1214
+ output_hidden_states: Optional[bool] = None,
1215
+ return_dict: Optional[bool] = None,
1216
+ cache_position: Optional[torch.LongTensor] = None,
1217
+ num_logits_to_keep: int = 0,
1218
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1219
+ r"""
1220
+ Args:
1221
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1222
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1223
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1224
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1225
+
1226
+ num_logits_to_keep (`int`, *optional*):
1227
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1228
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1229
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1230
+
1231
+ Returns:
1232
+
1233
+ Example:
1234
+
1235
+ ```python
1236
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1237
+
1238
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1239
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1240
+
1241
+ >>> prompt = "This is an example script ."
1242
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1243
+
1244
+ >>> # Generate
1245
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1246
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1247
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1248
+ ```"""
1249
+ if (
1250
+ use_cache
1251
+ and self.config.rope_scaling
1252
+ and cache_position is not None
1253
+ and cache_position[0] == self.config.original_max_position_embeddings
1254
+ ):
1255
+ logger.warning(
1256
+ f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
1257
+ )
1258
+
1259
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1260
+ output_hidden_states = (
1261
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1262
+ )
1263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1264
+
1265
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1266
+ outputs = self.model(
1267
+ input_ids=input_ids,
1268
+ attention_mask=attention_mask,
1269
+ position_ids=position_ids,
1270
+ past_key_values=past_key_values,
1271
+ inputs_embeds=inputs_embeds,
1272
+ use_cache=use_cache,
1273
+ output_attentions=output_attentions,
1274
+ output_hidden_states=output_hidden_states,
1275
+ return_dict=return_dict,
1276
+ )
1277
+
1278
+ hidden_states = outputs[0]
1279
+ if labels is None and not is_torchdynamo_compiling():
1280
+ logger.warning_once(
1281
+ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
1282
+ )
1283
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1284
+ # TODO: remove the float() operation in v4.46
1285
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1286
+
1287
+ loss = None
1288
+ if labels is not None:
1289
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1290
+ logits = logits.float()
1291
+ # Shift so that tokens < n predict n
1292
+ shift_logits = logits[..., :-1, :].contiguous()
1293
+ shift_labels = labels[..., 1:].contiguous()
1294
+ # Flatten the tokens
1295
+ loss_fct = CrossEntropyLoss()
1296
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1297
+ shift_labels = shift_labels.view(-1)
1298
+ # Enable model parallelism
1299
+ shift_labels = shift_labels.to(shift_logits.device)
1300
+ loss = loss_fct(shift_logits, shift_labels)
1301
+
1302
+ if not return_dict:
1303
+ output = (logits,) + outputs[1:]
1304
+ return (loss,) + output if loss is not None else output
1305
+
1306
+ return CausalLMOutputWithPast(
1307
+ loss=loss,
1308
+ logits=logits,
1309
+ past_key_values=outputs.past_key_values,
1310
+ hidden_states=outputs.hidden_states,
1311
+ attentions=outputs.attentions,
1312
+ )
1313
+
1314
+ def prepare_inputs_for_generation(
1315
+ self,
1316
+ input_ids,
1317
+ past_key_values=None,
1318
+ attention_mask=None,
1319
+ inputs_embeds=None,
1320
+ cache_position=None,
1321
+ position_ids=None,
1322
+ use_cache=True,
1323
+ num_logits_to_keep=None,
1324
+ **kwargs,
1325
+ ):
1326
+ # When the first time input length reached long and short factor switching point, enforce re-compute cache
1327
+ # It will cause downside of slower at this single token position, however, better than current failure.
1328
+ if (
1329
+ past_key_values
1330
+ and self.config.rope_scaling
1331
+ and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
1332
+ ):
1333
+ past_length = cache_position[0]
1334
+ if past_length <= self.config.original_max_position_embeddings:
1335
+ past_key_values = None
1336
+
1337
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1338
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1339
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1340
+ if past_key_values is not None:
1341
+ if inputs_embeds is not None: # Exception 1
1342
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1343
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1344
+ input_ids = input_ids[:, cache_position]
1345
+
1346
+ if attention_mask is not None and position_ids is None:
1347
+ # create position_ids on the fly for batch generation
1348
+ position_ids = attention_mask.long().cumsum(-1) - 1
1349
+ position_ids.masked_fill_(attention_mask == 0, 1)
1350
+ if past_key_values:
1351
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1352
+
1353
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1354
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1355
+
1356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1357
+ if inputs_embeds is not None and cache_position[0] == 0:
1358
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1359
+ else:
1360
+ # The clone here is for the same reason as for `position_ids`.
1361
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1362
+
1363
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1364
+ if model_inputs["inputs_embeds"] is not None:
1365
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1366
+ device = model_inputs["inputs_embeds"].device
1367
+ else:
1368
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1369
+ device = model_inputs["input_ids"].device
1370
+
1371
+ dtype = self.lm_head.weight.dtype
1372
+ min_dtype = torch.finfo(dtype).min
1373
+
1374
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1375
+ attention_mask,
1376
+ sequence_length=sequence_length,
1377
+ target_length=past_key_values.get_max_length(),
1378
+ dtype=dtype,
1379
+ device=device,
1380
+ min_dtype=min_dtype,
1381
+ cache_position=cache_position,
1382
+ batch_size=batch_size,
1383
+ )
1384
+
1385
+ if num_logits_to_keep is not None:
1386
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
1387
+
1388
+ model_inputs.update(
1389
+ {
1390
+ "position_ids": position_ids,
1391
+ "cache_position": cache_position,
1392
+ "past_key_values": past_key_values,
1393
+ "use_cache": use_cache,
1394
+ "attention_mask": attention_mask,
1395
+ }
1396
+ )
1397
+ return model_inputs
1398
+
1399
+
1400
+ @add_start_docstrings(
1401
+ """
1402
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
1403
+
1404
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1405
+ (e.g. GPT-2) do.
1406
+
1407
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1408
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1409
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1410
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1411
+ each row of the batch).
1412
+ """,
1413
+ PHI3_START_DOCSTRING,
1414
+ )
1415
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1416
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1417
+ def __init__(self, config):
1418
+ super().__init__(config)
1419
+ self.num_labels = config.num_labels
1420
+ self.model = Phi3Model(config)
1421
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1422
+
1423
+ # Initialize weights and apply final processing
1424
+ self.post_init()
1425
+
1426
+ def get_input_embeddings(self):
1427
+ return self.model.embed_tokens
1428
+
1429
+ def set_input_embeddings(self, value):
1430
+ self.model.embed_tokens = value
1431
+
1432
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1433
+ def forward(
1434
+ self,
1435
+ input_ids: Optional[torch.LongTensor] = None,
1436
+ attention_mask: Optional[torch.Tensor] = None,
1437
+ position_ids: Optional[torch.LongTensor] = None,
1438
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1439
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1440
+ labels: Optional[torch.LongTensor] = None,
1441
+ use_cache: Optional[bool] = None,
1442
+ output_attentions: Optional[bool] = None,
1443
+ output_hidden_states: Optional[bool] = None,
1444
+ return_dict: Optional[bool] = None,
1445
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1446
+ r"""
1447
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1448
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1449
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1450
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1451
+ """
1452
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1453
+
1454
+ model_outputs = self.model(
1455
+ input_ids,
1456
+ attention_mask=attention_mask,
1457
+ position_ids=position_ids,
1458
+ past_key_values=past_key_values,
1459
+ inputs_embeds=inputs_embeds,
1460
+ use_cache=use_cache,
1461
+ output_attentions=output_attentions,
1462
+ output_hidden_states=output_hidden_states,
1463
+ return_dict=return_dict,
1464
+ )
1465
+ hidden_states = model_outputs[0]
1466
+ logits = self.score(hidden_states)
1467
+
1468
+ if input_ids is not None:
1469
+ batch_size = input_ids.shape[0]
1470
+ else:
1471
+ batch_size = inputs_embeds.shape[0]
1472
+
1473
+ if self.config.pad_token_id is None and batch_size != 1:
1474
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1475
+ if self.config.pad_token_id is None:
1476
+ sequence_lengths = -1
1477
+ else:
1478
+ if input_ids is not None:
1479
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1480
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1481
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1482
+ sequence_lengths = sequence_lengths.to(logits.device)
1483
+ else:
1484
+ sequence_lengths = -1
1485
+
1486
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1487
+
1488
+ loss = None
1489
+ if labels is not None:
1490
+ labels = labels.to(logits.device)
1491
+ if self.config.problem_type is None:
1492
+ if self.num_labels == 1:
1493
+ self.config.problem_type = "regression"
1494
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1495
+ self.config.problem_type = "single_label_classification"
1496
+ else:
1497
+ self.config.problem_type = "multi_label_classification"
1498
+
1499
+ if self.config.problem_type == "regression":
1500
+ loss_fct = MSELoss()
1501
+ if self.num_labels == 1:
1502
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1503
+ else:
1504
+ loss = loss_fct(pooled_logits, labels)
1505
+ elif self.config.problem_type == "single_label_classification":
1506
+ loss_fct = CrossEntropyLoss()
1507
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1508
+ elif self.config.problem_type == "multi_label_classification":
1509
+ loss_fct = BCEWithLogitsLoss()
1510
+ loss = loss_fct(pooled_logits, labels)
1511
+ if not return_dict:
1512
+ output = (pooled_logits,) + model_outputs[1:]
1513
+ return ((loss,) + output) if loss is not None else output
1514
+
1515
+ return SequenceClassifierOutputWithPast(
1516
+ loss=loss,
1517
+ logits=pooled_logits,
1518
+ past_key_values=model_outputs.past_key_values,
1519
+ hidden_states=model_outputs.hidden_states,
1520
+ attentions=model_outputs.attentions,
1521
+ )
1522
+
1523
+
1524
+ @add_start_docstrings(
1525
+ """
1526
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1527
+ Named-Entity-Recognition (NER) tasks.
1528
+ """,
1529
+ PHI3_START_DOCSTRING,
1530
+ )
1531
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1532
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1533
+ def __init__(self, config: Phi3Config):
1534
+ super().__init__(config)
1535
+ self.num_labels = config.num_labels
1536
+
1537
+ self.model = Phi3Model(config)
1538
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1539
+ classifier_dropout = config.classifier_dropout
1540
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1541
+ classifier_dropout = config.hidden_dropout
1542
+ else:
1543
+ classifier_dropout = 0.1
1544
+ self.dropout = nn.Dropout(classifier_dropout)
1545
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1546
+
1547
+ # Initialize weights and apply final processing
1548
+ self.post_init()
1549
+
1550
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1551
+ @add_code_sample_docstrings(
1552
+ checkpoint=_CHECKPOINT_FOR_DOC,
1553
+ output_type=TokenClassifierOutput,
1554
+ config_class=_CONFIG_FOR_DOC,
1555
+ )
1556
+ def forward(
1557
+ self,
1558
+ input_ids: Optional[torch.LongTensor] = None,
1559
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1560
+ attention_mask: Optional[torch.Tensor] = None,
1561
+ inputs_embeds: Optional[torch.Tensor] = None,
1562
+ labels: Optional[torch.Tensor] = None,
1563
+ use_cache: Optional[bool] = None,
1564
+ output_attentions: Optional[bool] = None,
1565
+ output_hidden_states: Optional[bool] = None,
1566
+ return_dict: Optional[bool] = None,
1567
+ **deprecated_arguments,
1568
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1569
+ r"""
1570
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1571
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1572
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1573
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1574
+ """
1575
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1576
+
1577
+ model_outputs = self.model(
1578
+ input_ids,
1579
+ past_key_values=past_key_values,
1580
+ attention_mask=attention_mask,
1581
+ inputs_embeds=inputs_embeds,
1582
+ use_cache=use_cache,
1583
+ output_attentions=output_attentions,
1584
+ output_hidden_states=output_hidden_states,
1585
+ return_dict=return_dict,
1586
+ )
1587
+
1588
+ hidden_states = model_outputs[0]
1589
+ hidden_states = self.dropout(hidden_states)
1590
+ logits = self.classifier(hidden_states)
1591
+
1592
+ loss = None
1593
+ if labels is not None:
1594
+ # move labels to correct device to enable model parallelism
1595
+ labels = labels.to(logits.device)
1596
+ batch_size, seq_length = labels.shape
1597
+ loss_fct = CrossEntropyLoss()
1598
+ loss = loss_fct(
1599
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1600
+ )
1601
+
1602
+ if not return_dict:
1603
+ output = (logits,) + model_outputs[2:]
1604
+ return ((loss,) + output) if loss is not None else output
1605
+
1606
+ return TokenClassifierOutput(
1607
+ loss=loss,
1608
+ logits=logits,
1609
+ hidden_states=model_outputs.hidden_states,
1610
+ attentions=model_outputs.attentions,
1611
+ )
src/image_decoder/processor.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from transformers import AutoTokenizer
8
+ from huggingface_hub import snapshot_download
9
+ import numpy as np
10
+
11
+
12
+ def crop_arr(pil_image, max_image_size):
13
+ while min(*pil_image.size) >= 2 * max_image_size:
14
+ pil_image = pil_image.resize(
15
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
16
+ )
17
+
18
+ if max(*pil_image.size) > max_image_size:
19
+ scale = max_image_size / max(*pil_image.size)
20
+ pil_image = pil_image.resize(
21
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
22
+ )
23
+
24
+ if min(*pil_image.size) < 16:
25
+ scale = 16 / min(*pil_image.size)
26
+ pil_image = pil_image.resize(
27
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
28
+ )
29
+
30
+ arr = np.array(pil_image)
31
+ crop_y1 = (arr.shape[0] % 16) // 2
32
+ crop_y2 = arr.shape[0] % 16 - crop_y1
33
+
34
+ crop_x1 = (arr.shape[1] % 16) // 2
35
+ crop_x2 = arr.shape[1] % 16 - crop_x1
36
+
37
+ arr = arr[crop_y1:arr.shape[0] - crop_y2, crop_x1:arr.shape[1] - crop_x2]
38
+ return Image.fromarray(arr)
39
+
40
+
41
+ class OmniGenProcessor:
42
+ def __init__(self, max_image_size: int = 1024):
43
+ self.max_image_size = max_image_size
44
+
45
+ self.image_transform = transforms.Compose([
46
+ transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
49
+ ])
50
+
51
+ self.collator = OmniGenCollator()
52
+ self.separate_collator = OmniGenSeparateCollator()
53
+
54
+ @classmethod
55
+ def from_pretrained(cls, model_name):
56
+ if not os.path.exists(model_name):
57
+ cache_folder = os.getenv('HF_HUB_CACHE')
58
+ model_name = snapshot_download(repo_id=model_name,
59
+ cache_dir=cache_folder,
60
+ allow_patterns="*.json")
61
+ text_tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+
63
+ return cls(text_tokenizer)
64
+
65
+ def process_image(self, image):
66
+ image = Image.open(image).convert('RGB')
67
+ return self.image_transform(image)
68
+
69
+ def __call__(self,
70
+ context_hidden_state: List[torch.tensor],
71
+ neg_context_hidden_state: List[torch.tensor],
72
+ height: int = 1024,
73
+ width: int = 1024,
74
+ separate_cfg_input: bool = False,
75
+ ) -> Dict:
76
+
77
+ input_data = []
78
+ for i in range(len(context_hidden_state)):
79
+ cur_context_hidden_state = context_hidden_state[i]
80
+ cur_neg_context_hidden_state = neg_context_hidden_state[i]
81
+
82
+ input_data.append((cur_context_hidden_state, cur_neg_context_hidden_state, [height, width]))
83
+
84
+ if separate_cfg_input:
85
+ return self.separate_collator(input_data)
86
+ return self.collator(input_data)
87
+
88
+
89
+ class OmniGenCollator:
90
+ def __init__(self, pad_token_id=2, llm_pad_token_id=151643, hidden_size=3072):
91
+ self.llm_pad_token_id = llm_pad_token_id
92
+ self.pad_token_id = pad_token_id
93
+ self.hidden_size = hidden_size
94
+
95
+ def create_position(self, attention_mask, num_tokens_for_output_images):
96
+ position_ids = []
97
+ text_length = attention_mask.size(-1)
98
+ img_length = max(num_tokens_for_output_images)
99
+ for mask in attention_mask:
100
+ temp_l = torch.sum(mask)
101
+ temp_position = [0] * (text_length - temp_l) + [i for i in range(temp_l + img_length + 1)] # we add a time embedding into the sequence, so add one more token
102
+ position_ids.append(temp_position)
103
+ return torch.LongTensor(position_ids)
104
+
105
+ def create_connector_position(self, llm_2d_attention_mask):
106
+ position_ids = []
107
+ text_length = llm_2d_attention_mask.size(-1)
108
+ # img_length = max(num_tokens_for_output_images)
109
+ for batch_idx, mask in enumerate(llm_2d_attention_mask):
110
+ temp_l = torch.sum(llm_2d_attention_mask[batch_idx])
111
+ # temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
112
+ temp_position = [0] * (text_length - temp_l) + [i for i in range(temp_l)] # only condition for mllm like qwen
113
+ position_ids.append(temp_position)
114
+ return torch.LongTensor(position_ids)
115
+
116
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
117
+ extended_mask = []
118
+ padding_images = []
119
+ text_length = attention_mask.size(-1)
120
+ img_length = max(num_tokens_for_output_images)
121
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
122
+ inx = 0
123
+ for mask in attention_mask:
124
+ temp_l = torch.sum(mask)
125
+ pad_l = text_length - temp_l
126
+
127
+ temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
128
+
129
+ image_mask = torch.zeros(size=(temp_l + 1, img_length))
130
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
131
+
132
+ image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
133
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
134
+
135
+ if pad_l > 0:
136
+ pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
137
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
138
+
139
+ pad_mask = torch.ones(size=(pad_l, seq_len))
140
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
141
+
142
+ true_img_length = num_tokens_for_output_images[inx]
143
+ pad_img_length = img_length - true_img_length
144
+ if pad_img_length > 0:
145
+ temp_mask[:, -pad_img_length:] = 0
146
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
147
+ else:
148
+ temp_padding_imgs = None
149
+
150
+ extended_mask.append(temp_mask.unsqueeze(0))
151
+ padding_images.append(temp_padding_imgs)
152
+ inx += 1
153
+ return torch.cat(extended_mask, dim=0), padding_images
154
+
155
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
156
+ for b_inx in image_sizes.keys():
157
+ for start_inx, end_inx in image_sizes[b_inx]:
158
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
159
+
160
+ return attention_mask
161
+
162
+ def pad_input(self, context_hidden_state):
163
+ # pad_token_id = self.llm_pad_token_id # 151642 <|endoftext|> in qwen2.5vl
164
+ max_l = max([x.shape[1] for x in context_hidden_state])
165
+ attention_mask = []
166
+
167
+ for i in range(len(context_hidden_state)):
168
+ temp_hidden = context_hidden_state[i]
169
+ temp_l = temp_hidden.shape[1]
170
+ pad_l = max_l - temp_l
171
+ if pad_l == 0:
172
+ attention_mask.append([1] * max_l)
173
+ else:
174
+ attention_mask.append([0] * pad_l + [1] * temp_l)
175
+
176
+ return torch.LongTensor(attention_mask)
177
+
178
+ def process_mllm_input(self, context_hidden_state, target_img_size):
179
+ num_tokens_for_output_images = []
180
+ for img_size in target_img_size:
181
+ num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
182
+
183
+ llm_2d_attention_mask = self.pad_input(context_hidden_state)
184
+ connector_position_ids = self.create_connector_position(llm_2d_attention_mask)
185
+ llm_position_ids = self.create_position(llm_2d_attention_mask, num_tokens_for_output_images)
186
+ llm_attention_mask, _ = self.create_mask(llm_2d_attention_mask, num_tokens_for_output_images)
187
+
188
+ return llm_2d_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids
189
+
190
+
191
+ class OmniGenSeparateCollator(OmniGenCollator):
192
+ def __call__(self, features):
193
+ context_hidden_state = [f[0] for f in features]
194
+ neg_context_hidden_state = [f[1] for f in features]
195
+ target_img_size = [f[2] for f in features]
196
+
197
+ all_context_hidden_state, all_connector_attention_mask, all_connector_position_ids, all_llm_attention_mask, all_llm_position_ids = [], [], [], [], []
198
+ connector_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids = self.process_mllm_input(context_hidden_state, target_img_size)
199
+
200
+ all_context_hidden_state.append(context_hidden_state[0])
201
+ all_connector_attention_mask.append(connector_attention_mask)
202
+ all_connector_position_ids.append(connector_position_ids)
203
+ all_llm_attention_mask.append(llm_attention_mask)
204
+ all_llm_position_ids.append(llm_position_ids)
205
+
206
+ if neg_context_hidden_state[0] is not None:
207
+ connector_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids = self.process_mllm_input(neg_context_hidden_state, target_img_size)
208
+ all_context_hidden_state.append(neg_context_hidden_state[0])
209
+ all_connector_attention_mask.append(connector_attention_mask)
210
+ all_connector_position_ids.append(connector_position_ids)
211
+ all_llm_attention_mask.append(llm_attention_mask)
212
+ all_llm_position_ids.append(llm_position_ids)
213
+
214
+ data = {
215
+ "context_hidden_state": all_context_hidden_state,
216
+ "connector_attention_mask": all_connector_attention_mask,
217
+ "connector_position_ids": all_connector_position_ids,
218
+ "llm_attention_mask": all_llm_attention_mask,
219
+ "llm_position_ids": all_llm_position_ids,
220
+ }
221
+ return data
src/image_decoder/scheduler.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import Optional, Dict, Any, Tuple, List
3
+ import gc
4
+
5
+ import torch
6
+ try:
7
+ import torch_npu
8
+ except Exception as e:
9
+ print(e)
10
+ from transformers.cache_utils import DynamicCache
11
+
12
+
13
+ class OmniGenCache(DynamicCache):
14
+ def __init__(self, num_tokens_for_img: int, offload_kv_cache: bool = False) -> None:
15
+ # if not torch.cuda.is_available():
16
+ # # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
17
+ # # offload_kv_cache = False
18
+ # raise RuntimeError("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
19
+ super().__init__()
20
+ self.original_device = []
21
+ self.prefetch_stream = torch.cuda.Stream() if torch.cuda.is_available() else torch_npu.npu.Stream()
22
+ self.num_tokens_for_img = num_tokens_for_img
23
+ self.offload_kv_cache = offload_kv_cache
24
+
25
+ def prefetch_layer(self, layer_idx: int):
26
+ "Starts prefetching the next layer cache"
27
+ if layer_idx < len(self):
28
+ if torch.cuda.is_available():
29
+ with torch.cuda.stream(self.prefetch_stream):
30
+ # Prefetch next layer tensors to GPU
31
+ device = self.original_device[layer_idx]
32
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
33
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
34
+ else:
35
+ with torch_npu.npu.stream(self.prefetch_stream):
36
+ # Prefetch next layer tensors to GPU
37
+ device = self.original_device[layer_idx]
38
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
39
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
40
+
41
+ def evict_previous_layer(self, layer_idx: int):
42
+ "Moves the previous layer cache to the CPU"
43
+ if len(self) > 2:
44
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
45
+ if layer_idx == 0:
46
+ prev_layer_idx = -1
47
+ else:
48
+ prev_layer_idx = (layer_idx - 1) % len(self)
49
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
50
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
51
+
52
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
53
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
54
+ if layer_idx < len(self):
55
+ if self.offload_kv_cache:
56
+ # Evict the previous layer if necessary
57
+ if torch.cuda.is_available():
58
+ torch.cuda.current_stream().synchronize()
59
+ else:
60
+ torch_npu.npu.current_stream().synchronize()
61
+ self.evict_previous_layer(layer_idx)
62
+ # Load current layer cache to its original device if not already there
63
+ # self.prefetch_stream.synchronize(original_device)
64
+ if torch.cuda.is_available():
65
+ torch.cuda.synchronize(self.prefetch_stream)
66
+ else:
67
+ torch_npu.npu.synchronize(self.prefetch_stream)
68
+ key_tensor = self.key_cache[layer_idx]
69
+ value_tensor = self.value_cache[layer_idx]
70
+
71
+ # Prefetch the next layer
72
+ self.prefetch_layer((layer_idx + 1) % len(self))
73
+ else:
74
+ key_tensor = self.key_cache[layer_idx]
75
+ value_tensor = self.value_cache[layer_idx]
76
+ return (key_tensor, value_tensor)
77
+ else:
78
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
79
+
80
+ def update(
81
+ self,
82
+ key_states: torch.Tensor,
83
+ value_states: torch.Tensor,
84
+ layer_idx: int,
85
+ cache_kwargs: Optional[Dict[str, Any]] = None,
86
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
87
+ """
88
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
89
+ Parameters:
90
+ key_states (`torch.Tensor`):
91
+ The new key states to cache.
92
+ value_states (`torch.Tensor`):
93
+ The new value states to cache.
94
+ layer_idx (`int`):
95
+ The index of the layer to cache the states for.
96
+ cache_kwargs (`Dict[str, Any]`, `optional`):
97
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
98
+ Return:
99
+ A tuple containing the updated key and value states.
100
+ """
101
+ # Update the cache
102
+ if len(self.key_cache) < layer_idx:
103
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
104
+ elif len(self.key_cache) == layer_idx:
105
+ # only cache the states for condition tokens
106
+ key_states = key_states[..., :-(self.num_tokens_for_img + 1), :]
107
+ value_states = value_states[..., :-(self.num_tokens_for_img + 1), :]
108
+
109
+ # Update the number of seen tokens
110
+ if layer_idx == 0:
111
+ self._seen_tokens += key_states.shape[-2]
112
+
113
+ self.key_cache.append(key_states)
114
+ self.value_cache.append(value_states)
115
+ self.original_device.append(key_states.device)
116
+ if self.offload_kv_cache:
117
+ self.evict_previous_layer(layer_idx)
118
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
119
+ else:
120
+ # only cache the states for condition tokens
121
+ key_tensor, value_tensor = self[layer_idx]
122
+ k = torch.cat([key_tensor, key_states], dim=-2)
123
+ v = torch.cat([value_tensor, value_states], dim=-2)
124
+ return k, v
125
+
126
+
127
+ class OmniGenScheduler:
128
+ def __init__(self, num_steps: int = 50, time_shifting_factor: int = 1):
129
+ self.num_steps = num_steps
130
+ self.time_shift = time_shifting_factor
131
+
132
+ t = torch.linspace(0, 1, num_steps + 1)
133
+ t = t / (t + time_shifting_factor - time_shifting_factor * t)
134
+ self.sigma = t
135
+
136
+ def crop_kv_cache(self, past_key_values, num_tokens_for_img):
137
+ # return
138
+ crop_past_key_values = ()
139
+ for layer_idx in range(len(past_key_values)):
140
+ key_states, value_states = past_key_values[layer_idx][:2]
141
+ crop_past_key_values += ((key_states[..., :-(num_tokens_for_img + 1), :], value_states[..., :-(num_tokens_for_img + 1), :], ),)
142
+ # return crop_past_key_values
143
+ return DynamicCache.from_legacy_cache(crop_past_key_values)
144
+
145
+ def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
146
+ if isinstance(position_ids, list):
147
+ for i in range(len(position_ids)):
148
+ position_ids[i] = position_ids[i][:, -(num_tokens_for_img + 1):]
149
+ else:
150
+ position_ids = position_ids[:, -(num_tokens_for_img + 1):]
151
+ return position_ids
152
+
153
+ def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
154
+ if isinstance(attention_mask, list):
155
+ return [x[..., -(num_tokens_for_img + 1):, :] for x in attention_mask]
156
+ return attention_mask[..., -(num_tokens_for_img + 1):, :]
157
+
158
+ def crop_cache(self, cache, num_tokens_for_img):
159
+ for i in range(len(cache.key_cache)):
160
+ cache.key_cache[i] = cache.key_cache[i][..., :-(num_tokens_for_img + 1), :]
161
+ cache.value_cache[i] = cache.value_cache[i][..., :-(num_tokens_for_img + 1), :]
162
+
163
+ return cache
164
+
165
+ def __call__(self, z, func, model_kwargs, use_kv_cache: bool = True, offload_kv_cache: bool = True, tqdm_disable: bool = False):
166
+
167
+ num_tokens_for_img = z.size(-1) * z.size(-2) // 4
168
+ if isinstance(model_kwargs['llm_input_embeds'], list):
169
+ cache = [OmniGenCache(num_tokens_for_img, offload_kv_cache) for _ in range(len(model_kwargs['llm_input_embeds']))] if use_kv_cache else None
170
+ else:
171
+ cache = OmniGenCache(num_tokens_for_img, offload_kv_cache) if use_kv_cache else None
172
+ for i in tqdm(range(self.num_steps), disable=tqdm_disable):
173
+ timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
174
+ pred, cache = func(z, timesteps, past_key_values=cache, **model_kwargs)
175
+ sigma_next = self.sigma[i + 1]
176
+ sigma = self.sigma[i]
177
+ z = z + (sigma_next - sigma) * pred
178
+ if i == 0 and use_kv_cache:
179
+ num_tokens_for_img = z.size(-1) * z.size(-2) // 4
180
+ if isinstance(cache, list):
181
+ model_kwargs['llm_input_embeds'] = [None] * len(cache)
182
+ else:
183
+ model_kwargs['llm_input_embeds'] = None
184
+
185
+ model_kwargs['llm_position_ids'] = self.crop_position_ids_for_cache(model_kwargs['llm_position_ids'], num_tokens_for_img)
186
+ model_kwargs['llm_attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['llm_attention_mask'], num_tokens_for_img)
187
+
188
+ del cache
189
+ if torch.cuda.is_available():
190
+ torch.cuda.empty_cache()
191
+ else:
192
+ torch_npu.npu.empty_cache()
193
+ gc.collect()
194
+ return z
src/image_decoder/transformer.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+
5
+ from transformers.modeling_outputs import BaseModelOutputWithPast
6
+ from .modeling_phi3 import Phi3Model
7
+ from transformers.cache_utils import Cache, DynamicCache
8
+ from transformers.utils import logging
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class Phi3Transformer(Phi3Model):
14
+ """
15
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
16
+ We only modified the attention mask
17
+ Args:
18
+ config: Phi3Config
19
+ """
20
+ def prefetch_layer(self, layer_idx: int, device: torch.device):
21
+ "Starts prefetching the next layer cache"
22
+ with torch.cuda.stream(self.prefetch_stream):
23
+ # Prefetch next layer tensors to GPU
24
+ for name, param in self.layers[layer_idx].named_parameters():
25
+ param.data = param.data.to(device, non_blocking=True)
26
+
27
+ def evict_previous_layer(self, layer_idx: int):
28
+ "Moves the previous layer cache to the CPU"
29
+ prev_layer_idx = layer_idx - 1
30
+ for name, param in self.layers[prev_layer_idx].named_parameters():
31
+ param.data = param.data.to("cpu", non_blocking=True)
32
+
33
+ def get_offlaod_layer(self, layer_idx: int, device: torch.device):
34
+ # init stream
35
+ if not hasattr(self, "prefetch_stream"):
36
+ self.prefetch_stream = torch.cuda.Stream()
37
+
38
+ # delete previous layer
39
+ torch.cuda.current_stream().synchronize()
40
+ self.evict_previous_layer(layer_idx)
41
+
42
+ # make sure the current layer is ready
43
+ torch.cuda.synchronize(self.prefetch_stream)
44
+
45
+ # load next layer
46
+ self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
47
+
48
+ def forward(
49
+ self,
50
+ input_ids: torch.LongTensor = None,
51
+ attention_mask: Optional[torch.Tensor] = None,
52
+ position_ids: Optional[torch.LongTensor] = None,
53
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
54
+ inputs_embeds: Optional[torch.FloatTensor] = None,
55
+ use_cache: Optional[bool] = None,
56
+ output_attentions: Optional[bool] = None,
57
+ output_hidden_states: Optional[bool] = None,
58
+ return_dict: Optional[bool] = None,
59
+ cache_position: Optional[torch.LongTensor] = None,
60
+ offload_model: Optional[bool] = False,
61
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
62
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
63
+ output_hidden_states = (
64
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
65
+ )
66
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
67
+
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
69
+
70
+ if (input_ids is None) ^ (inputs_embeds is not None):
71
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
72
+
73
+ if self.gradient_checkpointing and self.training:
74
+ if use_cache:
75
+ logger.warning_once(
76
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
77
+ )
78
+ use_cache = False
79
+
80
+ # kept for BC (non `Cache` `past_key_values` inputs)
81
+ return_legacy_cache = False
82
+ if use_cache and not isinstance(past_key_values, Cache):
83
+ return_legacy_cache = True
84
+ if past_key_values is None:
85
+ past_key_values = DynamicCache()
86
+ else:
87
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
88
+ logger.warning_once(
89
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
90
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
91
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
92
+ )
93
+
94
+ # if inputs_embeds is None:
95
+ # inputs_embeds = self.embed_tokens(input_ids)
96
+
97
+ # if cache_position is None:
98
+ # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
99
+ # cache_position = torch.arange(
100
+ # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
101
+ # )
102
+ # if position_ids is None:
103
+ # position_ids = cache_position.unsqueeze(0)
104
+
105
+ if attention_mask is not None and attention_mask.dim() == 3:
106
+ dtype = inputs_embeds.dtype
107
+ min_dtype = torch.finfo(dtype).min
108
+ attention_mask = (1 - attention_mask) * min_dtype
109
+ attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
110
+ else:
111
+ raise Exception("attention_mask parameter was unavailable or invalid")
112
+ # causal_mask = self._update_causal_mask(
113
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
114
+ # )
115
+
116
+ hidden_states = inputs_embeds
117
+
118
+ # decoder layers
119
+ all_hidden_states = () if output_hidden_states else None
120
+ all_self_attns = () if output_attentions else None
121
+ next_decoder_cache = None
122
+
123
+ layer_idx = -1
124
+ for decoder_layer in self.layers:
125
+ layer_idx += 1
126
+
127
+ if output_hidden_states:
128
+ all_hidden_states += (hidden_states,)
129
+
130
+ if self.gradient_checkpointing and self.training:
131
+ layer_outputs = self._gradient_checkpointing_func(
132
+ decoder_layer.__call__,
133
+ hidden_states,
134
+ attention_mask,
135
+ position_ids,
136
+ past_key_values,
137
+ output_attentions,
138
+ use_cache,
139
+ cache_position,
140
+ )
141
+ else:
142
+ if offload_model and not self.training:
143
+ self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
144
+ layer_outputs = decoder_layer(
145
+ hidden_states,
146
+ attention_mask=attention_mask,
147
+ position_ids=position_ids,
148
+ past_key_value=past_key_values,
149
+ output_attentions=output_attentions,
150
+ use_cache=use_cache,
151
+ cache_position=cache_position,
152
+ )
153
+
154
+ hidden_states = layer_outputs[0]
155
+
156
+ if use_cache:
157
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
158
+
159
+ if output_attentions:
160
+ all_self_attns += (layer_outputs[1],)
161
+
162
+ hidden_states = self.norm(hidden_states)
163
+
164
+ # add hidden states from the last decoder layer
165
+ if output_hidden_states:
166
+ all_hidden_states += (hidden_states,)
167
+
168
+ next_cache = next_decoder_cache if use_cache else None
169
+ if return_legacy_cache:
170
+ next_cache = next_cache.to_legacy_cache()
171
+
172
+ if not return_dict:
173
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
174
+ return BaseModelOutputWithPast(
175
+ last_hidden_state=hidden_states,
176
+ past_key_values=next_cache,
177
+ hidden_states=all_hidden_states,
178
+ attentions=all_self_attns,
179
+ )
src/mindomni.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mllm import MindOmniMLLM
2
+ from .image_decoder import OmniGen
3
+ import torch.nn as nn
4
+ from .image_decoder import Phi3DecoderLayer, ImageDecoderPipeline, OmniGenProcessor
5
+ import os
6
+ import torch
7
+ from safetensors.torch import load_file
8
+ from typing import Union
9
+ from diffusers.utils import logging
10
+ from diffusers.models import AutoencoderKL
11
+ from transformers import AutoProcessor
12
+ import re
13
+ from qwen_vl_utils import process_vision_info
14
+ try:
15
+ import torch_npu
16
+ except Exception as e:
17
+ print(e)
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class MindOmniConnector(nn.Module):
23
+ def __init__(self, pre_config, post_config, layer_num: int = 2):
24
+ super().__init__()
25
+ connector_decoder = nn.ModuleList(
26
+ [Phi3DecoderLayer(post_config, layer_idx) for layer_idx in range(layer_num)]
27
+ )
28
+ self.connector = nn.ModuleList(
29
+ [nn.Linear(pre_config.hidden_size, post_config.hidden_size)] # qwen2.5vl-7b: 3584
30
+ )
31
+ self.connector.extend(connector_decoder)
32
+
33
+
34
+ class MindOmni:
35
+ def __init__(self, mllm, image_decoder, connector, vae, processor, mllm_processor, device: Union[str, torch.device] = None):
36
+ self.mllm = mllm
37
+ self.image_decoder = image_decoder
38
+ self.connector = connector
39
+ self.vae = vae
40
+ self.processor = processor
41
+ self.mllm_processor = mllm_processor
42
+
43
+ self.vae.to(torch.float32)
44
+ self.device = device
45
+ if device is None:
46
+ if torch.cuda.is_available():
47
+ self.device = torch.device("cuda")
48
+ elif torch_npu.npu.is_available():
49
+ self.device = torch.device("npu")
50
+ elif torch.backends.mps.is_available():
51
+ self.device = torch.device("mps")
52
+ else:
53
+ logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
54
+ self.device = torch.device("cpu")
55
+
56
+ @classmethod
57
+ def from_pretrained(cls, model_path):
58
+ mllm = MindOmniMLLM.from_pretrained(os.path.join(model_path, 'mllm'))
59
+ image_decoder = OmniGen.from_pretrained(os.path.join(model_path, 'image_decoder'))
60
+ connector = MindOmniConnector(mllm.config, image_decoder.llm.config, 2).connector
61
+ connector_state = load_file(os.path.join(model_path, 'connector.safetensors'))
62
+ connector.load_state_dict(connector_state)
63
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_path, "vae"))
64
+ processor = OmniGenProcessor.from_pretrained(os.path.join(model_path, 'image_decoder'))
65
+ mllm_processor = AutoProcessor.from_pretrained(os.path.join(model_path, 'mllm'))
66
+ logger.info("Preparing MindOmni")
67
+ return cls(mllm, image_decoder, connector, vae, processor, mllm_processor)
68
+
69
+ def to(self, device: Union[str, torch.device] = None, dtype: Union[str, torch.device] = None):
70
+ if device is not None:
71
+ if isinstance(device, str):
72
+ device = torch.device(device)
73
+ self.mllm.to(device)
74
+ self.image_decoder.to(device)
75
+ self.connector.to(device)
76
+ self.vae.to(device)
77
+ self.device = device
78
+ if dtype is not None:
79
+ self.mllm.to(dtype)
80
+ self.image_decoder.to(dtype)
81
+ self.connector.to(dtype)
82
+
83
+ def eval(self):
84
+ self.mllm.eval()
85
+ self.image_decoder.eval()
86
+ self.connector.eval()
87
+ self.vae.eval()
88
+
89
+ @torch.no_grad()
90
+ def get_mllm_hidden_state(self, user_input, input_images, do_sample, temperature, max_new_tokens, only_understand=False, use_cot=False):
91
+ input_llm_images = input_images
92
+ processor = self.mllm_processor
93
+ model = self.mllm
94
+ if only_understand or not use_cot:
95
+ system_prompt = (
96
+ "You are a helpful assistant."
97
+ )
98
+ else:
99
+ system_prompt = (
100
+ "You are a helpful assistant. When the user requests an image, the assistant "
101
+ "first thinks about the reasoning process in the mind and then provides the user with concise prompt as the answer. "
102
+ "The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
103
+ "<think> reasoning process here </think><answer> answer here </answer>."
104
+ )
105
+
106
+ messages = [
107
+ {
108
+ "role": "system",
109
+ "content": [
110
+ {"type": "text", "text": system_prompt},
111
+ ],
112
+ },
113
+ {
114
+ "role": "user",
115
+ "content": [
116
+ {"type": "text", "text": "Generate an image according to the following instructions\n"},
117
+ {"type": "text", "text": user_input},
118
+ ],
119
+ }
120
+ ]
121
+
122
+ if input_llm_images is not None:
123
+ if only_understand:
124
+ assert len(input_llm_images) == 1, "only support single image when multimodal understanding"
125
+ messages[1]['content'][0] = {"type": "image", "image": input_llm_images[0]}
126
+ else:
127
+ user_input = f'<img><|image_1|></img> {user_input}'
128
+ messages[1]['content'][1] = {"type": "text", "text": user_input}
129
+ image_tags = re.findall(r'<\|image_\d+\|>', messages[1]['content'][1]['text'])
130
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
131
+ pattern = r"<img><\|image_\d+\|></img>"
132
+ prompt_chunks = [chunk for chunk in re.split(pattern, messages[1]['content'][1]['text'])]
133
+ assert len(prompt_chunks) == len(input_llm_images) + 1
134
+ new_content = []
135
+ for idx, per_prompt in enumerate(prompt_chunks):
136
+ if idx != len(prompt_chunks) - 1:
137
+ item_text = {"type": "text", "text": per_prompt}
138
+ # resized_height, resized_width = input_images_shape[image_ids[idx] - 1]
139
+ image_path = input_llm_images[image_ids[idx] - 1]
140
+ # item_vit = {"type": "image", "image": image_path, "resized_height": resized_height, "resized_width": resized_width}
141
+ item_vit = {"type": "image", "image": image_path}
142
+ item_tag = {"type": "text", "text": f"<img>{image_tags[idx]}</img>"}
143
+ new_content.append(item_text)
144
+ new_content.append(item_vit)
145
+ new_content.append(item_tag)
146
+ else:
147
+ item_text = {"type": "text", "text": per_prompt}
148
+ new_content.append(item_text)
149
+ messages[1]['content'] = messages[1]['content'][:1] + new_content
150
+
151
+ text = processor.apply_chat_template(
152
+ messages, tokenize=False, add_generation_prompt=True
153
+ )
154
+ image_inputs, video_inputs = process_vision_info(messages)
155
+ inputs = processor(
156
+ text=[text],
157
+ images=image_inputs,
158
+ videos=video_inputs,
159
+ padding=True,
160
+ return_tensors="pt",
161
+ )
162
+ inputs = inputs.to("npu")
163
+
164
+ if use_cot:
165
+ # Inference: Generation of the output
166
+ temperature = temperature if do_sample else None
167
+ generated_dict = model.generate(**inputs, do_sample=do_sample, temperature=temperature, max_new_tokens=max_new_tokens, output_hidden_states=True, return_dict_in_generate=True)
168
+ generated_ids_trimmed = [
169
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_dict.sequences)
170
+ ]
171
+ output_hidden_state = [hidden_state[-1] for hidden_state in generated_dict.hidden_states]
172
+ context_hidden_state = torch.cat(output_hidden_state, dim=1)
173
+
174
+ output_text = processor.batch_decode(
175
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
176
+ )
177
+
178
+ prompt_ = output_text[0]
179
+
180
+ assistant_content = [
181
+ {
182
+ "role": "assistant",
183
+ "content": [
184
+ {"type": "text", "text": prompt_},
185
+ ],
186
+ }
187
+ ]
188
+
189
+ messages += assistant_content
190
+ else:
191
+ prompt_ = user_input
192
+ context_hidden_state = model(**inputs, output_hidden_states=True).hidden_states[-1]
193
+ return messages, prompt_, context_hidden_state
194
+
195
+ def generate_image(self, height, width, guidance_scale, inference_steps, separate_cfg_infer, offload_model, seed, max_input_image_size,
196
+ text, NEGATIVE_PROMPT, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=False):
197
+ gen_pipe = ImageDecoderPipeline(self.vae, self.image_decoder, self.connector, self.processor)
198
+ message, prompt_, context_hidden_state = self.get_mllm_hidden_state(text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=use_cot)
199
+ neg_message, neg_prompt_, neg_context_hidden_state = self.get_mllm_hidden_state(NEGATIVE_PROMPT, None, do_sample, temperature, max_new_tokens, only_understand, use_cot=False)
200
+ print(message)
201
+ output = gen_pipe(
202
+ context_hidden_state=context_hidden_state,
203
+ neg_context_hidden_state=neg_context_hidden_state,
204
+ height=height,
205
+ width=width,
206
+ guidance_scale=guidance_scale,
207
+ num_inference_steps=inference_steps,
208
+ separate_cfg_infer=separate_cfg_infer,
209
+ use_kv_cache=True,
210
+ offload_kv_cache=True,
211
+ offload_model=offload_model,
212
+ seed=seed,
213
+ max_input_image_size=max_input_image_size,
214
+ )
215
+ return output, prompt_
216
+
217
+ def generate_text(self, text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand):
218
+ _, answer, _ = self.get_mllm_hidden_state(text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand=True, use_cot=True)
219
+ return answer
src/mllm.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.modeling_outputs import BaseModelOutputWithPast
5
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import logger
6
+ from transformers.cache_utils import DynamicCache
7
+
8
+
9
+ class MindOmniMLLM_Model(Qwen2_5_VLModel):
10
+
11
+ def forward(
12
+ self,
13
+ input_ids: torch.LongTensor = None,
14
+ attention_mask: Optional[torch.Tensor] = None,
15
+ position_ids: Optional[torch.LongTensor] = None,
16
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
17
+ inputs_embeds: Optional[torch.FloatTensor] = None,
18
+ use_cache: Optional[bool] = None,
19
+ output_attentions: Optional[bool] = None,
20
+ output_hidden_states: Optional[bool] = None,
21
+ return_dict: Optional[bool] = None,
22
+ cache_position: Optional[torch.LongTensor] = None,
23
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
24
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
25
+ output_hidden_states = (
26
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
27
+ )
28
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
29
+
30
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
31
+
32
+ if (input_ids is None) ^ (inputs_embeds is not None):
33
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
34
+
35
+ if self.gradient_checkpointing and self.training:
36
+ if use_cache:
37
+ logger.warning_once(
38
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
39
+ )
40
+ use_cache = False
41
+
42
+ # torch.jit.trace() doesn't support cache objects in the output
43
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
44
+ past_key_values = DynamicCache()
45
+
46
+ if inputs_embeds is None:
47
+ inputs_embeds = self.embed_tokens(input_ids)
48
+
49
+ if cache_position is None:
50
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
51
+ cache_position = torch.arange(
52
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
53
+ )
54
+
55
+ # the hard coded `3` is for temporal, height and width.
56
+ if position_ids is None:
57
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
58
+ elif position_ids.dim() == 2:
59
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
60
+
61
+ causal_mask = self._update_causal_mask(
62
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
63
+ )
64
+ hidden_states = inputs_embeds
65
+
66
+ # create position embeddings to be shared across the decoder layers
67
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
68
+
69
+ # decoder layers
70
+ all_hidden_states = () if output_hidden_states else None
71
+ all_self_attns = () if output_attentions else None
72
+ next_decoder_cache = None
73
+
74
+ for decoder_layer in self.layers:
75
+ if output_hidden_states:
76
+ all_hidden_states += (hidden_states,)
77
+
78
+ if self.gradient_checkpointing and self.training:
79
+ layer_outputs = self._gradient_checkpointing_func(
80
+ decoder_layer.__call__,
81
+ hidden_states,
82
+ causal_mask,
83
+ position_ids,
84
+ past_key_values,
85
+ output_attentions,
86
+ use_cache,
87
+ cache_position,
88
+ position_embeddings,
89
+ )
90
+ else:
91
+ layer_outputs = decoder_layer(
92
+ hidden_states,
93
+ attention_mask=causal_mask,
94
+ position_ids=position_ids,
95
+ past_key_value=past_key_values,
96
+ output_attentions=output_attentions,
97
+ use_cache=use_cache,
98
+ cache_position=cache_position,
99
+ position_embeddings=position_embeddings,
100
+ )
101
+
102
+ hidden_states = layer_outputs[0]
103
+
104
+ if use_cache:
105
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
106
+
107
+ if output_attentions:
108
+ all_self_attns += (layer_outputs[1],)
109
+
110
+ # add hidden states from the last decoder layer before the self.norm
111
+ # import ipdb; ipdb.set_trace()
112
+ if output_hidden_states:
113
+ all_hidden_states += (hidden_states,)
114
+
115
+ hidden_states = self.norm(hidden_states)
116
+
117
+ next_cache = next_decoder_cache if use_cache else None
118
+
119
+ if not return_dict:
120
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
121
+ return BaseModelOutputWithPast(
122
+ last_hidden_state=hidden_states,
123
+ past_key_values=next_cache,
124
+ hidden_states=all_hidden_states,
125
+ attentions=all_self_attns,
126
+ )
127
+
128
+
129
+ class MindOmniMLLM(Qwen2_5_VLForConditionalGeneration):
130
+
131
+ def __init__(self, config):
132
+ super().__init__(config)
133
+ self.model = MindOmniMLLM_Model(config)
134
+
135
+ # @staticmethod
136
+ # def _update_model_kwargs_for_generation(
137
+ # outputs, model_kwargs, past_key_values_field="past_key_values"
138
+ # ):
139
+ # if past_key_values_field in outputs:
140
+ # model_kwargs[past_key_values_field] = outputs[past_key_values_field]
141
+
142
+ # if "attention_mask" in model_kwargs:
143
+ # bs, _ = model_kwargs["attention_mask"].shape
144
+ # new_mask = torch.ones(bs, 1, dtype=model_kwargs["attention_mask"].dtype,
145
+ # device=model_kwargs["attention_mask"].device)
146
+ # model_kwargs["attention_mask"] = torch.cat(
147
+ # [model_kwargs["attention_mask"], new_mask], dim=-1
148
+ # )
149
+ # return model_kwargs
150
+
151
+ # @staticmethod
152
+ # def _sample_token(
153
+ # logits: torch.Tensor,
154
+ # do_sample: bool,
155
+ # logits_processors: LogitsProcessorList,
156
+ # temperature: float,
157
+ # top_p: float,
158
+ # ):
159
+ # """do sample / greedy"""
160
+ # logits = logits_processors(None, logits)
161
+ # if do_sample:
162
+ # # 温度缩放
163
+ # if temperature != 1.0 and temperature > 0:
164
+ # logits = logits / temperature
165
+ # # nucleus
166
+ # if top_p < 1.0:
167
+ # logits = TopPLogitsWarper(top_p=top_p)(None, logits)
168
+ # probs = nn.functional.softmax(logits, dim=-1, dtype=torch.float32)
169
+ # next_token = torch.multinomial(probs, num_samples=1)
170
+ # else: # greedy
171
+ # next_token = torch.argmax(logits, dim=-1, keepdim=True)
172
+ # return next_token
173
+
174
+ # @torch.no_grad()
175
+ # def generate(
176
+ # self,
177
+ # pixel_values: Optional[torch.FloatTensor] = None,
178
+ # input_ids: Optional[torch.LongTensor] = None,
179
+ # attention_mask: Optional[torch.LongTensor] = None,
180
+ # max_new_tokens: int = 64,
181
+ # do_sample: bool = False,
182
+ # temperature: float = 1.0,
183
+ # top_p: float = 0.95,
184
+ # device: Union[str, torch.device] = "cuda",
185
+ # ) -> torch.LongTensor:
186
+
187
+ # assert input_ids is not None
188
+ # eos_token_id = self.config.eos_token_id
189
+
190
+ # generated = [input_ids]
191
+
192
+ # input_ids = input_ids.to(device)
193
+ # if pixel_values is not None:
194
+ # pixel_values = pixel_values.to(device)
195
+ # if attention_mask is None:
196
+ # attention_mask = torch.ones_like(input_ids, dtype=torch.long)
197
+
198
+ # logits_processors = LogitsProcessorList()
199
+ # if temperature != 1.0 and do_sample:
200
+ # logits_processors.append(TemperatureLogitsWarper(temperature))
201
+ # if top_p < 1.0 and do_sample:
202
+ # logits_processors.append(TopPLogitsWarper(top_p=top_p))
203
+
204
+ # # ---- 推理循环 ---- #
205
+ # model_kwargs = {
206
+ # "attention_mask": attention_mask,
207
+ # "use_cache": True,
208
+ # "past_key_values": None,
209
+ # "cache_position": torch.arange(attention_mask.shape[-1]).to(attention_mask)
210
+ # }
211
+
212
+ # for _ in range(max_new_tokens):
213
+ # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
214
+
215
+ # outputs = self(
216
+ # input_ids=input_ids,
217
+ # use_cache=True,
218
+ # **model_kwargs,
219
+ # )
220
+
221
+ # next_token = self._sample_token(
222
+ # outputs.logits[:, -1, :],
223
+ # do_sample=do_sample,
224
+ # logits_processors=logits_processors,
225
+ # temperature=temperature,
226
+ # top_p=top_p,
227
+ # ) # (bs, 1)
228
+
229
+ # # 追加生成
230
+ # input_ids = next_token
231
+ # generated.append(next_token)
232
+
233
+ # # 更新 kv cache / attention_mask
234
+ # model_kwargs = self._update_model_kwargs_for_generation(
235
+ # outputs, model_kwargs
236
+ # )
237
+
238
+ # # 判断终止:所有 batch 均生成 eos
239
+ # if eos_token_id is not None:
240
+ # if (next_token == eos_token_id).all():
241
+ # break
242
+
243
+ # generated_ids = torch.cat(generated, dim=1)
244
+
245
+ # return generated_ids