from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX from llava.conversation import conv_templates, SeparatorStyle from peft import LoraConfig, get_peft_model, PeftModel from PIL import Image import requests import copy import torch import argparse from dataset.SurgDataset import SurgDataset from accelerate import Accelerator from llava.model.SurgLLaVA import SurgLLaVA import os from tqdm import tqdm import json os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = '1' os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' def parse_args(): parser = argparse.ArgumentParser() # General arguments parser.add_argument('--data_path', type=str, default='/mnt1/lyc/llava_finetune/data_json/instruct_sample_18430_0713_rephrase', help='Data path') parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) parser.add_argument('--wandb', action='store_true') parser.add_argument('--wandb_project', type=str, default='SurgLlaVA') parser.add_argument('--wandb_process_name', type=str, default='finetune') parser.add_argument('--lora_rank', type=int, default=64, help='Rank of the LoRA matrix') parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument('--log_interval', type=int, default=1) parser.add_argument('--eval_interval', type=int, default=3) parser.add_argument('--save_interval', type=int, default=3) parser.add_argument('--ckpt_dir', type=str, default='model_ckpt', help='Model directory to store checkpoints') parser.add_argument('--model_name', type=str, default='llava3_mix_instr', help='Model name. This will be used to create a directory in ckpt_dir and show in wandb') parser.add_argument('--gradient_accumulation_steps', type=int, default=12) parser.add_argument('--step_size', type=int, default=300) parser.add_argument('--gamma', type=float, default=0.95, help='gemma value of scheduler') parser.add_argument('--num_epochs', type=int, default=1000) parser.add_argument('--lora', action='store_true', help='Use LoRA if True') parser.add_argument('--test', action='store_true') parser.add_argument('--lora_ckpt_path', type=str, default=None) parser.add_argument('--ckpt_path', type=str, default=None) parser.add_argument('--output_dir', type=str, default='4dor_output', help='output file path, which will store output text.') return parser.parse_args() def main(): args = parse_args() accelerator = Accelerator(project_dir=os.path.join(args.ckpt_dir, args.model_name), log_with="wandb" if args.wandb else None, gradient_accumulation_steps=args.gradient_accumulation_steps) if args.wandb: print(f'[INFO] Using wandb for logging...') accelerator.init_trackers( project_name=args.wandb_project, config=args, init_kwargs={"wandb": {"name": args.wandb_process_name}} ) accelerator.print("[Info] Using wandb for logging...") pretrained = "lmms-lab/llama3-llava-next-8b" model_name = "llava_llama3" tokenizer, llm_model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map='cuda') # Add any other thing you want to pass in llava_model_args # tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) train_dataset = SurgDataset(args, image_processor, llm_model.config, mode='train') test_dataset = SurgDataset(args, image_processor, llm_model.config, mode='test') train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=4) test_dataloader = torch.utils.data.DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=4) print(f'[INFO] Freezing llm model') for param in llm_model.parameters(): param.requires_grad = False llm_model.eval() if args.lora: if args.lora_ckpt_path is not None: print(f'[INFO] Loading LoRA model checkpoint...') llm_model = PeftModel.from_pretrained(llm_model, './model_ckpt/llama3-llava-next-8b-task-lora') llm_model = llm_model.merge_and_unload() else: print(f'[INFO] Creating LoRA ...') peft_config = LoraConfig( lora_alpha=args.lora_rank, lora_dropout=0.05, r=args.lora_rank, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head", ], ) lora_llm = get_peft_model(llm_model, peft_config) llm_model = lora_llm.model train_params = llm_model.parameters() print(f'[INFO] Creating Model ...') model = SurgLLaVA(args, llm_model, tokenizer) model = model.to(torch.bfloat16) optimizer = torch.optim.AdamW(train_params, lr=args.lr, eps=1e-7) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(train_dataloader) * args.step_size // args.gradient_accumulation_steps, gamma=args.gamma) if args.ckpt_path is not None: print(f'[INFO] Load whole pretrained checkpoint...') whole_model = torch.load(os.path.join(args.ckpt_path, 'pytorch_model.bin'), map_location='cpu') model.load_state_dict(whole_model) print(f'[INFO] Preparing accelerator...') model, tokenizer, optimizer, scheduler, train_dataloader, test_dataloader = accelerator.prepare(model, tokenizer, optimizer, scheduler, train_dataloader, test_dataloader) if args.test: # testing code accelerator.print(f'[INFO] Start testing...') model.eval() with torch.no_grad(): os.makedirs(args.output_dir, exist_ok=True) output_list = [] output_tasks = {} for i, batch in tqdm(enumerate(test_dataloader)): raw_data, question, answer, image, image_sizes = batch image = [img for img in image] image_sizes = image_sizes[0] if len(image_sizes) != args.batch_size: image_sizes = [torch.cat(image_sizes)] output = model(image, image_sizes, question) text_output = tokenizer.batch_decode(output, skip_special_tokens=True) output_data = raw_data output_data.update({'answer': text_output, 'question': question}) output_data = [dict(zip(output_data,t)) for t in zip(*output_data.values())] # Need to save the results to avoid conflict between processes with open(f'./temp_{accelerator.process_index}.json', 'w') as f: json.dump(output_data, f, indent = 4) accelerator.wait_for_everyone() # The main process are used to merge all the results if accelerator.is_main_process: for j in range(accelerator.num_processes): with open(f'./temp_{j}.json', 'r') as f: temp_output = json.load(f) for t in temp_output: if t['task'] not in output_tasks.keys(): output_tasks[t['task']] = [] output_tasks[t['task']].append(t) output_list.append(t) os.remove(f'./temp_{j}.json') with open(os.path.join(args.output_dir, f'preds.json'), 'w') as f: json.dump(output_list, f, indent = 4) for k in output_tasks.keys(): with open(os.path.join(args.output_dir, f'preds_{k}.json'), 'w') as f: json.dump(output_tasks[k], f, indent = 4) accelerator.wait_for_everyone() else: # initialize epoch-level metrics accelerator.print(f'[INFO] Start training...') for epoch in tqdm(range(args.num_epochs)): model.train() total_train_loss = 0 for i, batch in enumerate(train_dataloader): optimizer.zero_grad() img_id, question, answer, image, image_sizes = batch image = [img for img in image] image_sizes = image_sizes[0] if len(image_sizes) != args.batch_size: image_sizes = [torch.cat(image_sizes)] output = model(image, image_sizes, question, answer) loss = output.loss # Accelerator requires all params to involve gradient descend. This 'dummy loss' can avoid this issue. for param in model.parameters(): loss += param.sum() * 0.0 accelerator.backward(loss) optimizer.step() scheduler.step() total_train_loss += loss.item() if i % 100 == 0: accelerator.print(f'[Epoch {epoch} Iter {i}] loss: {loss.item()}') accelerator.log({ 'train_loss': loss.item(), 'lr': scheduler.get_last_lr()[0], }) if args.wandb else None # except: # accelerator.print(f"Error: {img_id}, {answer}") total_train_loss /= len(train_dataloader) total_test_loss = None if epoch % args.eval_interval == 0: total_test_loss = 0 model.eval() with torch.no_grad(): for i, batch in enumerate(test_dataloader): raw_data, question, answer, image, image_sizes = batch image = [img for img in image] image_sizes = image_sizes[0] if len(image_sizes) != args.batch_size: image_sizes = [torch.cat(image_sizes)] output = model(image, image_sizes, question, ) text_output = tokenizer.batch_decode(output, skip_special_tokens=True) if i % 100 == 0: img_id = raw_data[0]['id'] accelerator.print(f'[Epoch {epoch} ID {img_id} pred text: {text_output[0]}') accelerator.print(f'[Epoch {epoch} ID {img_id} G T text: {answer[0]}') accelerator.print() total_test_loss /= len(test_dataloader) if epoch % args.save_interval == 0: accelerator.print(f'[INFO] Save model...') save_model_dir = os.path.join(args.ckpt_dir, args.model_name, 'checkpoints', f'checkpoint_{epoch:05d}') lora_save_dir = os.path.join(args.ckpt_dir, args.model_name, 'lora') accelerator.save_state(save_model_dir, safe_serialization=False, total_limit=5) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.model.save_pretrained( lora_save_dir, save_function=accelerator.save, safe_serialization=False ) accelerator.print(f"Model saved at {save_model_dir}") accelerator.log({ 'train_loss': total_train_loss, 'eval_loss': total_test_loss if total_test_loss is not None else None, 'lr': scheduler.get_last_lr()[0], }) if args.wandb else None if __name__ == '__main__': main()