Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import nncore | |
| import torch | |
| import torch.nn as nn | |
| from peft import LoraConfig, PeftModel, get_peft_model | |
| from transformers import AutoProcessor, HfArgumentParser, TrainingArguments | |
| from videomind.constants import REG_TOKEN, SEG_E_TOKEN, SEG_S_TOKEN | |
| from videomind.dataset import HybridDataCollator, HybridDataset | |
| from videomind.model import MODELS | |
| from videomind.model.builder import build_model | |
| from videomind.train.custom_trainer import CustomTrainer | |
| class ModelArguments: | |
| model_name_or_path: Optional[str] = field(default=None) | |
| base_model: Optional[str] = field(default=None) | |
| conv_type: Optional[str] = field(default=None) | |
| role: Optional[str] = field(default=None) | |
| class DataArguments: | |
| datasets: Optional[str] = field(default=None) | |
| min_video_len: Optional[int] = field(default=-1) | |
| max_video_len: Optional[int] = field(default=-1) | |
| min_num_words: Optional[int] = field(default=-1) | |
| max_num_words: Optional[int] = field(default=-1) | |
| max_retries: Optional[int] = field(default=10) | |
| class CustomArguments: | |
| optim: Optional[str] = field(default='adamw_torch') | |
| group_by_data_type: Optional[bool] = field(default=True) | |
| merge_adapter: Optional[bool] = field(default=False) | |
| lora_enable: Optional[bool] = field(default=False) | |
| lora_type: Optional[str] = field(default='qkvo') | |
| lora_r: Optional[int] = field(default=64) | |
| lora_alpha: Optional[int] = field(default=64) | |
| lora_dropout: Optional[float] = field(default=0.1) | |
| lora_bias: Optional[str] = field(default='none') | |
| lora_lr: Optional[float] = field(default=None) | |
| head_lr: Optional[float] = field(default=None) | |
| tuning_modules: Optional[str] = field(default=None) | |
| save_full_model: Optional[bool] = field(default=False) | |
| remove_unused_columns: Optional[bool] = field(default=False) | |
| class TrainingArguments(CustomArguments, TrainingArguments): | |
| pass | |
| def get_target_modules(model, lora_type, base_model): | |
| lora_type = lora_type.split('_') | |
| assert all(t in ('qkvo', 'linear', 'all') for t in lora_type) | |
| if base_model == 'qwen2_vl': | |
| # all qkvo layers in the visual encoder and the llm | |
| qkvo_keys = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'attn.qkv', 'attn.proj'] | |
| target_modules = set() | |
| for n, m in model.named_modules(): | |
| if not isinstance(m, nn.Linear): | |
| continue | |
| if 'all' not in lora_type and 'visual' in n: | |
| continue | |
| if 'qkvo' in lora_type and not any(n.endswith(k) for k in qkvo_keys): | |
| continue | |
| target_modules.add(n) | |
| else: | |
| raise ValueError(f'unknown base model: {base_model}') | |
| return target_modules | |
| def train(TrainingArguments, Trainer): | |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| assert model_args.role in ('all_in_one', 'planner', 'grounder', 'verifier', 'answerer') | |
| config_cls, model_cls = MODELS[model_args.base_model] | |
| dtype = torch.bfloat16 if training_args.bf16 else torch.float32 | |
| config = config_cls.from_pretrained(model_args.model_name_or_path, torch_dtype=dtype) | |
| config.update(model_args.__dict__) | |
| if config.model_type == 'agent_qwen2_vl': | |
| model, processor = build_model( | |
| model_args.model_name_or_path, | |
| config=config, | |
| is_trainable=True, | |
| merge_adapter=training_args.merge_adapter, | |
| dtype=dtype) | |
| else: | |
| # set do_resize to false to avoid duplicated resizing | |
| # https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py | |
| processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, do_resize=False) | |
| # eager attention has known & unknown bugs | |
| # [4.46.2] broken causality fp16: https://github.com/huggingface/transformers/issues/35151 | |
| # [4.48.1] broken sliding window: https://github.com/huggingface/transformers/issues/35924 | |
| model = model_cls.from_pretrained(model_args.model_name_or_path, config=config, attn_implementation='sdpa') | |
| # save base model path for inference | |
| model.config.base_model_path = model_args.model_name_or_path | |
| # conv parameters may become inf after casting to fp16 | |
| model.reset_conv_parameters() | |
| model.requires_grad_(False) | |
| if training_args.lora_enable and not isinstance(model, PeftModel): | |
| target_modules = get_target_modules(model, training_args.lora_type, model.config.base_model) | |
| tune_lm_head = model.config.role in ('all_in_one', 'grounder', 'verifier') | |
| print(f'LoRA target modules: {target_modules}') | |
| lora_config = LoraConfig( | |
| task_type='CAUSAL_LM', | |
| r=training_args.lora_r, | |
| lora_alpha=training_args.lora_alpha, | |
| lora_dropout=training_args.lora_dropout, | |
| bias=training_args.lora_bias, | |
| target_modules=target_modules, | |
| modules_to_save=['embed_tokens', 'lm_head'] if tune_lm_head else None) | |
| # transformers integration does not support merge_and_unload, use peft instead | |
| model = get_peft_model(model, lora_config, adapter_name=model_args.role) | |
| new_tokens = processor.tokenizer.add_special_tokens( | |
| dict(additional_special_tokens=[REG_TOKEN, SEG_S_TOKEN, SEG_E_TOKEN])) | |
| print(f'Added {new_tokens} new token(s)') | |
| model.config.reg_token_id = processor.tokenizer.convert_tokens_to_ids(REG_TOKEN) | |
| model.config.seg_s_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_S_TOKEN) | |
| model.config.seg_e_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_E_TOKEN) | |
| if new_tokens > 0 and len(processor.tokenizer) > model.config.vocab_size: | |
| print(f'Expanding vocab size: {model.config.vocab_size} -> {len(processor.tokenizer)}') | |
| model.resize_token_embeddings(len(processor.tokenizer)) | |
| i_emb = model.get_input_embeddings().weight.data | |
| o_emb = model.get_output_embeddings().weight.data | |
| i_emb[-new_tokens:] = i_emb[:-new_tokens].mean(0, keepdim=True) | |
| o_emb[-new_tokens:] = o_emb[:-new_tokens].mean(0, keepdim=True) | |
| tuning_modules = [] if training_args.tuning_modules is None else training_args.tuning_modules.split(',') | |
| head_keys = [ | |
| 'vis_proj', 'reg_proj', 'vis_fuse', 'vis_norm', 'vis_pos', 'vis_emb', 'reg_emb', 'pyramid', 'class_head', | |
| 'coord_head', 'coef', 'bundle_loss' | |
| ] | |
| for n, p in model.named_parameters(): | |
| # embed_tokens and lm_head might be handled by lora | |
| if not training_args.lora_enable and new_tokens > 0 and any(k in n for k in ('embed_tokens', 'lm_head')): | |
| p.requires_grad = True | |
| if 'projector' in tuning_modules and 'visual.merger' in n: | |
| p.requires_grad = True | |
| if model_args.role in ('all_in_one', 'grounder') and any(k in n for k in head_keys): | |
| p.requires_grad = True | |
| if training_args.local_rank in (0, -1): | |
| for n, p in model.named_parameters(): | |
| print(p.requires_grad, p.dtype, p.shape, n) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| ratio = round(learnable_params / total_params * 100, 2) if total_params > 0 else 0 | |
| print(f'Total params: {total_params} Learnable params: {learnable_params} ({ratio}%)') | |
| i_size = model.get_input_embeddings().num_embeddings | |
| o_size = model.get_output_embeddings().out_features | |
| assert i_size == o_size, (i_size, o_size) | |
| print(f'Tokenizer size: {len(processor.tokenizer)} Vocab size: {model.config.vocab_size} Embed size: {i_size}') | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=HybridDataCollator(processor.tokenizer), | |
| train_dataset=HybridDataset(processor, model.config, model_args, data_args, training_args), | |
| processor=processor, | |
| head_keys=head_keys) | |
| has_ckpt = bool(nncore.find(training_args.output_dir, 'checkpoint-*')) | |
| trainer.train(resume_from_checkpoint=has_ckpt) | |
| trainer.save_state() | |
| trainer.gather_and_save_model() | |
| if __name__ == '__main__': | |
| train(TrainingArguments, CustomTrainer) | |