import argparse import json import os import sys from typing import List import torch import transformers from peft import PeftModel from torch.utils.data import DataLoader from tqdm import tqdm from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig from utils import * from collator import TestCollator from prompt import all_prompt from evaluate import get_topk_results, get_metrics_results def test(args): set_seed(args.seed) print(vars(args)) device_map = {"": args.gpu_id} device = torch.device("cuda",args.gpu_id) tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path) if args.lora: model = LlamaForCausalLM.from_pretrained( args.base_model, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device_map, ) model.resize_token_embeddings(len(tokenizer)) model = PeftModel.from_pretrained( model, args.ckpt_path, torch_dtype=torch.bfloat16, device_map=device_map, ) else: model = LlamaForCausalLM.from_pretrained( args.ckpt_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device_map, ) # assert model.config.vocab_size == len(tokenizer) if args.test_prompt_ids == "all": if args.test_task.lower() == "seqrec": prompt_ids = range(len(all_prompt["seqrec"])) elif args.test_task.lower() == "itemsearch": prompt_ids = range(len(all_prompt["itemsearch"])) elif args.test_task.lower() == "fusionseqrec": prompt_ids = range(len(all_prompt["fusionseqrec"])) else: prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")] test_data = load_test_dataset(args) collator = TestCollator(args, tokenizer) all_items = test_data.get_all_items() prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer) test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator, shuffle=True, num_workers=4, pin_memory=True) print("data num:", len(test_data)) model.eval() metrics = args.metrics.split(",") all_prompt_results = [] with torch.no_grad(): for prompt_id in prompt_ids: test_loader.dataset.set_prompt(prompt_id) metrics_results = {} total = 0 for step, batch in enumerate(tqdm(test_loader)): inputs = batch[0].to(device) targets = batch[1] total += len(targets) output = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, # max_length=10, prefix_allowed_tokens_fn=prefix_allowed_tokens, num_beams=args.num_beams, num_return_sequences=args.num_beams, output_scores=True, return_dict_in_generate=True, early_stopping=True, ) output_ids = output["sequences"] scores = output["sequences_scores"] output = tokenizer.batch_decode( output_ids, skip_special_tokens=True ) # print(output) topk_res = get_topk_results(output,scores,targets,args.num_beams, all_items=all_items if args.filter_items else None) batch_metrics_res = get_metrics_results(topk_res, metrics) # print(batch_metrics_res) for m, res in batch_metrics_res.items(): if m not in metrics_results: metrics_results[m] = res else: metrics_results[m] += res if (step+1)%10 == 0: temp={} for m in metrics_results: temp[m] = metrics_results[m] / total print(temp) for m in metrics_results: metrics_results[m] = metrics_results[m] / total all_prompt_results.append(metrics_results) print("======================================================") print("Prompt {} results: ".format(prompt_id), metrics_results) print("======================================================") print("") mean_results = {} min_results = {} max_results = {} for m in metrics: all_res = [_[m] for _ in all_prompt_results] mean_results[m] = sum(all_res)/len(all_res) min_results[m] = min(all_res) max_results[m] = max(all_res) print("======================================================") print("Mean results: ", mean_results) print("Min results: ", min_results) print("Max results: ", max_results) print("======================================================") save_data={} save_data["test_prompt_ids"] = args.test_prompt_ids save_data["mean_results"] = mean_results save_data["min_results"] = min_results save_data["max_results"] = max_results save_data["all_prompt_results"] = all_prompt_results with open(args.results_file, "w") as f: json.dump(save_data, f, indent=4) if __name__ == "__main__": parser = argparse.ArgumentParser(description="LLMRec_test") parser = parse_global_args(parser) parser = parse_dataset_args(parser) parser = parse_test_args(parser) args = parser.parse_args() test(args)