infgrad commited on
Commit
da9b77d
·
verified ·
1 Parent(s): c6991f5

Upload 4 files

Browse files
scripts/original_stella_jasper_training_codes/run_train_align_image_text_stage4.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf8
2
+ from PIL import ImageFile
3
+
4
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
5
+ import copy
6
+ import json
7
+ import os
8
+ import sys
9
+ import yaml
10
+ import torch
11
+ import shutil
12
+ import math
13
+ import random
14
+ import tarfile
15
+ import io
16
+ import accelerate
17
+ from loguru import logger
18
+ from torch.utils.data import DataLoader, Dataset
19
+ from transformers import get_scheduler, SiglipImageProcessor
20
+ from accelerate import Accelerator
21
+ from accelerate.utils import set_seed, ProjectConfiguration
22
+ from tqdm import tqdm
23
+ from typing import List, Union
24
+ from os.path import join
25
+ import torch.nn.functional as F
26
+ from jasper_model.modeling_jasper_vl import JasperVL
27
+ from jasper_model.tokenization_qwen import Qwen2TokenizerFast
28
+ from jasper_model.configuration_jasper_vl import JasperVLConfig
29
+ from PIL import Image
30
+
31
+
32
+ class JasperVLDataset_TAR(Dataset):
33
+
34
+ def __init__(self, file_path_list: Union[List[str], str], tar_names_list: List[str]):
35
+ self.file_path_list = file_path_list
36
+ self.tar_fr_list = [tarfile.open(file_path) for file_path in file_path_list]
37
+ self.tar_names_list = tar_names_list
38
+ self.num_data_of_tar = [len(i) for i in self.tar_names_list]
39
+ self.num_all_data = sum(self.num_data_of_tar)
40
+
41
+ self.all_ids = []
42
+ ids_list = [list(range(i)) for i in self.num_data_of_tar]
43
+ for start in range(0, len(ids_list), 11):
44
+ end = start + 11
45
+ if end > len(ids_list):
46
+ end = len(ids_list)
47
+ while True:
48
+ pre_num = len(self.all_ids)
49
+ for file_idx in range(start, end):
50
+ if ids_list[file_idx]:
51
+ self.all_ids.append((file_idx, ids_list[file_idx].pop()))
52
+ if len(self.all_ids) == pre_num:
53
+ break
54
+ assert len(self.all_ids) == self.num_all_data
55
+ self.accumulation_numbers = [sum(self.num_data_of_tar[:idx + 1]) for idx in range(len(self.num_data_of_tar))]
56
+ if accelerator.is_main_process:
57
+ logger.info(f"file_path_list:{file_path_list}")
58
+ logger.info(f"num_data_of_tar:{self.num_data_of_tar}")
59
+ logger.info(f"number of data:{self.num_all_data}")
60
+
61
+ def __len__(self):
62
+ return self.num_all_data
63
+
64
+ def __getitem__(self, item):
65
+
66
+ file_idx, item = self.all_ids[item]
67
+ file_path = self.file_path_list[file_idx]
68
+ tar_fr = self.tar_fr_list[file_idx]
69
+ text_item = json.loads(
70
+ tar_fr.extractfile(self.tar_names_list[file_idx][item]["text_name"]).read()
71
+ )
72
+ img_bytes = tar_fr.extractfile(self.tar_names_list[file_idx][item]["img_name"]).read()
73
+ # 根据file path 获取要处理的数据类型
74
+ if "DocStruct4M_struct_aware_parse" in file_path:
75
+ user_text = text_item["conversations"][0]["value"]
76
+ assistant_text = text_item["conversations"][1]["value"]
77
+ idx = user_text.find("<doc>")
78
+ if idx == -1:
79
+ user_text = ""
80
+ else:
81
+ user_text = user_text[idx + 5:]
82
+ # -6的原因是有</doc>
83
+ return {"text": user_text + assistant_text[:-6], "img_bytes": img_bytes}
84
+ else:
85
+ return {"text": text_item["conversations"][1]["value"], "img_bytes": img_bytes}
86
+
87
+
88
+ # modelscope download --dataset 'BAAI/Infinity-MM' --include 'stage2/DocStruct4M/DocStruct4M_struct_aware_parse*' --local_dir infinity_mm
89
+ # modelscope download --dataset 'BAAI/Infinity-MM' --include 'stage2/llava-onevision-mid-stage/synthdog_en_100k--synthdog_en_processed_new/*.tar' --local_dir infinity_mm
90
+ # modelscope download --dataset 'BAAI/Infinity-MM' --include 'stage2/MMC-Alignment/MMC-Alignment-mmc_chart_text_alignment_arxiv_text/*.tar' --local_dir infinity_mm
91
+ def collate_fn(batch):
92
+ """
93
+
94
+ :param batch:List[data_set[i]]
95
+ :return:
96
+ """
97
+ texts = [item["text"] for item in batch]
98
+ images = [Image.open(io.BytesIO(item["img_bytes"])).convert("RGB") for item in batch]
99
+
100
+ try:
101
+ pixel_values = processor(
102
+ images=images,
103
+ return_tensors="pt"
104
+ )["pixel_values"].bfloat16()
105
+ except Exception as e:
106
+ logger.error(f"转换成pixel_values失败:{e}, 会选取一些重复的数据进行替代")
107
+ # 先获取正常的数据id
108
+ normal_ids = []
109
+ for idx, img in enumerate(images):
110
+ try:
111
+ _ = processor(images=[img], return_tensors="pt")
112
+ normal_ids.append(idx)
113
+ except:
114
+ continue
115
+ if not normal_ids:
116
+ # 彻底没救了,gg
117
+ raise
118
+ # 然后组成一个无错误的texts和images
119
+ normal_texts, norm_images = [], []
120
+ while True:
121
+ for idx in normal_ids:
122
+ normal_texts.append(copy.deepcopy(texts[idx]))
123
+ norm_images.append(copy.deepcopy(images[idx]))
124
+ if len(normal_texts) == len(batch):
125
+ break
126
+ if len(normal_texts) == len(batch):
127
+ break
128
+ # 重新赋值并生成pixel values
129
+ texts, images = normal_texts, norm_images
130
+ pixel_values = processor(
131
+ images=images,
132
+ return_tensors="pt"
133
+ )["pixel_values"].bfloat16()
134
+ teacher_ipt = tokenizer(texts, padding=padding, truncation=True, max_length=max_length, return_tensors="pt")
135
+ student_ipt = tokenizer(
136
+ # +2是因为要考虑start token和end token
137
+ ["<|jasper_img_token|>" * (model_conf.num_img_tokens + 2)] * len(batch),
138
+ padding="longest", return_tensors="pt"
139
+ )
140
+ student_ipt["pixel_values"] = pixel_values
141
+ return {"teacher_ipt": teacher_ipt, "student_ipt": student_ipt}
142
+
143
+
144
+ def save_model():
145
+ checkpoint_dir = join(output_dir, f"step_{completed_steps}")
146
+ # accelerator.save_state(checkpoint_dir, safe_serialization=True)
147
+ accelerator.wait_for_everyone()
148
+ if accelerator.is_main_process:
149
+ logger.info(f"保存模型{checkpoint_dir}")
150
+ accelerator.unwrap_model(model).save_pretrained(checkpoint_dir, max_shard_size="32GB", safe_serialization=True)
151
+ # torch.save(accelerator.unwrap_model(optimizer).state_dict(), join(checkpoint_dir, "optimizer.bin"))
152
+ processor.save_pretrained(checkpoint_dir)
153
+ # tokenizer.save_pretrained(checkpoint_dir)
154
+ # cppy file
155
+ shutil.copy("./jasper_model/configuration_jasper_vl.py", join(checkpoint_dir, "configuration_jasper_vl.py"))
156
+ shutil.copy("./jasper_model/modeling_jasper_vl.py", join(checkpoint_dir, "modeling_jasper_vl.py"))
157
+ shutil.copy("./jasper_model/tokenization_qwen.py", join(checkpoint_dir, "tokenization_qwen.py"))
158
+ # change config json
159
+ with open(join(checkpoint_dir, "config.json"), "r", encoding="utf8") as fr:
160
+ config = json.load(fr)
161
+ if "_name_or_path" in config:
162
+ config.pop("_name_or_path")
163
+ config["auto_map"] = {
164
+ "AutoModel": "modeling_jasper_vl.JasperVL",
165
+ "AutoConfig": "configuration_jasper_vl.JasperVLConfig",
166
+ }
167
+ with open(join(checkpoint_dir, "config.json"), "w", encoding="utf8") as fw:
168
+ json.dump(config, fw, ensure_ascii=False, indent=1)
169
+
170
+ ## modules.json
171
+ with open(os.path.join(checkpoint_dir, "modules.json"), "w", encoding="utf8") as fw:
172
+ json.dump(
173
+ [
174
+ {
175
+ "idx": 0,
176
+ "name": "0",
177
+ "path": "",
178
+ "type": "sentence_transformers.models.Transformer"
179
+ }
180
+ ],
181
+ fw,
182
+ ensure_ascii=False,
183
+ indent=1
184
+ )
185
+ ## sentence_bert_config.json
186
+ shutil.copy(join(model_dir, "added_tokens.json"), join(checkpoint_dir, "added_tokens.json"))
187
+ shutil.copy(join(model_dir, "config_sentence_transformers.json"),
188
+ join(checkpoint_dir, "config_sentence_transformers.json"))
189
+ shutil.copy(join(model_dir, "merges.txt"), join(checkpoint_dir, "merges.txt"))
190
+ shutil.copy(join(model_dir, "sentence_bert_config.json"), join(checkpoint_dir, "sentence_bert_config.json"))
191
+ shutil.copy(join(model_dir, "special_tokens_map.json"), join(checkpoint_dir, "special_tokens_map.json"))
192
+ shutil.copy(join(model_dir, "tokenizer_config.json"), join(checkpoint_dir, "tokenizer_config.json"))
193
+ shutil.copy(join(model_dir, "tokenizer.json"), join(checkpoint_dir, "tokenizer.json"))
194
+ shutil.copy(join(model_dir, "vocab.json"), join(checkpoint_dir, "vocab.json"))
195
+
196
+
197
+ def get_score_diff(vectors):
198
+ scores = torch.matmul(vectors, vectors.T)
199
+ scores = scores[torch.triu(torch.ones_like(scores), diagonal=1).bool()]
200
+ score_diff = scores.reshape((1, -1)) - scores.reshape((-1, 1))
201
+ score_diff = score_diff[torch.triu(torch.ones_like(score_diff), diagonal=1).bool()]
202
+ return score_diff
203
+
204
+
205
+ if __name__ == "__main__":
206
+ # read the configration
207
+ with open(sys.argv[1].strip(), "r", encoding="utf8") as fr:
208
+ conf = yaml.safe_load(fr)
209
+ model_dir = conf["model_path_or_name"]
210
+ max_length = conf["max_length"]
211
+ resume_model_dir = conf["resume_model_dir"]
212
+ output_dir = conf["output_dir"]
213
+ save_steps = conf["save_steps"]
214
+ batch_size = conf["batch_size"]
215
+ project_name = conf["project_name"]
216
+ log_with = conf["log_with"]
217
+ log_init_kwargs = conf["log_init_kwargs"]
218
+ file_path_list_path = conf["file_path_list_path"]
219
+ print_debug_info_prob = conf["print_debug_info_prob"]
220
+ gradient_accumulation_steps = conf["gradient_accumulation_steps"]
221
+ continue_train = conf["continue_train"]
222
+ num_train_epochs = conf["num_train_epochs"]
223
+ lr_scheduler_type = conf["lr_scheduler_type"]
224
+ mse_loss_scale = conf["mse_loss_scale"]
225
+ cosine_loss_scale = conf["cosine_loss_scale"]
226
+ padding = conf["padding"]
227
+ rank_margin = conf["rank_margin"]
228
+ rank_loss_scale = conf["rank_loss_scale"]
229
+ start_index, end_index = conf["start_index"], conf["end_index"]
230
+ scheduler_kwargs = conf.get("scheduler_kwargs", {})
231
+
232
+ seed = conf["seed"]
233
+ # initialize accelerator
234
+ accelerator = Accelerator(
235
+ project_config=ProjectConfiguration(
236
+ project_dir=output_dir,
237
+ logging_dir=join(output_dir, "logs"),
238
+ ),
239
+ gradient_accumulation_steps=gradient_accumulation_steps,
240
+ log_with=log_with,
241
+ kwargs_handlers=[
242
+ accelerate.DistributedDataParallelKwargs(find_unused_parameters=not conf["gradient_checkpointing"])]
243
+ )
244
+
245
+ # output_dir and sth
246
+ with accelerator.main_process_first():
247
+ if accelerator.is_main_process:
248
+ os.makedirs(output_dir, exist_ok=True)
249
+ os.makedirs(join(output_dir, "logs/wandb_logs"), exist_ok=True)
250
+ logger.add(
251
+ join(output_dir, "train_logs.txt"),
252
+ level="DEBUG",
253
+ compression="zip",
254
+ rotation="500 MB",
255
+ # format="{message}"
256
+ )
257
+ shutil.copy(sys.argv[1].strip(), join(output_dir, "train_config.yml"))
258
+
259
+ accelerator.wait_for_everyone()
260
+ if accelerator.is_main_process:
261
+ logger.info(f"accelerator.state:{accelerator.state}")
262
+
263
+ # seed
264
+ set_seed(seed=seed)
265
+ # 加载模型、tokenizer
266
+ processor = SiglipImageProcessor.from_pretrained(model_dir)
267
+ model_conf = JasperVLConfig.from_pretrained(model_dir)
268
+ model = JasperVL.from_pretrained(model_dir, is_text_encoder=False)
269
+ tokenizer = Qwen2TokenizerFast.from_pretrained(model_dir, padding_side="right")
270
+
271
+ for k, v in model.named_parameters():
272
+ if k.startswith("model.") or k.startswith("vector_linear_"):
273
+ v.requires_grad = False
274
+ # 训练 最后三个特殊token
275
+ ## 加了好像没有用,就不加了
276
+ # model.get_input_embeddings().weight.data[-1].requires_grad = True
277
+ # model.get_input_embeddings().weight.data[-3].requires_grad = True
278
+
279
+ if accelerator.is_main_process:
280
+ logger.debug("参数冻结情况")
281
+ for k, v in model.named_parameters():
282
+ logger.debug(f"{k}:{v.shape, v.requires_grad}")
283
+ if conf["gradient_checkpointing"]:
284
+ model.gradient_checkpointing_enable()
285
+
286
+ # 加载数据和teacher vector
287
+ with open(file_path_list_path, "r", encoding="utf8") as fr:
288
+ file_path_list = json.load(fr)[start_index:end_index]
289
+ with open(conf["tar_names_path"], "r", encoding="utf8") as fr:
290
+ tar_names_list = json.load(fr)[start_index:end_index]
291
+ train_dataset = JasperVLDataset_TAR(file_path_list=file_path_list, tar_names_list=tar_names_list)
292
+ train_dataloader = DataLoader(
293
+ dataset=train_dataset,
294
+ shuffle=False,
295
+ collate_fn=collate_fn,
296
+ batch_size=batch_size,
297
+ num_workers=1, # 大于1,会报错,懒得调试了
298
+ drop_last=True,
299
+ # pin_memory=True,
300
+ # pin_memory_device="cuda",
301
+ # prefetch_factor=4,
302
+ )
303
+ # 加载上次的训练状态
304
+ accelerator.wait_for_everyone()
305
+ # init log
306
+ if "wandb" in log_init_kwargs:
307
+ log_init_kwargs["wandb"]["dir"] = join(output_dir, "logs/wandb_logs")
308
+ log_init_kwargs["wandb"]["config"] = {k: json.dumps(v, ensure_ascii=False) for k, v in conf.items()}
309
+ accelerator.init_trackers(
310
+ project_name=project_name,
311
+ init_kwargs=log_init_kwargs
312
+ )
313
+ # Optimizer
314
+ optimizer = torch.optim.AdamW(model.parameters(), lr=conf["learning_rate"])
315
+ # if os.path.exists(join(model_path_or_name, "optimizer.bin")):
316
+ # logger.info("加载优化器权重")
317
+ # optimizer.load_state_dict(torch.load(join(model_path_or_name, "optimizer.bin"), weights_only=False, map_location="cpu"))
318
+ # scheduler
319
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
320
+ max_train_steps = num_update_steps_per_epoch * num_train_epochs
321
+ logger.info(f"max_train_steps:{max_train_steps}")
322
+ if isinstance(conf["num_warmup_steps"], float):
323
+ num_warmup_steps = int(max_train_steps * conf["num_warmup_steps"])
324
+ else:
325
+ num_warmup_steps = conf["num_warmup_steps"]
326
+ lr_scheduler = get_scheduler(
327
+ name=lr_scheduler_type,
328
+ optimizer=optimizer,
329
+ num_warmup_steps=num_warmup_steps,
330
+ num_training_steps=conf.get("max_train_steps") if conf.get("max_train_steps", -1) > 0 else max_train_steps,
331
+ scheduler_specific_kwargs=scheduler_kwargs,
332
+ )
333
+ logger.debug(f"before prepare, len(train_dataloader): {len(train_dataloader)}")
334
+ # prepare everything
335
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
336
+ model, optimizer, train_dataloader, lr_scheduler
337
+ )
338
+ logger.debug(f"after prepare, len(train_dataloader): {len(train_dataloader)}")
339
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
340
+ ## PS: 多机多卡的问题,之前的计算没有考虑num_process,多机读卡下len(train_dataloader)会变小, 接下来的相当于是每张卡的数量
341
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
342
+ max_train_steps = num_train_epochs * num_update_steps_per_epoch
343
+ logger.debug(f"max_train_steps for each card:{max_train_steps}")
344
+ starting_epoch, completed_steps = 0, 0
345
+
346
+ progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
347
+
348
+ if continue_train:
349
+ logger.info(f"Continue train from {model_dir}")
350
+ accelerator.load_state(resume_model_dir)
351
+ resume_step = int(os.path.basename(resume_model_dir).replace("step_", ""))
352
+ completed_steps = resume_step
353
+ starting_epoch = resume_step // num_update_steps_per_epoch
354
+ resume_step -= starting_epoch * num_update_steps_per_epoch
355
+ progress_bar.update(completed_steps)
356
+ # 开始训练
357
+ for epoch in range(starting_epoch, num_train_epochs):
358
+ model.train()
359
+ # skip new `skip_first_batches` to skip the batches when resuming from ckpt
360
+ if continue_train and epoch == starting_epoch:
361
+ # We need to skip steps until we reach the resumed step
362
+ active_dataloader = accelerator.skip_first_batches(
363
+ train_dataloader,
364
+ resume_step * gradient_accumulation_steps
365
+ )
366
+ else:
367
+ # After the first iteration though, we need to go back to the original dataloader
368
+ active_dataloader = train_dataloader
369
+ logger.debug(f"len(active_dataloader): {len(active_dataloader)}")
370
+
371
+ for batch in active_dataloader:
372
+ # get teacher vectors
373
+ with torch.no_grad():
374
+ model.eval()
375
+ all_teacher_vectors = model(**batch["teacher_ipt"])["all_vectors"]
376
+ all_teacher_vectors = [F.normalize(vector.float(), p=2, dim=-1) for vector in all_teacher_vectors]
377
+ # 维度最长的向量作为label
378
+ target_sim_values = torch.matmul(all_teacher_vectors[0],
379
+ all_teacher_vectors[0].T)
380
+ rank_label = torch.where(get_score_diff(all_teacher_vectors[0]) < 0, 1, -1)
381
+ model.train()
382
+ with accelerator.accumulate(model):
383
+ # get student vectors
384
+ all_student_vectors = model(**batch["student_ipt"])["all_vectors"]
385
+ all_student_vectors = [F.normalize(vector.float(), p=2, dim=-1) for vector in all_student_vectors]
386
+ cosine_loss_list, sim_value_loss_list, rank_loss_list = [], [], []
387
+ for teacher_vectors, student_vectors in zip(all_teacher_vectors, all_student_vectors):
388
+ # cosine loss
389
+ cosine_loss_list.append(
390
+ (1 - (student_vectors * teacher_vectors).sum(axis=1).mean()) * cosine_loss_scale
391
+ )
392
+ # 计算老师和学生的相似度值损失
393
+ sim_value_loss_list.append(
394
+ F.mse_loss(
395
+ input=torch.matmul(student_vectors, student_vectors.T),
396
+ target=target_sim_values,
397
+ ) * mse_loss_scale
398
+ )
399
+ # print(sim_value_loss_list)
400
+ # 计算 排序损失函数
401
+ rank_loss_list.append(
402
+ F.relu(get_score_diff(student_vectors) * rank_label + rank_margin).mean() * rank_loss_scale
403
+ )
404
+ cosine_loss = sum(cosine_loss_list) / len(cosine_loss_list)
405
+ sim_value_loss = sum(sim_value_loss_list) / len(sim_value_loss_list)
406
+ rank_loss = sum(rank_loss_list) / len(rank_loss_list)
407
+ loss = cosine_loss + sim_value_loss + rank_loss
408
+
409
+ ########################## debug 信息 #######################################################
410
+ if accelerator.is_main_process and (completed_steps == 10 or random.random() < print_debug_info_prob):
411
+ debug_index = random.randint(0, batch_size - 1)
412
+
413
+ teacher_input_ids = batch["teacher_ipt"]["input_ids"].cpu().numpy()
414
+ teacher_attention_mask = batch["teacher_ipt"]["attention_mask"].cpu().numpy()
415
+
416
+ for debug_k, debug_v in batch["teacher_ipt"].items():
417
+ logger.debug(f"teacher_ipt_{debug_k}.shape: {debug_v.shape}")
418
+ logger.debug(f"teacher input_ids: {teacher_input_ids[debug_index].tolist()}")
419
+ logger.debug(f"teacher input_tokens: {tokenizer.decode(teacher_input_ids[debug_index])}")
420
+ logger.debug(f"teacher attention_mask: {teacher_attention_mask[debug_index].tolist()}")
421
+
422
+ student_input_ids = batch["student_ipt"]["input_ids"].cpu().numpy()
423
+ student_attention_mask = batch["student_ipt"]["attention_mask"].cpu().numpy()
424
+
425
+ for debug_k, debug_v in batch["student_ipt"].items():
426
+ logger.debug(f"student_ipt_{debug_k}.shape: {debug_v.shape}")
427
+ logger.debug(f"student input_ids: {student_input_ids[debug_index].tolist()}")
428
+ logger.debug(f"student input_tokens: {tokenizer.decode(student_input_ids[debug_index])}")
429
+ logger.debug(f"student attention_mask: {student_attention_mask[debug_index].tolist()}")
430
+
431
+ logger.debug(f"teacher_vectors.shape: {teacher_vectors.shape}")
432
+ logger.debug(f"student_vectors.shape: {student_vectors.shape}")
433
+ ###############################################################################################
434
+
435
+ accelerator.backward(loss)
436
+ optimizer.step()
437
+ lr_scheduler.step()
438
+ optimizer.zero_grad()
439
+ if accelerator.sync_gradients:
440
+ progress_bar.update(1)
441
+ completed_steps += 1
442
+ # if completed_steps == 15:
443
+ # save_model()
444
+ if completed_steps % save_steps == 0 and completed_steps > 0:
445
+ save_model()
446
+ # log
447
+ if accelerator.is_main_process:
448
+ curr_lr = float(lr_scheduler.get_last_lr()[-1])
449
+ logger.info(
450
+ f"epoch-{epoch},completed_steps-{completed_steps},lr:{curr_lr},cosine_loss:{cosine_loss.item()},sim_value_loss:{sim_value_loss.item()},rank_loss:{rank_loss.item()}"
451
+ )
452
+ accelerator.log(
453
+ {
454
+ "cosine_loss": cosine_loss.item(),
455
+ "sim_value_loss": sim_value_loss.item(),
456
+ "rank_loss": rank_loss.item(),
457
+ "lr": curr_lr
458
+ },
459
+ step=completed_steps
460
+ )
461
+ # if accelerator.is_main_process:
462
+ # print("model.vs_token_emb[:,:,:4]", model.vs_token_emb[:, :, :4])
463
+ # print("model.ve_token_emb[:,:,:4]", model.ve_token_emb[:, :, :4])
464
+ # 训练结束后保存一次模型
465
+ save_model()
466
+ accelerator.end_training()
scripts/original_stella_jasper_training_codes/run_train_distill_stage1.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf8
2
+ import json
3
+ import os
4
+ import sys
5
+ import yaml
6
+ import torch
7
+ import shutil
8
+ import math
9
+ import random
10
+ import lmdb
11
+ import pickle
12
+ import accelerate
13
+ from loguru import logger
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from transformers import get_scheduler
16
+ from accelerate import Accelerator
17
+ from accelerate.utils import set_seed, ProjectConfiguration
18
+ from tqdm import tqdm
19
+ from typing import List, Union
20
+ from os.path import join
21
+ import torch.nn.functional as F
22
+ from safetensors.torch import load_file, save_file
23
+
24
+
25
+ class JasperDataset_LMDB_RANDOM_ACCESS(Dataset):
26
+
27
+ def __init__(self, file_path_list_or_dir: Union[List[str], str]):
28
+ if isinstance(file_path_list_or_dir, str):
29
+ file_path_list = []
30
+ for name in os.listdir(file_path_list_or_dir):
31
+ if not name.endswith('-lock'):
32
+ continue
33
+ file_path_list.append(join(file_path_list_or_dir, name[:-5]))
34
+ else:
35
+ file_path_list = file_path_list_or_dir
36
+ file_path_list.sort()
37
+ random.seed(42)
38
+ random.shuffle(file_path_list)
39
+ # file_path_list = file_path_list[:20]
40
+ self.lmdb_env_list = [
41
+ lmdb.open(file_path, readonly=True, readahead=False, subdir=False, lock=False)
42
+ for file_path in file_path_list
43
+ ]
44
+ self.lmdb_txn_list = [lmdb_env.begin(write=False, buffers=True) for lmdb_env in self.lmdb_env_list]
45
+ self.num_data_of_env = [lmdb_env.stat()["entries"] for lmdb_env in self.lmdb_env_list]
46
+ self.num_all_data = sum(self.num_data_of_env)
47
+ self.accumulation_numbers = [sum(self.num_data_of_env[:idx + 1]) for idx in range(len(self.num_data_of_env))]
48
+ if accelerator.is_main_process:
49
+ logger.info(f"file_path_list:{file_path_list}")
50
+ logger.info(f"number of data:{self.num_all_data}")
51
+
52
+ def __len__(self):
53
+ return self.num_all_data
54
+
55
+ def __getitem__(self, item):
56
+
57
+ # print("accelerator.local_process_index,item", accelerator.local_process_index, item)
58
+ # rank_env and item in this db
59
+ for env_idx, accum_num in enumerate(self.accumulation_numbers):
60
+ if item < accum_num:
61
+ break
62
+ txn = self.lmdb_txn_list[env_idx]
63
+ item -= self.accumulation_numbers[env_idx - 1] if env_idx > 0 else 0
64
+ data_item = pickle.loads(bytes(txn.get(f"{item}".encode())))
65
+ text, extra = data_item["text"], json.loads(data_item["extra"])
66
+ data_item["text"] = extra["prompt_student"] + data_item["text"]
67
+ return data_item
68
+
69
+
70
+ def collate_fn_jasper_text(batch):
71
+ """
72
+
73
+ :param batch:List[data_set[i]]
74
+ :return:
75
+ """
76
+ all_texts = [item["text"] for item in batch]
77
+ teacher_vectors = torch.tensor(
78
+ [
79
+ [value for col in teacher_vector_cols for value in item[col]]
80
+ for item in batch
81
+ ]
82
+ )
83
+ if len(teacher_vector_cols) > 1:
84
+ teacher_vectors = F.normalize(teacher_vector_cols, p=2, dim=-1)
85
+ ipt = tokenizer(all_texts, padding=padding, truncation=True, max_length=max_length, return_tensors="pt")
86
+ ipt["teacher_vectors"] = teacher_vectors
87
+ return ipt
88
+
89
+
90
+ def save_model():
91
+ checkpoint_dir = join(output_dir, f"step_{completed_steps}")
92
+ # accelerator.save_state(checkpoint_dir, safe_serialization=True)
93
+ accelerator.wait_for_everyone()
94
+ if accelerator.is_main_process:
95
+ logger.info(f"保存模型{checkpoint_dir}")
96
+ accelerator.unwrap_model(model).save_pretrained(checkpoint_dir, max_shard_size="32GB", safe_serialization=True)
97
+ shutil.copy("./jasper_model/modeling_qwen.py", join(checkpoint_dir, "modeling_qwen.py"))
98
+ shutil.copy("./jasper_model/tokenization_qwen.py", join(checkpoint_dir, "tokenization_qwen.py"))
99
+ # change config json
100
+ with open(join(checkpoint_dir, "config.json"), "r", encoding="utf8") as fr:
101
+ config = json.load(fr)
102
+ config.pop("_name_or_path")
103
+ config["auto_map"] = {"AutoModel": "modeling_qwen.JasperTextStella_1_5"}
104
+ with open(join(checkpoint_dir, "config.json"), "w", encoding="utf8") as fw:
105
+ json.dump(config, fw, ensure_ascii=False, indent=1)
106
+
107
+ os.makedirs(join(checkpoint_dir, "1_Pooling"), exist_ok=True)
108
+ config = {
109
+ "word_embedding_dimension": 4096,
110
+ "pooling_mode_cls_token": True,
111
+ "pooling_mode_mean_tokens": False,
112
+ "pooling_mode_max_tokens": False,
113
+ "pooling_mode_mean_sqrt_len_tokens": False,
114
+ "pooling_mode_weightedmean_tokens": False,
115
+ "pooling_mode_lasttoken": False,
116
+ "include_prompt": False
117
+ }
118
+ with open(join(checkpoint_dir, "1_Pooling/config.json"), "w", encoding="utf8") as fw:
119
+ json.dump(config, fw, ensure_ascii=False, indent=1)
120
+ ## modules.json
121
+ with open(os.path.join(checkpoint_dir, "modules.json"), "w", encoding="utf8") as fw:
122
+ json.dump(
123
+ [
124
+ {
125
+ "idx": 0,
126
+ "name": "0",
127
+ "path": "",
128
+ "type": "sentence_transformers.models.Transformer"
129
+ },
130
+ {
131
+ "idx": 1,
132
+ "name": "1",
133
+ "path": "1_Pooling",
134
+ "type": "sentence_transformers.models.Pooling"
135
+ }
136
+ ],
137
+ fw,
138
+ ensure_ascii=False,
139
+ indent=1
140
+ )
141
+ ## sentence_bert_config.json
142
+ shutil.copy(join(model_dir, "added_tokens.json"), join(checkpoint_dir, "added_tokens.json"))
143
+ shutil.copy(join(model_dir, "config_sentence_transformers.json"),
144
+ join(checkpoint_dir, "config_sentence_transformers.json"))
145
+ shutil.copy(join(model_dir, "merges.txt"), join(checkpoint_dir, "merges.txt"))
146
+ shutil.copy(join(model_dir, "sentence_bert_config.json"), join(checkpoint_dir, "sentence_bert_config.json"))
147
+ shutil.copy(join(model_dir, "special_tokens_map.json"), join(checkpoint_dir, "special_tokens_map.json"))
148
+ shutil.copy(join(model_dir, "tokenizer_config.json"), join(checkpoint_dir, "tokenizer_config.json"))
149
+ shutil.copy(join(model_dir, "tokenizer.json"), join(checkpoint_dir, "tokenizer.json"))
150
+ shutil.copy(join(model_dir, "vocab.json"), join(checkpoint_dir, "vocab.json"))
151
+ # 把stella 的 vector weight放进去
152
+ ori_di = load_file(join(checkpoint_dir, "model.safetensors"))
153
+ stella_dense_di = load_file(
154
+ "/home/wcm/jasper/public_models/stella_en_1_5B_v5/2_Dense_8192/model.safetensors"
155
+ )
156
+ # vec
157
+ ori_di["stella_dense.weight"] = stella_dense_di["linear.weight"].clone().detach().bfloat16()
158
+ ori_di["stella_dense.bias"] = stella_dense_di["linear.bias"].clone().detach().bfloat16()
159
+ save_file(ori_di, join(checkpoint_dir, "model.safetensors"), metadata={"format": "pt"})
160
+
161
+
162
+ def get_score_diff(vectors):
163
+ scores = torch.matmul(vectors, vectors.T)
164
+ scores = scores[torch.triu(torch.ones_like(scores), diagonal=1).bool()]
165
+ score_diff = scores.reshape((1, -1)) - scores.reshape((-1, 1))
166
+ score_diff = score_diff[torch.triu(torch.ones_like(score_diff), diagonal=1).bool()]
167
+ return score_diff
168
+
169
+
170
+
171
+
172
+ if __name__ == "__main__":
173
+ # read the configration
174
+ with open(sys.argv[1].strip(), "r", encoding="utf8") as fr:
175
+ conf = yaml.safe_load(fr)
176
+ model_name = conf["model_name"]
177
+ model_dir = conf["model_path_or_name"]
178
+ max_length = conf["max_length"]
179
+ resume_model_dir = conf["resume_model_dir"]
180
+ output_dir = conf["output_dir"]
181
+ save_steps = conf["save_steps"]
182
+ batch_size = conf["batch_size"]
183
+ project_name = conf["project_name"]
184
+ log_with = conf["log_with"]
185
+ log_init_kwargs = conf["log_init_kwargs"]
186
+ file_path_list_or_dir = conf["file_path_list"]
187
+ print_debug_info_prob = conf["print_debug_info_prob"]
188
+ gradient_accumulation_steps = conf["gradient_accumulation_steps"]
189
+ continue_train = conf["continue_train"]
190
+ num_train_epochs = conf["num_train_epochs"]
191
+ lr_scheduler_type = conf["lr_scheduler_type"]
192
+ mse_loss_scale = conf["mse_loss_scale"]
193
+ cosine_loss_scale = conf["cosine_loss_scale"]
194
+ padding = conf["padding"]
195
+ teacher_vector_cols = conf["teacher_vector_cols"]
196
+ rank_margin = conf["rank_margin"]
197
+ rank_loss_scale = conf["rank_loss_scale"]
198
+ used_loss = set(conf["used_loss"].split(";"))
199
+ scheduler_kwargs = conf.get("scheduler_kwargs", {})
200
+ os.environ["ADAPTER_TYPE"] = conf["adapter_type"]
201
+ os.environ["MERGE_VECS"] = "0"
202
+
203
+ seed = conf["seed"]
204
+ CL_LABELS = torch.LongTensor(range(max_length))
205
+ # initialize accelerator
206
+ accelerator = Accelerator(
207
+ project_config=ProjectConfiguration(
208
+ project_dir=output_dir,
209
+ logging_dir=join(output_dir, "logs"),
210
+ ),
211
+ gradient_accumulation_steps=gradient_accumulation_steps,
212
+ log_with=log_with,
213
+ kwargs_handlers=[
214
+ accelerate.DistributedDataParallelKwargs(find_unused_parameters=not conf["gradient_checkpointing"])]
215
+ )
216
+
217
+ # output_dir and sth
218
+ with accelerator.main_process_first():
219
+ if accelerator.is_main_process:
220
+ os.makedirs(output_dir, exist_ok=True)
221
+ os.makedirs(join(output_dir, "logs/wandb_logs"), exist_ok=True)
222
+ logger.add(
223
+ join(output_dir, "train_logs.txt"),
224
+ level="DEBUG",
225
+ compression="zip",
226
+ rotation="500 MB",
227
+ # format="{message}"
228
+ )
229
+ shutil.copy(sys.argv[1].strip(), join(output_dir, "train_config.yml"))
230
+
231
+ accelerator.wait_for_everyone()
232
+ if accelerator.is_main_process:
233
+ logger.info(f"accelerator.state:{accelerator.state}")
234
+
235
+ # seed
236
+ set_seed(seed=seed)
237
+ # 加载模型、tokenizer
238
+ model = MODEL_NAME_INFO[model_name][0].from_pretrained(
239
+ model_dir,
240
+ use_cache=False,
241
+ )
242
+ tokenizer = MODEL_NAME_INFO[model_name][1].from_pretrained(model_dir, padding_side="right")
243
+ model_conf = model.config
244
+ model.padding_side = "right"
245
+
246
+ for k, v in model.named_parameters():
247
+ if k.startswith("model."):
248
+ v.requires_grad = False
249
+
250
+ if accelerator.is_main_process:
251
+ logger.debug("参数冻结情况")
252
+ for k, v in model.named_parameters():
253
+ logger.debug(f"{k}:{v.shape, v.requires_grad}")
254
+ if conf["gradient_checkpointing"]:
255
+ model.gradient_checkpointing_enable()
256
+
257
+ # 加载数据和teacher vector
258
+ train_dataset = JasperDataset_LMDB_RANDOM_ACCESS(file_path_list_or_dir=file_path_list_or_dir)
259
+ train_dataloader = DataLoader(
260
+ dataset=train_dataset,
261
+ shuffle=False,
262
+ collate_fn=collate_fn_jasper_text,
263
+ batch_size=batch_size,
264
+ num_workers=6,
265
+ drop_last=True,
266
+ # pin_memory=True,
267
+ # pin_memory_device="cuda",
268
+ prefetch_factor=4,
269
+ )
270
+ # 加载上次的训练状态
271
+ accelerator.wait_for_everyone()
272
+ # init log
273
+ if "wandb" in log_init_kwargs:
274
+ log_init_kwargs["wandb"]["dir"] = join(output_dir, "logs/wandb_logs")
275
+ log_init_kwargs["wandb"]["config"] = {k: json.dumps(v, ensure_ascii=False) for k, v in conf.items()}
276
+ accelerator.init_trackers(
277
+ project_name=project_name,
278
+ init_kwargs=log_init_kwargs
279
+ )
280
+ # Optimizer
281
+ optimizer = torch.optim.AdamW(model.parameters(), lr=conf["learning_rate"])
282
+ # if os.path.exists(join(model_path_or_name, "optimizer.bin")):
283
+ # optimizer.load_state_dict(torch.load(join(model_path_or_name, "optimizer.bin"), weights_only=False, map_location="cpu"))
284
+ # scheduler
285
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
286
+ max_train_steps = num_update_steps_per_epoch * num_train_epochs
287
+ if isinstance(conf["num_warmup_steps"], float):
288
+ num_warmup_steps = int(max_train_steps * conf["num_warmup_steps"])
289
+ else:
290
+ num_warmup_steps = conf["num_warmup_steps"]
291
+ lr_scheduler = get_scheduler(
292
+ name=lr_scheduler_type,
293
+ optimizer=optimizer,
294
+ num_warmup_steps=num_warmup_steps,
295
+ num_training_steps=max_train_steps,
296
+ scheduler_specific_kwargs=scheduler_kwargs,
297
+ )
298
+ logger.debug(f"before prepare, len(train_dataloader): {len(train_dataloader)}")
299
+ # prepare everything
300
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
301
+ model, optimizer, train_dataloader, lr_scheduler
302
+ )
303
+ logger.debug(f"after prepare, len(train_dataloader): {len(train_dataloader)}")
304
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
305
+ ## PS: 多机多卡的问题,之前的计算没有考虑num_process,多机读卡下len(train_dataloader)会变小, 接下来的相当于是每张卡的数量
306
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
307
+ max_train_steps = num_train_epochs * num_update_steps_per_epoch
308
+ logger.debug(f"max_train_steps for each card:{max_train_steps}")
309
+ starting_epoch, completed_steps = 0, 0
310
+
311
+ progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
312
+
313
+ if continue_train:
314
+ logger.info(f"Continue train from {model_dir}")
315
+ accelerator.load_state(resume_model_dir)
316
+ resume_step = int(os.path.basename(resume_model_dir).replace("step_", ""))
317
+ completed_steps = resume_step
318
+ starting_epoch = resume_step // num_update_steps_per_epoch
319
+ resume_step -= starting_epoch * num_update_steps_per_epoch
320
+ progress_bar.update(completed_steps)
321
+ # 开始训练
322
+ CL_LABELS = CL_LABELS.to(accelerator.device)
323
+ for epoch in range(starting_epoch, num_train_epochs):
324
+ model.train()
325
+ # skip new `skip_first_batches` to skip the batches when resuming from ckpt
326
+ if continue_train and epoch == starting_epoch:
327
+ # We need to skip steps until we reach the resumed step
328
+ active_dataloader = accelerator.skip_first_batches(
329
+ train_dataloader,
330
+ resume_step * gradient_accumulation_steps
331
+ )
332
+ else:
333
+ # After the first iteration though, we need to go back to the original dataloader
334
+ active_dataloader = train_dataloader
335
+ logger.debug(f"len(active_dataloader): {len(active_dataloader)}")
336
+
337
+ for batch in active_dataloader:
338
+ teacher_vectors = batch.pop("teacher_vectors")
339
+ with accelerator.accumulate(model):
340
+ attention_mask = batch["attention_mask"]
341
+ model_output = model(**batch)
342
+ student_vectors = model_output["token_embeddings"].float()[:, 0]
343
+ student_vectors = F.normalize(student_vectors, p=2, dim=-1)
344
+ # 计算cosine loss
345
+ cosine_loss = (1 - (student_vectors * teacher_vectors).sum(axis=1).mean()) * cosine_loss_scale
346
+ # 计算老师和学生的相似度值损失
347
+ sim_value_loss = F.mse_loss(
348
+ input=torch.matmul(student_vectors, student_vectors.T),
349
+ target=torch.matmul(teacher_vectors, teacher_vectors.T),
350
+ ) * mse_loss_scale
351
+ # 计算 排序损失函数
352
+ ## 首先获取 rank_labellabel
353
+ rank_label = torch.where(get_score_diff(teacher_vectors) < 0, 1, -1)
354
+ rank_loss = F.relu(get_score_diff(student_vectors) * rank_label + rank_margin).mean() * rank_loss_scale
355
+
356
+ loss = cosine_loss
357
+ if "sim_value_loss" in used_loss:
358
+ loss = loss + sim_value_loss
359
+ if "rank_loss" in used_loss:
360
+ loss = loss + rank_loss
361
+ ########################## debug 信息 #######################################################
362
+ if accelerator.is_main_process and (completed_steps == 10 or random.random() < print_debug_info_prob):
363
+ input_ids = batch["input_ids"].cpu().numpy()
364
+ attention_mask = batch["attention_mask"].cpu().numpy()
365
+ debug_index = random.randint(0, len(input_ids) - 1)
366
+ for debug_k, debug_v in batch.items():
367
+ logger.debug(f"{debug_k}.shape: {debug_v.shape}")
368
+ logger.debug(f"debug_index: {debug_index}")
369
+ logger.debug(f"input_ids: {input_ids[debug_index].tolist()}")
370
+ logger.debug(f"input_tokens: {tokenizer.decode(input_ids[debug_index])}")
371
+ logger.debug(f"attention_mask: {attention_mask[debug_index].tolist()}")
372
+ logger.debug(f"teacher_vectors.shape: {teacher_vectors.shape}")
373
+ logger.debug(f"student_vectors.shape: {student_vectors.shape}")
374
+ ###############################################################################################
375
+
376
+ accelerator.backward(loss)
377
+ optimizer.step()
378
+ lr_scheduler.step()
379
+ optimizer.zero_grad()
380
+ if accelerator.sync_gradients:
381
+ progress_bar.update(1)
382
+ completed_steps += 1
383
+ if completed_steps == 15:
384
+ save_model()
385
+ if completed_steps % save_steps == 0 and completed_steps > 0:
386
+ save_model()
387
+ # log
388
+ if accelerator.is_main_process:
389
+ curr_lr = float(lr_scheduler.get_last_lr()[-1])
390
+ logger.info(
391
+ f"epoch-{epoch},completed_steps-{completed_steps},lr:{curr_lr},cosine_loss:{cosine_loss.item()},sim_value_loss:{sim_value_loss.item()},rank_loss:{rank_loss.item()}"
392
+ )
393
+ accelerator.log(
394
+ {
395
+ "cosine_loss": cosine_loss.item(),
396
+ "sim_value_loss": sim_value_loss.item(),
397
+ "rank_loss": rank_loss.item(),
398
+ "lr": curr_lr
399
+ },
400
+ step=completed_steps
401
+ )
402
+
403
+
404
+ save_model()
405
+ accelerator.end_training()
scripts/original_stella_jasper_training_codes/run_train_distill_stage2.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf8
2
+ import json
3
+ import os
4
+ import sys
5
+ import yaml
6
+ import torch
7
+ import shutil
8
+ import math
9
+ import random
10
+ import lmdb
11
+ import pickle
12
+ import accelerate
13
+ from loguru import logger
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from transformers import get_scheduler
16
+ from accelerate import Accelerator
17
+ from accelerate.utils import set_seed, ProjectConfiguration
18
+ from tqdm import tqdm
19
+ from typing import List, Union
20
+ from os.path import join
21
+ import torch.nn.functional as F
22
+ from jasper_model.modeling_jasper_vl import JasperVL
23
+ from jasper_model.tokenization_qwen import Qwen2TokenizerFast
24
+
25
+
26
+ class JasperDataset_LMDB_RANDOM_ACCESS(Dataset):
27
+
28
+ def __init__(self, file_path_list_or_dir: Union[List[str], str]):
29
+ if isinstance(file_path_list_or_dir, str):
30
+ file_path_list = []
31
+ for name in os.listdir(file_path_list_or_dir):
32
+ if not name.endswith('-lock'):
33
+ continue
34
+ file_path_list.append(join(file_path_list_or_dir, name[:-5]))
35
+ else:
36
+ file_path_list = file_path_list_or_dir
37
+ file_path_list.sort()
38
+ random.seed(seed)
39
+ random.shuffle(file_path_list)
40
+ # file_path_list = file_path_list[:20]
41
+ self.lmdb_env_list = [
42
+ lmdb.open(file_path, readonly=True, readahead=False, subdir=False, lock=False)
43
+ for file_path in file_path_list
44
+ ]
45
+ self.lmdb_txn_list = [lmdb_env.begin(write=False, buffers=True) for lmdb_env in self.lmdb_env_list]
46
+ self.num_data_of_env = [lmdb_env.stat()["entries"] for lmdb_env in self.lmdb_env_list]
47
+ self.num_all_data = sum(self.num_data_of_env)
48
+ self.accumulation_numbers = [sum(self.num_data_of_env[:idx + 1]) for idx in range(len(self.num_data_of_env))]
49
+ if accelerator.is_main_process:
50
+ logger.info(f"file_path_list:{file_path_list}")
51
+ logger.info(f"number of data:{self.num_all_data}")
52
+
53
+ def __len__(self):
54
+ return self.num_all_data
55
+
56
+ def __getitem__(self, item):
57
+
58
+ # print("accelerator.local_process_index,item", accelerator.local_process_index, item)
59
+ # rank_env and item in this db
60
+ for env_idx, accum_num in enumerate(self.accumulation_numbers):
61
+ if item < accum_num:
62
+ break
63
+ txn = self.lmdb_txn_list[env_idx]
64
+ item -= self.accumulation_numbers[env_idx - 1] if env_idx > 0 else 0
65
+ data_item = pickle.loads(bytes(txn.get(f"{item}".encode())))
66
+ text, extra = data_item["text"], json.loads(data_item["extra"])
67
+ data_item["text"] = extra["prompt_student"] + data_item["text"]
68
+ return data_item
69
+
70
+
71
+ def collate_fn_jasper_text(batch):
72
+ """
73
+
74
+ :param batch:List[data_set[i]]
75
+ :return:
76
+ """
77
+ all_texts = [item["text"] for item in batch]
78
+ teacher_vectors = torch.tensor(
79
+ [
80
+ [value for col in teacher_vector_cols for value in item[col]]
81
+ for item in batch
82
+ ]
83
+ )
84
+ if len(teacher_vector_cols) > 1:
85
+ teacher_vectors = F.normalize(teacher_vectors, p=2, dim=-1)
86
+ ipt = tokenizer(all_texts, padding=padding, truncation=True, max_length=max_length, return_tensors="pt")
87
+ ipt["teacher_vectors"] = teacher_vectors
88
+ return ipt
89
+
90
+
91
+ def save_model():
92
+ checkpoint_dir = join(output_dir, f"step_{completed_steps}")
93
+ # accelerator.save_state(checkpoint_dir, safe_serialization=True)
94
+ accelerator.wait_for_everyone()
95
+ if accelerator.is_main_process:
96
+ logger.info(f"保存模型{checkpoint_dir}")
97
+ # 再存储一次方便直接加载
98
+ accelerator.unwrap_model(model).save_pretrained(checkpoint_dir, max_shard_size="32GB", safe_serialization=True)
99
+
100
+ # cppy file
101
+ shutil.copy("./jasper_model/configuration_jasper_vl.py", join(checkpoint_dir, "configuration_jasper_vl.py"))
102
+ shutil.copy("./jasper_model/modeling_jasper_vl.py", join(checkpoint_dir, "modeling_jasper_vl.py"))
103
+ shutil.copy("./jasper_model/tokenization_qwen.py", join(checkpoint_dir, "tokenization_qwen.py"))
104
+ # change config json
105
+ with open(join(checkpoint_dir, "config.json"), "r", encoding="utf8") as fr:
106
+ config = json.load(fr)
107
+ if "_name_or_path" in config:
108
+ config.pop("_name_or_path")
109
+ config["auto_map"] = {
110
+ "AutoModel": "modeling_jasper_vl.JasperVL",
111
+ "AutoConfig": "configuration_jasper_vl.JasperVLConfig",
112
+ }
113
+ with open(join(checkpoint_dir, "config.json"), "w", encoding="utf8") as fw:
114
+ json.dump(config, fw, ensure_ascii=False, indent=1)
115
+
116
+ os.makedirs(join(checkpoint_dir, "1_Pooling"), exist_ok=True)
117
+ config = {
118
+ "word_embedding_dimension": 12288,
119
+ "pooling_mode_cls_token": True,
120
+ "pooling_mode_mean_tokens": False,
121
+ "pooling_mode_max_tokens": False,
122
+ "pooling_mode_mean_sqrt_len_tokens": False,
123
+ "pooling_mode_weightedmean_tokens": False,
124
+ "pooling_mode_lasttoken": False,
125
+ "include_prompt": False
126
+ }
127
+ with open(join(checkpoint_dir, "1_Pooling/config.json"), "w", encoding="utf8") as fw:
128
+ json.dump(config, fw, ensure_ascii=False, indent=1)
129
+ ## modules.json
130
+ with open(os.path.join(checkpoint_dir, "modules.json"), "w", encoding="utf8") as fw:
131
+ json.dump(
132
+ [
133
+ {
134
+ "idx": 0,
135
+ "name": "0",
136
+ "path": "",
137
+ "type": "sentence_transformers.models.Transformer"
138
+ },
139
+ {
140
+ "idx": 1,
141
+ "name": "1",
142
+ "path": "1_Pooling",
143
+ "type": "sentence_transformers.models.Pooling"
144
+ }
145
+ ],
146
+ fw,
147
+ ensure_ascii=False,
148
+ indent=1
149
+ )
150
+ ## sentence_bert_config.json
151
+ shutil.copy(join(model_dir, "added_tokens.json"), join(checkpoint_dir, "added_tokens.json"))
152
+ shutil.copy(join(model_dir, "config_sentence_transformers.json"),
153
+ join(checkpoint_dir, "config_sentence_transformers.json"))
154
+ shutil.copy(join(model_dir, "merges.txt"), join(checkpoint_dir, "merges.txt"))
155
+ shutil.copy(join(model_dir, "sentence_bert_config.json"), join(checkpoint_dir, "sentence_bert_config.json"))
156
+ shutil.copy(join(model_dir, "special_tokens_map.json"), join(checkpoint_dir, "special_tokens_map.json"))
157
+ shutil.copy(join(model_dir, "tokenizer_config.json"), join(checkpoint_dir, "tokenizer_config.json"))
158
+ shutil.copy(join(model_dir, "tokenizer.json"), join(checkpoint_dir, "tokenizer.json"))
159
+ shutil.copy(join(model_dir, "vocab.json"), join(checkpoint_dir, "vocab.json"))
160
+
161
+
162
+ def get_score_diff(vectors):
163
+ scores = torch.matmul(vectors, vectors.T)
164
+ scores = scores[torch.triu(torch.ones_like(scores), diagonal=1).bool()]
165
+ score_diff = scores.reshape((1, -1)) - scores.reshape((-1, 1))
166
+ score_diff = score_diff[torch.triu(torch.ones_like(score_diff), diagonal=1).bool()]
167
+ return score_diff
168
+
169
+
170
+ if __name__ == "__main__":
171
+ # read the configration
172
+ with open(sys.argv[1].strip(), "r", encoding="utf8") as fr:
173
+ conf = yaml.safe_load(fr)
174
+ model_dir = conf["model_path_or_name"]
175
+ max_length = conf["max_length"]
176
+ resume_model_dir = conf["resume_model_dir"]
177
+ output_dir = conf["output_dir"]
178
+ save_steps = conf["save_steps"]
179
+ batch_size = conf["batch_size"]
180
+ project_name = conf["project_name"]
181
+ log_with = conf["log_with"]
182
+ log_init_kwargs = conf["log_init_kwargs"]
183
+ file_path_list_or_dir = conf["file_path_list"]
184
+ print_debug_info_prob = conf["print_debug_info_prob"]
185
+ gradient_accumulation_steps = conf["gradient_accumulation_steps"]
186
+ continue_train = conf["continue_train"]
187
+ num_train_epochs = conf["num_train_epochs"]
188
+ lr_scheduler_type = conf["lr_scheduler_type"]
189
+ mse_loss_scale = conf["mse_loss_scale"]
190
+ cosine_loss_scale = conf["cosine_loss_scale"]
191
+ padding = conf["padding"]
192
+ teacher_vector_cols = conf["teacher_vector_cols"]
193
+ rank_margin = conf["rank_margin"]
194
+ rank_loss_scale = conf["rank_loss_scale"]
195
+ scheduler_kwargs = conf.get("scheduler_kwargs", {})
196
+
197
+ seed = conf["seed"]
198
+ # initialize accelerator
199
+ accelerator = Accelerator(
200
+ project_config=ProjectConfiguration(
201
+ project_dir=output_dir,
202
+ logging_dir=join(output_dir, "logs"),
203
+ ),
204
+ gradient_accumulation_steps=gradient_accumulation_steps,
205
+ log_with=log_with,
206
+ kwargs_handlers=[
207
+ accelerate.DistributedDataParallelKwargs(find_unused_parameters=not conf["gradient_checkpointing"])]
208
+ )
209
+
210
+ # output_dir and sth
211
+ with accelerator.main_process_first():
212
+ if accelerator.is_main_process:
213
+ os.makedirs(output_dir, exist_ok=True)
214
+ os.makedirs(join(output_dir, "logs/wandb_logs"), exist_ok=True)
215
+ logger.add(
216
+ join(output_dir, "train_logs.txt"),
217
+ level="DEBUG",
218
+ compression="zip",
219
+ rotation="500 MB",
220
+ # format="{message}"
221
+ )
222
+ shutil.copy(sys.argv[1].strip(), join(output_dir, "train_config.yml"))
223
+
224
+ accelerator.wait_for_everyone()
225
+ if accelerator.is_main_process:
226
+ logger.info(f"accelerator.state:{accelerator.state}")
227
+
228
+ # seed
229
+ set_seed(seed=seed)
230
+ # 加载模型、tokenizer
231
+ model = JasperVL.from_pretrained(model_dir)
232
+ tokenizer = Qwen2TokenizerFast.from_pretrained(model_dir, padding_side="right")
233
+ for k, v in model.named_parameters():
234
+ if k.startswith("model."):
235
+ v.requires_grad = False
236
+ if "model.norm.weight" in k or "layers.27" in k:
237
+ v.requires_grad = True
238
+
239
+ if accelerator.is_main_process:
240
+ logger.debug("被训练的参数如下:")
241
+ for k, v in model.named_parameters():
242
+ if v.requires_grad:
243
+ logger.debug(f"{k}:{v.shape, v.requires_grad}")
244
+ if conf["gradient_checkpointing"]:
245
+ model.gradient_checkpointing_enable()
246
+
247
+ # 加载数据和teacher vector
248
+ train_dataset = JasperDataset_LMDB_RANDOM_ACCESS(file_path_list_or_dir=file_path_list_or_dir)
249
+ train_dataloader = DataLoader(
250
+ dataset=train_dataset,
251
+ shuffle=False,
252
+ collate_fn=collate_fn_jasper_text,
253
+ batch_size=batch_size,
254
+ num_workers=6,
255
+ drop_last=True,
256
+ # pin_memory=True,
257
+ # pin_memory_device="cuda",
258
+ prefetch_factor=4,
259
+ )
260
+ # 加载上次的训练状态
261
+ accelerator.wait_for_everyone()
262
+ # init log
263
+ if "wandb" in log_init_kwargs:
264
+ log_init_kwargs["wandb"]["dir"] = join(output_dir, "logs/wandb_logs")
265
+ log_init_kwargs["wandb"]["config"] = {k: json.dumps(v, ensure_ascii=False) for k, v in conf.items()}
266
+ accelerator.init_trackers(
267
+ project_name=project_name,
268
+ init_kwargs=log_init_kwargs
269
+ )
270
+ # Optimizer
271
+ optimizer = torch.optim.AdamW(model.parameters(), lr=conf["learning_rate"])
272
+ # if os.path.exists(join(model_path_or_name, "optimizer.bin")):
273
+ # optimizer.load_state_dict(torch.load(join(model_path_or_name, "optimizer.bin"), weights_only=False, map_location="cpu"))
274
+ # scheduler
275
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
276
+ max_train_steps = num_update_steps_per_epoch * num_train_epochs
277
+ if isinstance(conf["num_warmup_steps"], float):
278
+ num_warmup_steps = int(max_train_steps * conf["num_warmup_steps"])
279
+ else:
280
+ num_warmup_steps = conf["num_warmup_steps"]
281
+ lr_scheduler = get_scheduler(
282
+ name=lr_scheduler_type,
283
+ optimizer=optimizer,
284
+ num_warmup_steps=num_warmup_steps,
285
+ num_training_steps=max_train_steps,
286
+ scheduler_specific_kwargs=scheduler_kwargs,
287
+ )
288
+ logger.debug(f"before prepare, len(train_dataloader): {len(train_dataloader)}")
289
+ # prepare everything
290
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
291
+ model, optimizer, train_dataloader, lr_scheduler
292
+ )
293
+ logger.debug(f"after prepare, len(train_dataloader): {len(train_dataloader)}")
294
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
295
+ ## PS: 多机多卡的问题,之前的计算没有考虑num_process,多机读卡下len(train_dataloader)会变小, 接下来的相当于是每张卡的数量
296
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
297
+ max_train_steps = num_train_epochs * num_update_steps_per_epoch
298
+ logger.debug(f"max_train_steps for each card:{max_train_steps}")
299
+ starting_epoch, completed_steps = 0, 0
300
+
301
+ progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
302
+
303
+ if continue_train:
304
+ logger.info(f"Continue train from {model_dir}")
305
+ accelerator.load_state(resume_model_dir)
306
+ resume_step = int(os.path.basename(resume_model_dir).replace("step_", ""))
307
+ completed_steps = resume_step
308
+ starting_epoch = resume_step // num_update_steps_per_epoch
309
+ resume_step -= starting_epoch * num_update_steps_per_epoch
310
+ progress_bar.update(completed_steps)
311
+ # 开始训练
312
+ for epoch in range(starting_epoch, num_train_epochs):
313
+ model.train()
314
+ # skip new `skip_first_batches` to skip the batches when resuming from ckpt
315
+ if continue_train and epoch == starting_epoch:
316
+ # We need to skip steps until we reach the resumed step
317
+ active_dataloader = accelerator.skip_first_batches(
318
+ train_dataloader,
319
+ resume_step * gradient_accumulation_steps
320
+ )
321
+ else:
322
+ # After the first iteration though, we need to go back to the original dataloader
323
+ active_dataloader = train_dataloader
324
+ logger.debug(f"len(active_dataloader): {len(active_dataloader)}")
325
+
326
+ for batch in active_dataloader:
327
+ teacher_vectors = batch.pop("teacher_vectors")
328
+ with accelerator.accumulate(model):
329
+ attention_mask = batch["attention_mask"]
330
+ model_output = model(**batch)
331
+ student_vectors = model_output["token_embeddings"].float()[:, 0]
332
+ student_vectors = F.normalize(student_vectors, p=2, dim=-1)
333
+ # 计算cosine loss
334
+ cosine_loss = (1 - (student_vectors * teacher_vectors).sum(axis=1).mean()) * cosine_loss_scale
335
+ # 计算老师和学生的相似度值损失
336
+ sim_value_loss = F.mse_loss(
337
+ input=torch.matmul(student_vectors, student_vectors.T),
338
+ target=torch.matmul(teacher_vectors, teacher_vectors.T),
339
+ ) * mse_loss_scale
340
+ # 计算 排序损失函数
341
+ ## 首先获取 rank_labellabel
342
+ rank_label = torch.where(get_score_diff(teacher_vectors) < 0, 1, -1)
343
+ rank_loss = F.relu(get_score_diff(student_vectors) * rank_label + rank_margin).mean() * rank_loss_scale
344
+
345
+ loss = cosine_loss + sim_value_loss + rank_loss
346
+ ########################## debug 信息 #######################################################
347
+ if accelerator.is_main_process and (completed_steps == 10 or random.random() < print_debug_info_prob):
348
+ input_ids = batch["input_ids"].cpu().numpy()
349
+ attention_mask = batch["attention_mask"].cpu().numpy()
350
+ debug_index = random.randint(0, len(input_ids) - 1)
351
+ for debug_k, debug_v in batch.items():
352
+ logger.debug(f"{debug_k}.shape: {debug_v.shape}")
353
+ logger.debug(f"debug_index: {debug_index}")
354
+ logger.debug(f"input_ids: {input_ids[debug_index].tolist()}")
355
+ logger.debug(f"input_tokens: {tokenizer.decode(input_ids[debug_index])}")
356
+ logger.debug(f"attention_mask: {attention_mask[debug_index].tolist()}")
357
+ logger.debug(f"teacher_vectors.shape: {teacher_vectors.shape}")
358
+ logger.debug(f"student_vectors.shape: {student_vectors.shape}")
359
+ ###############################################################################################
360
+
361
+ accelerator.backward(loss)
362
+ optimizer.step()
363
+ lr_scheduler.step()
364
+ optimizer.zero_grad()
365
+ if accelerator.sync_gradients:
366
+ progress_bar.update(1)
367
+ completed_steps += 1
368
+ if completed_steps == 15:
369
+ save_model()
370
+ if completed_steps % save_steps == 0 and completed_steps > 0:
371
+ save_model()
372
+ # log
373
+ if accelerator.is_main_process:
374
+ curr_lr = float(lr_scheduler.get_last_lr()[-1])
375
+ logger.info(
376
+ f"epoch-{epoch},completed_steps-{completed_steps},lr:{curr_lr},cosine_loss:{cosine_loss.item()},sim_value_loss:{sim_value_loss.item()},rank_loss:{rank_loss.item()}"
377
+ )
378
+ accelerator.log(
379
+ {
380
+ "cosine_loss": cosine_loss.item(),
381
+ "sim_value_loss": sim_value_loss.item(),
382
+ "rank_loss": rank_loss.item(),
383
+ "lr": curr_lr
384
+ },
385
+ step=completed_steps
386
+ )
387
+
388
+ # 训练结束后保存一次模型
389
+ save_model()
390
+ accelerator.end_training()
scripts/original_stella_jasper_training_codes/run_train_mrl_stage3.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf8
2
+ import json
3
+ import os
4
+ import sys
5
+ import yaml
6
+ import torch
7
+ import shutil
8
+ import math
9
+ import random
10
+ import lmdb
11
+ import pickle
12
+ import accelerate
13
+ from loguru import logger
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from transformers import get_scheduler
16
+ from accelerate import Accelerator
17
+ from accelerate.utils import set_seed, ProjectConfiguration
18
+ from tqdm import tqdm
19
+ from typing import List, Union
20
+ from os.path import join
21
+ import torch.nn.functional as F
22
+ from jasper_model.modeling_jasper_vl import JasperVL
23
+ from jasper_model.tokenization_qwen import Qwen2TokenizerFast
24
+ from jasper_model.configuration_jasper_vl import JasperVLConfig
25
+ from safetensors.torch import load_file
26
+
27
+
28
+ class JasperDataset_LMDB_RANDOM_ACCESS(Dataset):
29
+
30
+ def __init__(self, file_path_list_or_dir: Union[List[str], str]):
31
+ if isinstance(file_path_list_or_dir, str):
32
+ file_path_list = []
33
+ for name in os.listdir(file_path_list_or_dir):
34
+ if not name.endswith('-lock'):
35
+ continue
36
+ file_path_list.append(join(file_path_list_or_dir, name[:-5]))
37
+ else:
38
+ file_path_list = file_path_list_or_dir
39
+ file_path_list.sort()
40
+ random.seed(seed)
41
+ random.shuffle(file_path_list)
42
+ # TODO 之前为了加速我们是顺序训练的,只训练了一部分,现在我们要训练mrl,可以倒过来读取,用新的数据训练
43
+ file_path_list = file_path_list[::-1]
44
+ self.lmdb_env_list = [
45
+ lmdb.open(file_path, readonly=True, readahead=False, subdir=False, lock=False)
46
+ for file_path in file_path_list
47
+ ]
48
+ self.lmdb_txn_list = [lmdb_env.begin(write=False, buffers=True) for lmdb_env in self.lmdb_env_list]
49
+ self.num_data_of_env = [lmdb_env.stat()["entries"] for lmdb_env in self.lmdb_env_list]
50
+ self.num_all_data = sum(self.num_data_of_env)
51
+ self.accumulation_numbers = [sum(self.num_data_of_env[:idx + 1]) for idx in range(len(self.num_data_of_env))]
52
+ if accelerator.is_main_process:
53
+ logger.info(f"file_path_list:{file_path_list}")
54
+ logger.info(f"number of data:{self.num_all_data}")
55
+
56
+ def __len__(self):
57
+ return self.num_all_data
58
+
59
+ def __getitem__(self, item):
60
+
61
+ # print("accelerator.local_process_index,item", accelerator.local_process_index, item)
62
+ # rank_env and item in this db
63
+ for env_idx, accum_num in enumerate(self.accumulation_numbers):
64
+ if item < accum_num:
65
+ break
66
+ txn = self.lmdb_txn_list[env_idx]
67
+ item -= self.accumulation_numbers[env_idx - 1] if env_idx > 0 else 0
68
+ data_item = pickle.loads(bytes(txn.get(f"{item}".encode())))
69
+ text, extra = data_item["text"], json.loads(data_item["extra"])
70
+ data_item["text"] = extra["prompt_student"] + data_item["text"]
71
+ return data_item
72
+
73
+
74
+ def collate_fn_jasper_text(batch):
75
+ """
76
+
77
+ :param batch:List[data_set[i]]
78
+ :return:
79
+ """
80
+ all_texts = [item["text"] for item in batch]
81
+ teacher_vectors = torch.tensor(
82
+ [
83
+ [value for col in teacher_vector_cols for value in item[col]]
84
+ for item in batch
85
+ ]
86
+ )
87
+ if len(teacher_vector_cols) > 1:
88
+ teacher_vectors = F.normalize(teacher_vectors, p=2, dim=-1)
89
+ ipt = tokenizer(all_texts, padding=padding, truncation=True, max_length=max_length, return_tensors="pt")
90
+ ipt["teacher_vectors"] = teacher_vectors
91
+ return ipt
92
+
93
+
94
+ def save_model():
95
+ checkpoint_dir = join(output_dir, f"step_{completed_steps}")
96
+ # accelerator.save_state(checkpoint_dir, safe_serialization=True)
97
+ accelerator.wait_for_everyone()
98
+ if accelerator.is_main_process:
99
+ logger.info(f"保存模型{checkpoint_dir}")
100
+ accelerator.unwrap_model(model).save_pretrained(checkpoint_dir, max_shard_size="32GB", safe_serialization=True)
101
+ # cppy file
102
+ shutil.copy("./jasper_model/configuration_jasper_vl.py", join(checkpoint_dir, "configuration_jasper_vl.py"))
103
+ shutil.copy("./jasper_model/modeling_jasper_vl.py", join(checkpoint_dir, "modeling_jasper_vl.py"))
104
+ shutil.copy("./jasper_model/tokenization_qwen.py", join(checkpoint_dir, "tokenization_qwen.py"))
105
+ # change config json
106
+ with open(join(checkpoint_dir, "config.json"), "r", encoding="utf8") as fr:
107
+ config = json.load(fr)
108
+ if "_name_or_path" in config:
109
+ config.pop("_name_or_path")
110
+ config["auto_map"] = {
111
+ "AutoModel": "modeling_jasper_vl.JasperVL",
112
+ "AutoConfig": "configuration_jasper_vl.JasperVLConfig",
113
+ }
114
+ with open(join(checkpoint_dir, "config.json"), "w", encoding="utf8") as fw:
115
+ json.dump(config, fw, ensure_ascii=False, indent=1)
116
+
117
+ os.makedirs(join(checkpoint_dir, "1_Pooling"), exist_ok=True)
118
+ config = {
119
+ "word_embedding_dimension": 4096,
120
+ "pooling_mode_cls_token": True,
121
+ "pooling_mode_mean_tokens": False,
122
+ "pooling_mode_max_tokens": False,
123
+ "pooling_mode_mean_sqrt_len_tokens": False,
124
+ "pooling_mode_weightedmean_tokens": False,
125
+ "pooling_mode_lasttoken": False,
126
+ "include_prompt": False
127
+ }
128
+ with open(join(checkpoint_dir, "1_Pooling/config.json"), "w", encoding="utf8") as fw:
129
+ json.dump(config, fw, ensure_ascii=False, indent=1)
130
+ ## modules.json
131
+ with open(os.path.join(checkpoint_dir, "modules.json"), "w", encoding="utf8") as fw:
132
+ json.dump(
133
+ [
134
+ {
135
+ "idx": 0,
136
+ "name": "0",
137
+ "path": "",
138
+ "type": "sentence_transformers.models.Transformer"
139
+ },
140
+ {
141
+ "idx": 1,
142
+ "name": "1",
143
+ "path": "1_Pooling",
144
+ "type": "sentence_transformers.models.Pooling"
145
+ }
146
+ ],
147
+ fw,
148
+ ensure_ascii=False,
149
+ indent=1
150
+ )
151
+ ## sentence_bert_config.json
152
+ shutil.copy(join(model_dir, "added_tokens.json"), join(checkpoint_dir, "added_tokens.json"))
153
+ shutil.copy(join(model_dir, "config_sentence_transformers.json"),
154
+ join(checkpoint_dir, "config_sentence_transformers.json"))
155
+ shutil.copy(join(model_dir, "merges.txt"), join(checkpoint_dir, "merges.txt"))
156
+ shutil.copy(join(model_dir, "sentence_bert_config.json"), join(checkpoint_dir, "sentence_bert_config.json"))
157
+ shutil.copy(join(model_dir, "special_tokens_map.json"), join(checkpoint_dir, "special_tokens_map.json"))
158
+ shutil.copy(join(model_dir, "tokenizer_config.json"), join(checkpoint_dir, "tokenizer_config.json"))
159
+ shutil.copy(join(model_dir, "tokenizer.json"), join(checkpoint_dir, "tokenizer.json"))
160
+ shutil.copy(join(model_dir, "vocab.json"), join(checkpoint_dir, "vocab.json"))
161
+
162
+
163
+ def get_score_diff(vectors):
164
+ scores = torch.matmul(vectors, vectors.T)
165
+ scores = scores[torch.triu(torch.ones_like(scores), diagonal=1).bool()]
166
+ score_diff = scores.reshape((1, -1)) - scores.reshape((-1, 1))
167
+ score_diff = score_diff[torch.triu(torch.ones_like(score_diff), diagonal=1).bool()]
168
+ return score_diff
169
+
170
+
171
+ if __name__ == "__main__":
172
+ # read the configration
173
+ with open(sys.argv[1].strip(), "r", encoding="utf8") as fr:
174
+ conf = yaml.safe_load(fr)
175
+ model_dir = conf["model_path_or_name"]
176
+ max_length = conf["max_length"]
177
+ resume_model_dir = conf["resume_model_dir"]
178
+ output_dir = conf["output_dir"]
179
+ save_steps = conf["save_steps"]
180
+ batch_size = conf["batch_size"]
181
+ project_name = conf["project_name"]
182
+ log_with = conf["log_with"]
183
+ log_init_kwargs = conf["log_init_kwargs"]
184
+ file_path_list_or_dir = conf["file_path_list"]
185
+ print_debug_info_prob = conf["print_debug_info_prob"]
186
+ gradient_accumulation_steps = conf["gradient_accumulation_steps"]
187
+ continue_train = conf["continue_train"]
188
+ num_train_epochs = conf["num_train_epochs"]
189
+ lr_scheduler_type = conf["lr_scheduler_type"]
190
+ mse_loss_scale = conf["mse_loss_scale"]
191
+ cosine_loss_scale = conf["cosine_loss_scale"]
192
+ padding = conf["padding"]
193
+ teacher_vector_cols = conf["teacher_vector_cols"]
194
+ rank_margin = conf["rank_margin"]
195
+ rank_loss_scale = conf["rank_loss_scale"]
196
+ scheduler_kwargs = conf.get("scheduler_kwargs", {})
197
+
198
+ seed = conf["seed"]
199
+ # initialize accelerator
200
+ accelerator = Accelerator(
201
+ project_config=ProjectConfiguration(
202
+ project_dir=output_dir,
203
+ logging_dir=join(output_dir, "logs"),
204
+ ),
205
+ gradient_accumulation_steps=gradient_accumulation_steps,
206
+ log_with=log_with,
207
+ kwargs_handlers=[
208
+ accelerate.DistributedDataParallelKwargs(find_unused_parameters=not conf["gradient_checkpointing"])]
209
+ )
210
+
211
+ # output_dir and sth
212
+ with accelerator.main_process_first():
213
+ if accelerator.is_main_process:
214
+ os.makedirs(output_dir, exist_ok=True)
215
+ os.makedirs(join(output_dir, "logs/wandb_logs"), exist_ok=True)
216
+ logger.add(
217
+ join(output_dir, "train_logs.txt"),
218
+ level="DEBUG",
219
+ compression="zip",
220
+ rotation="500 MB",
221
+ # format="{message}"
222
+ )
223
+ shutil.copy(sys.argv[1].strip(), join(output_dir, "train_config.yml"))
224
+
225
+ accelerator.wait_for_everyone()
226
+ if accelerator.is_main_process:
227
+ logger.info(f"accelerator.state:{accelerator.state}")
228
+
229
+ # seed
230
+ set_seed(seed=seed)
231
+ # 加载模型、tokenizer
232
+
233
+ model_conf = JasperVLConfig.from_pretrained(model_dir)
234
+ model = JasperVL(model_conf)
235
+ w_di = load_file(filename=join(model_dir, "model.safetensors"), device="cpu")
236
+ w, b = w_di["vector_linear_12288.weight"].detach(), w_di["vector_linear_12288.bias"].detach()
237
+
238
+ w_di["vector_linear_1024.weight"] = w.reshape(1024, -1, 1536).mean(dim=1, keepdim=False)
239
+ w_di["vector_linear_1024.bias"] = b.reshape(1024, -1).mean(dim=1, keepdim=False)
240
+
241
+ w_di["vector_linear_512.weight"] = w.reshape(512, -1, 1536).mean(dim=1, keepdim=False)
242
+ w_di["vector_linear_512.bias"] = b.reshape(512, -1).mean(dim=1, keepdim=False)
243
+
244
+ w_di["vector_linear_256.weight"] = w.reshape(256, -1, 1536).mean(dim=1, keepdim=False)
245
+ w_di["vector_linear_256.bias"] = b.reshape(256, -1).mean(dim=1, keepdim=False)
246
+ model.load_state_dict(state_dict=w_di, strict=True)
247
+ tokenizer = Qwen2TokenizerFast.from_pretrained(model_dir, padding_side="right")
248
+
249
+ for k, v in model.named_parameters():
250
+ if k.startswith("model."):
251
+ v.requires_grad = False
252
+ if "model.norm.weight" in k or "layers.27" in k or "layers.26" in k or "layers.25" in k:
253
+ v.requires_grad = True
254
+ if accelerator.is_main_process:
255
+ logger.debug("参数冻结情况")
256
+ for k, v in model.named_parameters():
257
+ logger.debug(f"{k}:{v.shape, v.requires_grad}")
258
+ if conf["gradient_checkpointing"]:
259
+ model.gradient_checkpointing_enable()
260
+
261
+ # 加载数据和teacher vector
262
+ train_dataset = JasperDataset_LMDB_RANDOM_ACCESS(file_path_list_or_dir=file_path_list_or_dir)
263
+ train_dataloader = DataLoader(
264
+ dataset=train_dataset,
265
+ shuffle=False,
266
+ collate_fn=collate_fn_jasper_text,
267
+ batch_size=batch_size,
268
+ num_workers=6,
269
+ drop_last=True,
270
+ # pin_memory=True,
271
+ # pin_memory_device="cuda",
272
+ prefetch_factor=4,
273
+ )
274
+ # 加载上次的训练状态
275
+ accelerator.wait_for_everyone()
276
+ # init log
277
+ if "wandb" in log_init_kwargs:
278
+ log_init_kwargs["wandb"]["dir"] = join(output_dir, "logs/wandb_logs")
279
+ log_init_kwargs["wandb"]["config"] = {k: json.dumps(v, ensure_ascii=False) for k, v in conf.items()}
280
+ accelerator.init_trackers(
281
+ project_name=project_name,
282
+ init_kwargs=log_init_kwargs
283
+ )
284
+ # Optimizer
285
+ optimizer = torch.optim.AdamW(model.parameters(), lr=conf["learning_rate"])
286
+ # if os.path.exists(join(model_path_or_name, "optimizer.bin")):
287
+ # optimizer.load_state_dict(torch.load(join(model_path_or_name, "optimizer.bin"), weights_only=False, map_location="cpu"))
288
+ # scheduler
289
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
290
+ max_train_steps = num_update_steps_per_epoch * num_train_epochs
291
+ if isinstance(conf["num_warmup_steps"], float):
292
+ num_warmup_steps = int(max_train_steps * conf["num_warmup_steps"])
293
+ else:
294
+ num_warmup_steps = conf["num_warmup_steps"]
295
+ lr_scheduler = get_scheduler(
296
+ name=lr_scheduler_type,
297
+ optimizer=optimizer,
298
+ num_warmup_steps=num_warmup_steps,
299
+ num_training_steps=max_train_steps,
300
+ scheduler_specific_kwargs=scheduler_kwargs,
301
+ )
302
+ logger.debug(f"before prepare, len(train_dataloader): {len(train_dataloader)}")
303
+ # prepare everything
304
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
305
+ model, optimizer, train_dataloader, lr_scheduler
306
+ )
307
+ logger.debug(f"after prepare, len(train_dataloader): {len(train_dataloader)}")
308
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
309
+ ## PS: 多机多卡的问题,之前的计算没有考虑num_process,多机读卡下len(train_dataloader)会变小, 接下来的相当于是每张卡的数量
310
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
311
+ max_train_steps = num_train_epochs * num_update_steps_per_epoch
312
+ logger.debug(f"max_train_steps for each card:{max_train_steps}")
313
+ starting_epoch, completed_steps = 0, 0
314
+
315
+ progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
316
+
317
+ if continue_train:
318
+ logger.info(f"Continue train from {model_dir}")
319
+ accelerator.load_state(resume_model_dir)
320
+ resume_step = int(os.path.basename(resume_model_dir).replace("step_", ""))
321
+ completed_steps = resume_step
322
+ starting_epoch = resume_step // num_update_steps_per_epoch
323
+ resume_step -= starting_epoch * num_update_steps_per_epoch
324
+ progress_bar.update(completed_steps)
325
+ # 开始训练
326
+ for epoch in range(starting_epoch, num_train_epochs):
327
+ model.train()
328
+ # skip new `skip_first_batches` to skip the batches when resuming from ckpt
329
+ if continue_train and epoch == starting_epoch:
330
+ # We need to skip steps until we reach the resumed step
331
+ active_dataloader = accelerator.skip_first_batches(
332
+ train_dataloader,
333
+ resume_step * gradient_accumulation_steps
334
+ )
335
+ else:
336
+ # After the first iteration though, we need to go back to the original dataloader
337
+ active_dataloader = train_dataloader
338
+ logger.debug(f"len(active_dataloader): {len(active_dataloader)}")
339
+
340
+ for batch in active_dataloader:
341
+ teacher_vectors = batch.pop("teacher_vectors")
342
+ with accelerator.accumulate(model):
343
+ attention_mask = batch["attention_mask"]
344
+ model_output = model(**batch)
345
+ target_sim_values = torch.matmul(teacher_vectors, teacher_vectors.T)
346
+ rank_label = torch.where(get_score_diff(teacher_vectors) < 0, 1, -1)
347
+ sim_value_loss_list, rank_loss_list = [], []
348
+ all_vectors = [v for k, v in model_output.items() if k.startswith("student_vectors_")]
349
+ for student_vectors in all_vectors:
350
+ student_vectors = student_vectors.float()[:, 0]
351
+ student_vectors = F.normalize(student_vectors, p=2, dim=-1)
352
+ if student_vectors.shape[-1] == 12288:
353
+ # 计算cosine loss
354
+ cosine_loss = (1 - (student_vectors * teacher_vectors).sum(axis=1).mean()) * cosine_loss_scale
355
+
356
+ # 计算老师和学生的相似度值损失
357
+ sim_value_loss_list.append(
358
+ F.mse_loss(
359
+ input=torch.matmul(student_vectors, student_vectors.T),
360
+ target=target_sim_values,
361
+ ) * mse_loss_scale
362
+ )
363
+ # print(sim_value_loss_list)
364
+ # 计算 排序损失函数
365
+ rank_loss_list.append(
366
+ F.relu(get_score_diff(student_vectors) * rank_label + rank_margin).mean() * rank_loss_scale
367
+ )
368
+ sim_value_loss = sum(sim_value_loss_list) / len(sim_value_loss_list)
369
+ rank_loss = sum(rank_loss_list) / len(rank_loss_list)
370
+ loss = cosine_loss + sim_value_loss + rank_loss
371
+ ########################## debug 信息 #######################################################
372
+ if accelerator.is_main_process and (completed_steps == 10 or random.random() < print_debug_info_prob):
373
+ input_ids = batch["input_ids"].cpu().numpy()
374
+ attention_mask = batch["attention_mask"].cpu().numpy()
375
+ debug_index = random.randint(0, len(input_ids) - 1)
376
+ for debug_k, debug_v in batch.items():
377
+ logger.debug(f"{debug_k}.shape: {debug_v.shape}")
378
+ logger.debug(f"debug_index: {debug_index}")
379
+ logger.debug(f"input_ids: {input_ids[debug_index].tolist()}")
380
+ logger.debug(f"input_tokens: {tokenizer.decode(input_ids[debug_index])}")
381
+ logger.debug(f"attention_mask: {attention_mask[debug_index].tolist()}")
382
+ logger.debug(f"teacher_vectors.shape: {teacher_vectors.shape}")
383
+ logger.debug(f"student_vectors.shape: {student_vectors.shape}")
384
+ ###############################################################################################
385
+
386
+ accelerator.backward(loss)
387
+ optimizer.step()
388
+ lr_scheduler.step()
389
+ optimizer.zero_grad()
390
+ if accelerator.sync_gradients:
391
+ progress_bar.update(1)
392
+ completed_steps += 1
393
+ if completed_steps == 15:
394
+ save_model()
395
+ if completed_steps % save_steps == 0 and completed_steps > 0:
396
+ save_model()
397
+ # log
398
+ if accelerator.is_main_process:
399
+ curr_lr = float(lr_scheduler.get_last_lr()[-1])
400
+ logger.info(
401
+ f"epoch-{epoch},completed_steps-{completed_steps},lr:{curr_lr},cosine_loss:{cosine_loss.item()},sim_value_loss:{sim_value_loss.item()},rank_loss:{rank_loss.item()}"
402
+ )
403
+ accelerator.log(
404
+ {
405
+ "cosine_loss": cosine_loss.item(),
406
+ "sim_value_loss": sim_value_loss.item(),
407
+ "rank_loss": rank_loss.item(),
408
+ "lr": curr_lr
409
+ },
410
+ step=completed_steps
411
+ )
412
+
413
+ # 训练结束后保存一次模型
414
+ save_model()
415
+ accelerator.end_training()