Junyin commited on
Commit
05744dc
·
verified ·
1 Parent(s): bfcae9b

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asset/model.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LC-Rec
2
+
3
+ This is the official PyTorch implementation for the paper:
4
+
5
+ > [Adapting Large Language Models by Integrating Collaborative Semantics for Recommendation](https://arxiv.org/abs/2311.09049)
6
+
7
+ ## Overview
8
+
9
+ We propose **LC-Rec**, a new approach to integrate **L**anguage and **C**ollaborative semantics for improving LLMs in **Rec**ommender systems. To tackle the large gap between the language semantics modeled by LLMs and collaborative semantics implied by recommender systems, we make two major contributions in two aspects. For item indexing, we design a learning-based vector quantization method with uniform semantic mapping, which can assign meaningful and non-conflicting IDs (called item indices) for items. For alignment tuning, we propose a series of specially designed tuning tasks to enhance the integration of collaborative semantics in LLMs. Our fine-tuning tasks enforce LLMs to deeply integrate language and collaborative semantics (characterized by the learned item indices), so as to achieve an effective adaptation to recommender systems.
10
+
11
+ ![model](./asset/model.jpg)
12
+
13
+ ## Requirements
14
+
15
+ ```
16
+ torch==1.13.1+cu117
17
+ accelerate
18
+ bitsandbytes
19
+ deepspeed
20
+ evaluate
21
+ peft
22
+ sentencepiece
23
+ tqdm
24
+ transformers
25
+ ```
26
+
27
+ ## Model Checkpoint
28
+
29
+ The delta weights on the three datasets can be downloaded from huggingface hub ([Instruments](https://huggingface.co/bwzheng0324/lc-rec-instruments-delta), [Arts](https://huggingface.co/bwzheng0324/lc-rec-arts-delta), [Games](https://huggingface.co/bwzheng0324/lc-rec-games-delta)). After downloading, you can add our deltas to the original LLaMA weights to obtain LC-Rec weights:
30
+
31
+ 1. Get the original [LLaMA](https://huggingface.co/huggyllama/llama-7b) weights.
32
+ 2. Use the following scripts to get LC-Rec weights by applying our delta.
33
+
34
+ ```shell
35
+ python -m convert/merge_delta.py \
36
+ --base-model-path /path/to/llama-7b \
37
+ --target-model-path /path/output/lc-rec \
38
+ --delta-path bwzheng0324/lc-rec-games-delta
39
+ ```
40
+
41
+ ## Dataset
42
+
43
+ We use three datasets in our paper, all of which have been uploaded to [Google Drive](https://drive.google.com/drive/folders/1RcJ2M1l5zWPHYuGd9l5Gibcs5w5aI3y6?usp=sharing)
44
+
45
+ ## Train
46
+
47
+ The detailed scripts for all three datasets are in `run.sh`:
48
+
49
+ ```shell
50
+ DATASET=Games
51
+ BASE_MODEL=huggyllama/llama-7b
52
+ DATA_PATH=./data
53
+ OUTPUT_DIR=./ckpt/$DATASET/
54
+
55
+ torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
56
+ --base_model $BASE_MODEL \
57
+ --output_dir $OUTPUT_DIR \
58
+ --dataset $DATASET \
59
+ --data_path $DATA_PATH \
60
+ --per_device_batch_size 8 \
61
+ --gradient_accumulation_steps 2 \
62
+ --learning_rate 5e-5 \
63
+ --epochs 4 \
64
+ --weight_decay 0.01 \
65
+ --save_and_eval_strategy epoch \
66
+ --deepspeed ./config/ds_z3_bf16.json \
67
+ --bf16 \
68
+ --only_train_response \
69
+ --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
70
+ --train_prompt_sample_num 1,1,1,1,1,1 \
71
+ --train_data_sample_num 0,0,0,100000,0,0 \
72
+ --index_file .index.json
73
+
74
+
75
+ cd convert
76
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
77
+ cd ..
78
+ ```
79
+
80
+ ## Test
81
+
82
+ Test with a single GPU:
83
+
84
+ ```shell
85
+ DATASET=Games
86
+ DATA_PATH=./data
87
+ OUTPUT_DIR=./ckpt/$DATASET/
88
+ RESULTS_FILE=./results/$DATASET/result.json
89
+
90
+ python test.py \
91
+ --gpu_id 0 \
92
+ --ckpt_path $CKPT_PATH \
93
+ --dataset $DATASET \
94
+ --data_path $DATA_PATH \
95
+ --results_file $RESULTS_FILE \
96
+ --test_batch_size 1 \
97
+ --num_beams 20 \
98
+ --test_prompt_ids all \
99
+ --index_file .index.json
100
+ ```
101
+
102
+ Test with multiple GPUs:
103
+
104
+ ```shell
105
+ DATASET=Games
106
+ DATA_PATH=./data
107
+ OUTPUT_DIR=./ckpt/$DATASET/
108
+ RESULTS_FILE=./results/$DATASET/result.json
109
+
110
+ torchrun --nproc_per_node=8 --master_port=4324 test_ddp.py \
111
+ --ckpt_path $CKPT_PATH \
112
+ --dataset $DATASET \
113
+ --data_path $DATA_PATH \
114
+ --results_file $RESULTS_FILE \
115
+ --test_batch_size 1 \
116
+ --num_beams 20 \
117
+ --test_prompt_ids all \
118
+ --index_file .index.json
119
+ ```
120
+
121
+ ## Acknowledgement
122
+
123
+ The implementation is based on [HuggingFace](https://github.com/huggingface/transformers).
124
+
asset/model.jpg ADDED

Git LFS Details

  • SHA256: 52223d0ef7f3701a6e40db9997e78c0a7f0d6bfce7965b9f27637e0e25fd1097
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
collator.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import argparse
4
+ from dataclasses import dataclass
5
+
6
+ import transformers
7
+ import math
8
+ from torch.utils.data import Sampler
9
+ import torch.distributed as dist
10
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, T5Tokenizer, T5Config, T5ForConditionalGeneration
11
+
12
+
13
+ class Collator(object):
14
+
15
+ def __init__(self, args, tokenizer):
16
+ self.args = args
17
+ self.only_train_response = args.only_train_response
18
+ self.tokenizer = tokenizer
19
+ if self.tokenizer.pad_token_id is None:
20
+ self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
21
+ # print(self.tokenizer.model_max_length)
22
+
23
+ def __call__(self, batch):
24
+
25
+ input_texts = [d["input_ids"] for d in batch]
26
+ full_texts = [d["labels"] + self.tokenizer.eos_token for d in batch]
27
+
28
+ inputs = self.tokenizer(
29
+ text = full_texts,
30
+ text_target = input_texts,
31
+ return_tensors="pt",
32
+ padding="longest",
33
+ max_length=self.tokenizer.model_max_length,
34
+ truncation=True,
35
+ return_attention_mask=True,
36
+ )
37
+ labels = copy.deepcopy(inputs["input_ids"])
38
+ if self.only_train_response:
39
+ # ignore padding
40
+ labels[labels == self.tokenizer.pad_token_id] = -100
41
+ # ignore input text
42
+ labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100
43
+
44
+ inputs["labels"] = labels
45
+
46
+
47
+ return inputs
48
+
49
+
50
+
51
+ class TestCollator(object):
52
+
53
+ def __init__(self, args, tokenizer):
54
+ self.args = args
55
+ self.tokenizer = tokenizer
56
+ if self.tokenizer.pad_token_id is None:
57
+ self.tokenizer.pad_token_id = 0
58
+
59
+ if isinstance(self.tokenizer, LlamaTokenizer):
60
+ # Allow batched inference
61
+ self.tokenizer.padding_side = "left"
62
+
63
+ def __call__(self, batch):
64
+
65
+ input_texts = [d["input_ids"] for d in batch]
66
+ targets = [d["labels"] for d in batch]
67
+ inputs = self.tokenizer(
68
+ text=input_texts,
69
+ return_tensors="pt",
70
+ padding="longest",
71
+ max_length=self.tokenizer.model_max_length,
72
+ truncation=True,
73
+ return_attention_mask=True,
74
+ )
75
+
76
+ return (inputs, targets)
77
+
config/ds_z2_bf16.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 2,
7
+ "allgather_partitions": true,
8
+ "allgather_bucket_size": 5e8,
9
+ "overlap_comm": true,
10
+ "reduce_scatter": true,
11
+ "reduce_bucket_size": 5e8,
12
+ "contiguous_gradients": true
13
+ },
14
+ "gradient_accumulation_steps": "auto",
15
+ "gradient_clipping": "auto",
16
+ "steps_per_print": 2000,
17
+ "train_batch_size": "auto",
18
+ "train_micro_batch_size_per_gpu": "auto",
19
+ "wall_clock_breakdown": false,
20
+ "flops_profiler": {
21
+ "enabled": true,
22
+ "profile_step": 10,
23
+ "module_depth": -1,
24
+ "top_modules": 3,
25
+ "detailed": true,
26
+ "output_file": "flops_profiler.out"
27
+ }
28
+ }
config/ds_z2_fp16.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "auto_cast": false,
5
+ "loss_scale": 0,
6
+ "initial_scale_power": 16,
7
+ "loss_scale_window": 1000,
8
+ "hysteresis": 2,
9
+ "min_loss_scale": 1
10
+ },
11
+ "zero_optimization": {
12
+ "stage": 2,
13
+ "allgather_partitions": true,
14
+ "allgather_bucket_size": 5e8,
15
+ "overlap_comm": true,
16
+ "reduce_scatter": true,
17
+ "reduce_bucket_size": 5e8,
18
+ "contiguous_gradients": true
19
+ },
20
+ "gradient_accumulation_steps": "auto",
21
+ "gradient_clipping": "auto",
22
+ "steps_per_print": 2000,
23
+ "train_batch_size": "auto",
24
+ "train_micro_batch_size_per_gpu": "auto",
25
+ "wall_clock_breakdown": false,
26
+ "flops_profiler": {
27
+ "enabled": true,
28
+ "profile_step": 10,
29
+ "module_depth": -1,
30
+ "top_modules": 3,
31
+ "detailed": true,
32
+ "output_file": "flops_profiler.out"
33
+ }
34
+ }
config/ds_z3_bf16.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 3,
7
+ "overlap_comm": true,
8
+ "contiguous_gradients": true,
9
+ "sub_group_size": 1e9,
10
+ "reduce_bucket_size": "auto",
11
+ "stage3_prefetch_bucket_size": "auto",
12
+ "stage3_param_persistence_threshold": "auto",
13
+ "stage3_max_live_parameters": 1e9,
14
+ "stage3_max_reuse_distance": 1e9,
15
+ "stage3_gather_16bit_weights_on_model_save": true
16
+ },
17
+ "gradient_accumulation_steps": "auto",
18
+ "gradient_clipping": "auto",
19
+ "steps_per_print": 2000,
20
+ "train_batch_size": "auto",
21
+ "train_micro_batch_size_per_gpu": "auto",
22
+ "wall_clock_breakdown": false,
23
+ "flops_profiler": {
24
+ "enabled": true,
25
+ "profile_step": 10,
26
+ "module_depth": -1,
27
+ "top_modules": 3,
28
+ "detailed": true,
29
+ "output_file": "flops_profiler.out"
30
+ }
31
+ }
config/ds_z3_bf16_save16bit.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 3,
7
+ "overlap_comm": true,
8
+ "contiguous_gradients": true,
9
+ "sub_group_size": 1e9,
10
+ "reduce_bucket_size": "auto",
11
+ "stage3_prefetch_bucket_size": "auto",
12
+ "stage3_param_persistence_threshold": "auto",
13
+ "stage3_max_live_parameters": 1e9,
14
+ "stage3_max_reuse_distance": 1e9,
15
+ "stage3_gather_16bit_weights_on_model_save": true
16
+ },
17
+ "gradient_accumulation_steps": "auto",
18
+ "gradient_clipping": "auto",
19
+ "steps_per_print": 2000,
20
+ "train_batch_size": "auto",
21
+ "train_micro_batch_size_per_gpu": "auto",
22
+ "wall_clock_breakdown": false,
23
+ "flops_profiler": {
24
+ "enabled": true,
25
+ "profile_step": 10,
26
+ "module_depth": -1,
27
+ "top_modules": 3,
28
+ "detailed": true,
29
+ "output_file": "flops_profiler.out"
30
+ }
31
+ }
config/ds_z3_fp16.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "auto_cast": false,
5
+ "loss_scale": 0,
6
+ "initial_scale_power": 16,
7
+ "loss_scale_window": 1000,
8
+ "hysteresis": 2,
9
+ "min_loss_scale": 1
10
+ },
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "overlap_comm": true,
14
+ "contiguous_gradients": true,
15
+ "sub_group_size": 1e9,
16
+ "reduce_bucket_size": "auto",
17
+ "stage3_prefetch_bucket_size": "auto",
18
+ "stage3_param_persistence_threshold": "auto",
19
+ "stage3_max_live_parameters": 1e9,
20
+ "stage3_max_reuse_distance": 1e9,
21
+ "stage3_gather_16bit_weights_on_model_save": true
22
+ },
23
+ "gradient_accumulation_steps": "auto",
24
+ "gradient_clipping": "auto",
25
+ "steps_per_print": 2000,
26
+ "train_batch_size": "auto",
27
+ "train_micro_batch_size_per_gpu": "auto",
28
+ "wall_clock_breakdown": false,
29
+ "flops_profiler": {
30
+ "enabled": true,
31
+ "profile_step": 10,
32
+ "module_depth": -1,
33
+ "top_modules": 3,
34
+ "detailed": true,
35
+ "output_file": "flops_profiler.out"
36
+ }
37
+ }
config/ds_z3_fp16_save16bit.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "auto_cast": false,
5
+ "loss_scale": 0,
6
+ "initial_scale_power": 16,
7
+ "loss_scale_window": 1000,
8
+ "hysteresis": 2,
9
+ "min_loss_scale": 1
10
+ },
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "overlap_comm": true,
14
+ "contiguous_gradients": true,
15
+ "sub_group_size": 1e9,
16
+ "reduce_bucket_size": "auto",
17
+ "stage3_prefetch_bucket_size": "auto",
18
+ "stage3_param_persistence_threshold": "auto",
19
+ "stage3_max_live_parameters": 1e9,
20
+ "stage3_max_reuse_distance": 1e9,
21
+ "stage3_gather_16bit_weights_on_model_save": true
22
+ },
23
+ "gradient_accumulation_steps": "auto",
24
+ "gradient_clipping": "auto",
25
+ "steps_per_print": 2000,
26
+ "train_batch_size": "auto",
27
+ "train_micro_batch_size_per_gpu": "auto",
28
+ "wall_clock_breakdown": false,
29
+ "flops_profiler": {
30
+ "enabled": true,
31
+ "profile_step": 10,
32
+ "module_depth": -1,
33
+ "top_modules": 3,
34
+ "detailed": true,
35
+ "output_file": "flops_profiler.out"
36
+ }
37
+ }
convert/convert.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import argparse
3
+ import os
4
+
5
+ if __name__ == '__main__':
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--source", "-s", type=str, default="", help="source path of models")
8
+ parser.add_argument("--target", "-t", type=str, default="", help="target path of models")
9
+
10
+ args, _ = parser.parse_known_args()
11
+
12
+ assert os.path.exists(args.source)
13
+ assert args.target != ""
14
+
15
+ model = transformers.AutoModelForCausalLM.from_pretrained(args.source)
16
+ model.save_pretrained(args.target, state_dict=model.state_dict())
convert/convert.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model=$1
2
+
3
+ set -x
4
+
5
+ for step in `ls ${model} | grep checkpoint | awk -F'-' '{ print $2 }'`
6
+ do
7
+ mkdir ${model}/tmp-checkpoint-${step}
8
+ mkdir ${model}/final-checkpoint-${step}
9
+ python ./zero_to_fp32.py ${model}/checkpoint-${step}/ ${model}/tmp-checkpoint-${step}/pytorch_model.bin
10
+ cp ${model}/*.json ${model}/tmp-checkpoint-${step}
11
+ python ./convert.py -s ${model}/tmp-checkpoint-${step} -t ${model}/final-checkpoint-${step}
12
+ cp ${model}/checkpoint-${step}/*.json ${model}/final-checkpoint-${step}
13
+ cp ${model}/*.json ${model}/final-checkpoint-${step}
14
+ cp ${model}/tokenizer* ${model}/final-checkpoint-${step}
15
+ cp ${model}/train* ${model}/final-checkpoint-${step}
16
+ #rm -rf ${model}/tmp-checkpoint-${step} ${model}/checkpoint-${step} ${model}/global_step${step}
17
+ #mv ${model}/final-checkpoint-${step} ${model}/checkpoint-${step}
18
+ done
convert/convert_fp16.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+
8
+ def convert_fp16(in_checkpoint, out_checkpoint):
9
+ tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True
12
+ )
13
+ model.save_pretrained(out_checkpoint)
14
+ tokenizer.save_pretrained(out_checkpoint)
15
+
16
+
17
+ if __name__ == "__main__":
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
20
+ parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
21
+ args = parser.parse_args()
22
+
23
+ convert_fp16(args.in_checkpoint, args.out_checkpoint)
convert/make_delta.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+
9
+ def make_delta(base_model_path, target_model_path, delta_path):
10
+ print(f"Loading the base model from {base_model_path}")
11
+ base = AutoModelForCausalLM.from_pretrained(
12
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
13
+ )
14
+
15
+ print(f"Loading the target model from {target_model_path}")
16
+ target = AutoModelForCausalLM.from_pretrained(
17
+ target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
18
+ )
19
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False)
20
+
21
+ print("Calculating the delta")
22
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
23
+ assert name in base.state_dict()
24
+ if param.shape == base.state_dict()[name].shape:
25
+ param.data -= base.state_dict()[name]
26
+ else:
27
+ print(name)
28
+
29
+ print(f"Saving the delta to {delta_path}")
30
+ if args.hub_repo_id:
31
+ kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id}
32
+ else:
33
+ kwargs = {}
34
+ target.save_pretrained(delta_path, **kwargs)
35
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--base-model-path", type=str, required=True)
41
+ parser.add_argument("--target-model-path", type=str, required=True)
42
+ parser.add_argument("--delta-path", type=str, required=True)
43
+ parser.add_argument("--hub-repo-id", type=str)
44
+ args = parser.parse_args()
45
+
46
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path)
convert/merge_delta.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import gc
4
+ import glob
5
+ import json
6
+ import os
7
+ import shutil
8
+ import tempfile
9
+
10
+ from huggingface_hub import snapshot_download
11
+ import torch
12
+ from torch import nn
13
+ from tqdm import tqdm
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
15
+
16
+
17
+ GB = 1 << 30
18
+
19
+
20
+ def split_files(model_path, tmp_path, split_size):
21
+ if not os.path.exists(model_path):
22
+ model_path = snapshot_download(repo_id=model_path)
23
+ if not os.path.exists(tmp_path):
24
+ os.makedirs(tmp_path)
25
+
26
+ file_pattern = os.path.join(model_path, "pytorch_model-*.bin")
27
+ files = glob.glob(file_pattern)
28
+
29
+ part = 0
30
+ try:
31
+ for file_path in tqdm(files):
32
+ state_dict = torch.load(file_path)
33
+ new_state_dict = {}
34
+
35
+ current_size = 0
36
+ for name, param in state_dict.items():
37
+ param_size = param.numel() * param.element_size()
38
+
39
+ if current_size + param_size > split_size:
40
+ new_file_name = f"pytorch_model-{part}.bin"
41
+ new_file_path = os.path.join(tmp_path, new_file_name)
42
+ torch.save(new_state_dict, new_file_path)
43
+ current_size = 0
44
+ new_state_dict = None
45
+ gc.collect()
46
+ new_state_dict = {}
47
+ part += 1
48
+
49
+ new_state_dict[name] = param
50
+ current_size += param_size
51
+
52
+ new_file_name = f"pytorch_model-{part}.bin"
53
+ new_file_path = os.path.join(tmp_path, new_file_name)
54
+ torch.save(new_state_dict, new_file_path)
55
+ new_state_dict = None
56
+ gc.collect()
57
+ new_state_dict = {}
58
+ part += 1
59
+ except Exception as e:
60
+ print(f"An error occurred during split_files: {e}")
61
+ shutil.rmtree(tmp_path)
62
+ raise
63
+
64
+
65
+ def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):
66
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
67
+ delta_config = AutoConfig.from_pretrained(delta_path)
68
+
69
+ if os.path.exists(target_model_path):
70
+ shutil.rmtree(target_model_path)
71
+ os.makedirs(target_model_path)
72
+
73
+ split_size = 4 * GB
74
+
75
+ with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
76
+ print(f"Split files for the base model to {tmp_base_path}")
77
+ split_files(base_model_path, tmp_base_path, split_size)
78
+ print(f"Split files for the delta weights to {tmp_delta_path}")
79
+ split_files(delta_path, tmp_delta_path, split_size)
80
+
81
+ base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin")
82
+ base_files = glob.glob(base_pattern)
83
+ base_state_dict = torch.load(base_files[0])
84
+ delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin")
85
+ delta_files = glob.glob(delta_pattern)
86
+ # delta_state_dict = torch.load(delta_files[0])
87
+
88
+ print("Applying the delta")
89
+ weight_map = {}
90
+ total_size = 0
91
+
92
+ for i, delta_file in tqdm(enumerate(delta_files)):
93
+ state_dict = torch.load(delta_file)
94
+ file_name = f"pytorch_model-{i}.bin"
95
+ for name, param in state_dict.items():
96
+ if name not in base_state_dict:
97
+ for base_file in base_files:
98
+ base_state_dict = torch.load(base_file)
99
+ gc.collect()
100
+ if name in base_state_dict:
101
+ break
102
+ if state_dict[name].shape == base_state_dict[name].shape:
103
+ state_dict[name] += base_state_dict[name]
104
+ else:
105
+ print(name)
106
+ weight_map[name] = file_name
107
+ total_size += param.numel() * param.element_size()
108
+ gc.collect()
109
+ torch.save(state_dict, os.path.join(target_model_path, file_name))
110
+
111
+ with open(
112
+ os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w"
113
+ ) as f:
114
+ json.dump(
115
+ {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f
116
+ )
117
+
118
+ print(f"Saving the target model to {target_model_path}")
119
+ delta_tokenizer.save_pretrained(target_model_path)
120
+ delta_config.save_pretrained(target_model_path)
121
+
122
+
123
+ def apply_delta(base_model_path, target_model_path, delta_path):
124
+ print(f"Loading the delta weights from {delta_path}")
125
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
126
+ delta = AutoModelForCausalLM.from_pretrained(
127
+ delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
128
+ )
129
+
130
+ print(f"Loading the base model from {base_model_path}")
131
+ base = AutoModelForCausalLM.from_pretrained(
132
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
133
+ )
134
+
135
+ print("Applying the delta")
136
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
137
+ assert name in base.state_dict()
138
+ if param.shape == base.state_dict()[name].shape:
139
+ param.data += base.state_dict()[name]
140
+ else:
141
+ print(name)
142
+
143
+
144
+ print(f"Saving the target model to {target_model_path}")
145
+ delta.save_pretrained(target_model_path)
146
+ delta_tokenizer.save_pretrained(target_model_path)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--base-model-path", type=str, required=True)
152
+ parser.add_argument("--target-model-path", type=str, required=True)
153
+ parser.add_argument("--delta-path", type=str, required=True)
154
+ parser.add_argument(
155
+ "--low-cpu-mem",
156
+ action="store_true",
157
+ help="Lower the cpu memory usage. This will split large files and use "
158
+ "disk as swap to reduce the memory usage below 10GB.",
159
+ )
160
+ args = parser.parse_args()
161
+
162
+ if args.low_cpu_mem:
163
+ apply_delta_low_cpu_mem(
164
+ args.base_model_path, args.target_model_path, args.delta_path
165
+ )
166
+ else:
167
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
convert/zero_to_fp32.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example: python zero_to_fp32.py . pytorch_model.bin
14
+
15
+ import argparse
16
+ import torch
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass
23
+ from tqdm import tqdm
24
+
25
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
26
+ # DeepSpeed data structures it has to be available in the current python environment.
27
+ from deepspeed.utils import logger
28
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
29
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
30
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
31
+
32
+
33
+ @dataclass
34
+ class zero_model_state:
35
+ buffers: dict()
36
+ param_shapes: dict()
37
+ shared_params: list
38
+ ds_version: int
39
+ frozen_param_shapes: dict()
40
+ frozen_param_fragments: dict()
41
+
42
+
43
+ debug = 0
44
+
45
+ # load to cpu
46
+ device = torch.device('cpu')
47
+
48
+
49
+ def atoi(text):
50
+ return int(text) if text.isdigit() else text
51
+
52
+
53
+ def natural_keys(text):
54
+ '''
55
+ alist.sort(key=natural_keys) sorts in human order
56
+ http://nedbatchelder.com/blog/200712/human_sorting.html
57
+ (See Toothy's implementation in the comments)
58
+ '''
59
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
60
+
61
+
62
+ def get_model_state_file(checkpoint_dir, zero_stage):
63
+ if not os.path.isdir(checkpoint_dir):
64
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
65
+
66
+ # there should be only one file
67
+ if zero_stage == 2:
68
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
69
+ elif zero_stage == 3:
70
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
71
+
72
+ if not os.path.exists(file):
73
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
74
+
75
+ return file
76
+
77
+
78
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
79
+ # XXX: need to test that this simple glob rule works for multi-node setup too
80
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
81
+
82
+ if len(ckpt_files) == 0:
83
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
84
+
85
+ return ckpt_files
86
+
87
+
88
+ def get_optim_files(checkpoint_dir):
89
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
90
+
91
+
92
+ def get_model_state_files(checkpoint_dir):
93
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
94
+
95
+
96
+ def parse_model_states(files):
97
+ zero_model_states = []
98
+ for file in files:
99
+ state_dict = torch.load(file, map_location=device)
100
+
101
+ if BUFFER_NAMES not in state_dict:
102
+ raise ValueError(f"{file} is not a model state checkpoint")
103
+ buffer_names = state_dict[BUFFER_NAMES]
104
+ if debug:
105
+ print("Found buffers:", buffer_names)
106
+
107
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
108
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
109
+ param_shapes = state_dict[PARAM_SHAPES]
110
+
111
+ # collect parameters that are included in param_shapes
112
+ param_names = []
113
+ for s in param_shapes:
114
+ for name in s.keys():
115
+ param_names.append(name)
116
+
117
+ # update with frozen parameters
118
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
119
+ if frozen_param_shapes is not None:
120
+ if debug:
121
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
122
+ param_names += list(frozen_param_shapes.keys())
123
+
124
+ # record shared parameters so that they can be recovered based on partners
125
+ # this is because such parameters holding reference only are not saved by optimizer
126
+ shared_params = []
127
+ for param in state_dict["module"]:
128
+ if param not in [*param_names, *buffer_names]:
129
+ for share_param in state_dict["module"]:
130
+ if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr()
131
+ and share_param != param):
132
+ shared_params.append([param, share_param])
133
+ break
134
+
135
+ ds_version = state_dict.get(DS_VERSION, None)
136
+
137
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
138
+
139
+ z_model_state = zero_model_state(buffers=buffers,
140
+ param_shapes=param_shapes,
141
+ shared_params=shared_params,
142
+ ds_version=ds_version,
143
+ frozen_param_shapes=frozen_param_shapes,
144
+ frozen_param_fragments=frozen_param_fragments)
145
+ zero_model_states.append(z_model_state)
146
+
147
+ return zero_model_states
148
+
149
+
150
+ def parse_optim_states(files, ds_checkpoint_dir):
151
+
152
+ total_files = len(files)
153
+ state_dicts = []
154
+ for i, f in enumerate(tqdm(files)):
155
+ state_dicts.append(torch.load(f, map_location=device))
156
+ if i == 0:
157
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
158
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
159
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
160
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
161
+
162
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
163
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
164
+ # use the max of the partition_count to get the dp world_size.
165
+
166
+ if type(world_size) is list:
167
+ world_size = max(world_size)
168
+
169
+ if world_size != total_files:
170
+ raise ValueError(
171
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
172
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
173
+ )
174
+
175
+ # the groups are named differently in each stage
176
+ if zero_stage == 2:
177
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
178
+ elif zero_stage == 3:
179
+ fp32_groups_key = FP32_FLAT_GROUPS
180
+ else:
181
+ raise ValueError(f"unknown zero stage {zero_stage}")
182
+
183
+ key_list = list(state_dicts[-1][OPTIMIZER_STATE_DICT].keys())
184
+ for key in key_list:
185
+ if zero_stage == 2:
186
+ if key != fp32_groups_key:
187
+ del state_dicts[-1][OPTIMIZER_STATE_DICT][key]
188
+ elif zero_stage == 3:
189
+ if key == fp32_groups_key:
190
+ value = torch.cat(state_dicts[-1][OPTIMIZER_STATE_DICT][fp32_groups_key], 0)
191
+ del state_dicts[-1][OPTIMIZER_STATE_DICT][key]
192
+ if key == fp32_groups_key:
193
+ state_dicts[-1][OPTIMIZER_STATE_DICT][key] = value
194
+
195
+ print('zero_stage:', zero_stage)
196
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
197
+ # if zero_stage == 2:
198
+ # # fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
199
+ # elif zero_stage == 3:
200
+ # # if there is more than one param group, there will be multiple flattened tensors - one
201
+ # # flattened tensor per group - for simplicity merge them into a single tensor
202
+ # #
203
+ # # XXX: could make the script more memory efficient for when there are multiple groups - it
204
+ # # will require matching the sub-lists of param_shapes for each param group flattened tensor
205
+
206
+ # print('start!')
207
+ # # fp32_flat_groups = [
208
+ # # torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
209
+ # # ]
210
+
211
+ return zero_stage, world_size, fp32_flat_groups
212
+
213
+
214
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
215
+ """
216
+ Returns fp32 state_dict reconstructed from ds checkpoint
217
+
218
+ Args:
219
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
220
+
221
+ """
222
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
223
+
224
+ optim_files = get_optim_files(ds_checkpoint_dir)
225
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
226
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
227
+
228
+ model_files = get_model_state_files(ds_checkpoint_dir)
229
+
230
+ zero_model_states = parse_model_states(model_files)
231
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
232
+
233
+ if zero_stage == 2:
234
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
235
+ elif zero_stage == 3:
236
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
237
+
238
+
239
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
240
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
241
+ return
242
+
243
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
244
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
245
+
246
+ if debug:
247
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
248
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
249
+
250
+ wanted_params = len(frozen_param_shapes)
251
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
252
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
253
+ print(f'Frozen params: Have {avail_numel} numels to process.')
254
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
255
+
256
+ total_params = 0
257
+ total_numel = 0
258
+ for name, shape in frozen_param_shapes.items():
259
+ total_params += 1
260
+ unpartitioned_numel = shape.numel()
261
+ total_numel += unpartitioned_numel
262
+
263
+ state_dict[name] = frozen_param_fragments[name]
264
+
265
+ if debug:
266
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
267
+
268
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
269
+
270
+
271
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
272
+ param_shapes = zero_model_states[0].param_shapes
273
+
274
+ # Reconstruction protocol:
275
+ #
276
+ # XXX: document this
277
+
278
+ if debug:
279
+ for i in range(world_size):
280
+ for j in range(len(fp32_flat_groups[0])):
281
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
282
+
283
+ # XXX: memory usage doubles here (zero2)
284
+ num_param_groups = len(fp32_flat_groups[0])
285
+ merged_single_partition_of_fp32_groups = []
286
+ for i in range(num_param_groups):
287
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
288
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
289
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
290
+ avail_numel = sum(
291
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
292
+
293
+ if debug:
294
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
295
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
296
+ # not asserting if there is a mismatch due to possible padding
297
+ print(f"Have {avail_numel} numels to process.")
298
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
299
+
300
+ # params
301
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
302
+ # out-of-core computing solution
303
+ total_numel = 0
304
+ total_params = 0
305
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
306
+ offset = 0
307
+ avail_numel = full_single_fp32_vector.numel()
308
+ for name, shape in shapes.items():
309
+
310
+ unpartitioned_numel = shape.numel()
311
+ total_numel += unpartitioned_numel
312
+ total_params += 1
313
+
314
+ if debug:
315
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
316
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
317
+ offset += unpartitioned_numel
318
+
319
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
320
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
321
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
322
+ # live optimizer object, so we are checking that the numbers are within the right range
323
+ align_to = 2 * world_size
324
+
325
+ def zero2_align(x):
326
+ return align_to * math.ceil(x / align_to)
327
+
328
+ if debug:
329
+ print(f"original offset={offset}, avail_numel={avail_numel}")
330
+
331
+ offset = zero2_align(offset)
332
+ avail_numel = zero2_align(avail_numel)
333
+
334
+ if debug:
335
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
336
+
337
+ # Sanity check
338
+ if offset != avail_numel:
339
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
340
+
341
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
342
+
343
+
344
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
345
+ state_dict = OrderedDict()
346
+
347
+ # buffers
348
+ buffers = zero_model_states[0].buffers
349
+ state_dict.update(buffers)
350
+ if debug:
351
+ print(f"added {len(buffers)} buffers")
352
+
353
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
354
+
355
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
356
+
357
+ # recover shared parameters
358
+ for pair in zero_model_states[0].shared_params:
359
+ state_dict[pair[0]] = state_dict[pair[1]]
360
+
361
+ return state_dict
362
+
363
+
364
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
365
+ remainder = unpartitioned_numel % world_size
366
+ padding_numel = (world_size - remainder) if remainder else 0
367
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
368
+ return partitioned_numel, padding_numel
369
+
370
+
371
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
372
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
373
+ return
374
+
375
+ if debug:
376
+ for i in range(world_size):
377
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
378
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
379
+
380
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
381
+ wanted_params = len(frozen_param_shapes)
382
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
383
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
384
+ print(f'Frozen params: Have {avail_numel} numels to process.')
385
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
386
+
387
+ total_params = 0
388
+ total_numel = 0
389
+ for name, shape in tqdm(zero_model_states[0].frozen_param_shapes.items()):
390
+ total_params += 1
391
+ unpartitioned_numel = shape.numel()
392
+ total_numel += unpartitioned_numel
393
+
394
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
395
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
396
+
397
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
398
+
399
+ if debug:
400
+ print(
401
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
402
+ )
403
+
404
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
405
+
406
+
407
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
408
+ param_shapes = zero_model_states[0].param_shapes
409
+ avail_numel = fp32_flat_groups[0].numel() * world_size
410
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
411
+ # param, re-consolidating each param, while dealing with padding if any
412
+
413
+ # merge list of dicts, preserving order
414
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
415
+
416
+ if debug:
417
+ for i in range(world_size):
418
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
419
+
420
+ wanted_params = len(param_shapes)
421
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
422
+ # not asserting if there is a mismatch due to possible padding
423
+ avail_numel = fp32_flat_groups[0].numel() * world_size
424
+ print(f"Trainable params: Have {avail_numel} numels to process.")
425
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
426
+
427
+ # params
428
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
429
+ # out-of-core computing solution
430
+ offset = 0
431
+ total_numel = 0
432
+ total_params = 0
433
+ for name, shape in tqdm(param_shapes.items()):
434
+
435
+ unpartitioned_numel = shape.numel()
436
+ total_numel += unpartitioned_numel
437
+ total_params += 1
438
+
439
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
440
+
441
+ if debug:
442
+ print(
443
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
444
+ )
445
+
446
+ # XXX: memory usage doubles here
447
+ state_dict[name] = torch.cat(
448
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
449
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
450
+ offset += partitioned_numel
451
+
452
+ offset *= world_size
453
+
454
+ # Sanity check
455
+ if offset != avail_numel:
456
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
457
+
458
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
459
+
460
+
461
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
462
+ state_dict = OrderedDict()
463
+
464
+ # buffers
465
+ buffers = zero_model_states[0].buffers
466
+ state_dict.update(buffers)
467
+ if debug:
468
+ print(f"added {len(buffers)} buffers")
469
+
470
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
471
+
472
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
473
+
474
+ # recover shared parameters
475
+ for pair in zero_model_states[0].shared_params:
476
+ state_dict[pair[0]] = state_dict[pair[1]]
477
+
478
+ return state_dict
479
+
480
+
481
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
482
+ """
483
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
484
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
485
+ via a model hub.
486
+
487
+ Args:
488
+ - ``checkpoint_dir``: path to the desired checkpoint folder
489
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
490
+
491
+ Returns:
492
+ - pytorch ``state_dict``
493
+
494
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
495
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
496
+ the checkpoint.
497
+
498
+ A typical usage might be ::
499
+
500
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
501
+ # do the training and checkpoint saving
502
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
503
+ model = model.cpu() # move to cpu
504
+ model.load_state_dict(state_dict)
505
+ # submit to model hub or save the model to share with others
506
+
507
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
508
+ application. i.e. you will need to re-initialize the deepspeed engine, since
509
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
510
+
511
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
512
+
513
+ """
514
+ if tag is None:
515
+ latest_path = os.path.join(checkpoint_dir, 'latest')
516
+ if os.path.isfile(latest_path):
517
+ with open(latest_path, 'r') as fd:
518
+ tag = fd.read().strip()
519
+ else:
520
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
521
+
522
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
523
+
524
+ if not os.path.isdir(ds_checkpoint_dir):
525
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
526
+
527
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
528
+
529
+
530
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
531
+ """
532
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
533
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
534
+
535
+ Args:
536
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
537
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
538
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
539
+ """
540
+
541
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
542
+ print(f"Saving fp32 state dict to {output_file}")
543
+ torch.save(state_dict, output_file)
544
+
545
+
546
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
547
+ """
548
+ 1. Put the provided model to cpu
549
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
550
+ 3. Load it into the provided model
551
+
552
+ Args:
553
+ - ``model``: the model object to update
554
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
555
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
556
+
557
+ Returns:
558
+ - ``model`: modified model
559
+
560
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
561
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
562
+ conveniently placed for you in the checkpoint folder.
563
+
564
+ A typical usage might be ::
565
+
566
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
567
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
568
+ # submit to model hub or save the model to share with others
569
+
570
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
571
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
572
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
573
+
574
+ """
575
+ logger.info(f"Extracting fp32 weights")
576
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
577
+
578
+ logger.info(f"Overwriting model with fp32 weights")
579
+ model = model.cpu()
580
+ model.load_state_dict(state_dict, strict=False)
581
+
582
+ return model
583
+
584
+
585
+ if __name__ == "__main__":
586
+
587
+ parser = argparse.ArgumentParser()
588
+ parser.add_argument("checkpoint_dir",
589
+ type=str,
590
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
591
+ parser.add_argument(
592
+ "output_file",
593
+ type=str,
594
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
595
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
596
+ args = parser.parse_args()
597
+
598
+ debug = args.debug
599
+
600
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
data.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import argparse
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+ import torch.distributed as dist
11
+ import logging
12
+ import re
13
+ import pdb
14
+ import json
15
+ from prompt import sft_prompt, all_prompt
16
+ import numpy as np
17
+
18
+
19
+ class BaseDataset(Dataset):
20
+
21
+ def __init__(self, args):
22
+ super().__init__()
23
+
24
+ self.args = args
25
+ self.dataset = args.dataset
26
+ self.data_path = os.path.join(args.data_path, self.dataset)
27
+
28
+ self.max_his_len = args.max_his_len
29
+ self.his_sep = args.his_sep
30
+ self.index_file = args.index_file
31
+ self.add_prefix = args.add_prefix
32
+
33
+ self.new_tokens = None
34
+ self.allowed_tokens = None
35
+ self.all_items = None
36
+
37
+
38
+ def _load_data(self):
39
+
40
+ with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
41
+ self.indices = json.load(f)
42
+
43
+ def get_new_tokens(self):
44
+
45
+ if self.new_tokens is not None:
46
+ return self.new_tokens
47
+
48
+ self.new_tokens = set()
49
+ for index in self.indices.values():
50
+ for token in index:
51
+ self.new_tokens.add(token)
52
+ self.new_tokens = sorted(list(self.new_tokens))
53
+
54
+ return self.new_tokens
55
+
56
+ def get_all_items(self):
57
+
58
+ if self.all_items is not None:
59
+ return self.all_items
60
+
61
+ self.all_items = set()
62
+ for index in self.indices.values():
63
+ self.all_items.add("".join(index))
64
+
65
+ return self.all_items
66
+
67
+ def get_prefix_allowed_tokens_fn(self, tokenizer):
68
+
69
+
70
+ if self.allowed_tokens is None:
71
+ self.allowed_tokens = {}
72
+ for index in self.indices.values():
73
+ for i, token in enumerate(index):
74
+ token_id = tokenizer(token)["input_ids"][1]
75
+ if i not in self.allowed_tokens.keys():
76
+ self.allowed_tokens[i] = set()
77
+ self.allowed_tokens[i].add(token_id)
78
+ self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id])
79
+ sep = tokenizer("Response:")["input_ids"][1:]
80
+
81
+ def prefix_allowed_tokens_fn(batch_id, sentence):
82
+ sentence = sentence.tolist()
83
+ reversed_sent = sentence[::-1]
84
+ for i in range(len(reversed_sent)):
85
+ if reversed_sent[i:i + len(sep)] == sep[::-1]:
86
+ # print(list(self.allowed_tokens[i]))
87
+ return list(self.allowed_tokens[i])
88
+
89
+ return prefix_allowed_tokens_fn
90
+
91
+ def _process_data(self):
92
+
93
+ raise NotImplementedError
94
+
95
+
96
+
97
+ class SeqRecDataset(BaseDataset):
98
+
99
+ def __init__(self, args, mode="train",
100
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
101
+ super().__init__(args)
102
+
103
+ self.mode = mode
104
+ self.prompt_sample_num = prompt_sample_num
105
+ self.prompt_id = prompt_id
106
+ self.sample_num = sample_num
107
+
108
+ self.prompts = all_prompt["seqrec"]
109
+
110
+
111
+ # load data
112
+ self._load_data()
113
+ self._remap_items()
114
+
115
+ # load data
116
+ if self.mode == 'train':
117
+ self.inter_data = self._process_train_data()
118
+ elif self.mode == 'valid':
119
+ self.sample_valid = args.sample_valid
120
+ self.valid_prompt_id = args.valid_prompt_id
121
+ self.inter_data = self._process_valid_data()
122
+ self._construct_valid_text()
123
+ elif self.mode == 'test':
124
+ self.inter_data = self._process_test_data()
125
+ else:
126
+ raise NotImplementedError
127
+
128
+
129
+
130
+ def _load_data(self):
131
+
132
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
133
+ self.inters = json.load(f)
134
+ with open(self.index_file, 'r') as f:
135
+ self.indices = json.load(f)
136
+
137
+
138
+ def _remap_items(self):
139
+
140
+ self.remapped_inters = dict()
141
+ for uid, items in self.inters.items():
142
+ new_items = ["".join(self.indices[str(i)]) for i in items]
143
+ self.remapped_inters[uid] = new_items
144
+
145
+
146
+ def _process_train_data(self):
147
+
148
+ inter_data = []
149
+ for uid in self.remapped_inters:
150
+ items = self.remapped_inters[uid][:-2]
151
+ for i in range(1, len(items)):
152
+ one_data = dict()
153
+ # one_data["user"] = uid
154
+ one_data["item"] = items[i]
155
+ history = items[:i]
156
+ if self.max_his_len > 0:
157
+ history = history[-self.max_his_len:]
158
+ if self.add_prefix:
159
+ history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)]
160
+ one_data["inters"] = self.his_sep.join(history)
161
+ inter_data.append(one_data)
162
+
163
+ return inter_data
164
+
165
+ def _process_valid_data(self):
166
+
167
+ inter_data = []
168
+ for uid in self.remapped_inters:
169
+ items = self.remapped_inters[uid]
170
+ one_data = dict()
171
+ # one_data["user"] = uid
172
+ one_data["item"] = items[-2]
173
+ history = items[:-2]
174
+ if self.max_his_len > 0:
175
+ history = history[-self.max_his_len:]
176
+ if self.add_prefix:
177
+ history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
178
+ one_data["inters"] = self.his_sep.join(history)
179
+ inter_data.append(one_data)
180
+
181
+ return inter_data
182
+
183
+ def _process_test_data(self):
184
+
185
+ inter_data = []
186
+ for uid in self.remapped_inters:
187
+ items = self.remapped_inters[uid]
188
+ one_data = dict()
189
+ # one_data["user"] = uid
190
+ one_data["item"] = items[-1]
191
+ history = items[:-1]
192
+ if self.max_his_len > 0:
193
+ history = history[-self.max_his_len:]
194
+ if self.add_prefix:
195
+ history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
196
+ one_data["inters"] = self.his_sep.join(history)
197
+ inter_data.append(one_data)
198
+
199
+ if self.sample_num > 0:
200
+ all_inter_idx = range(len(inter_data))
201
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
202
+ inter_data = np.array(inter_data)[sample_idx].tolist()
203
+
204
+ return inter_data
205
+
206
+ def set_prompt(self, prompt_id):
207
+
208
+ self.prompt_id = prompt_id
209
+
210
+ def __len__(self):
211
+ if self.mode == 'train':
212
+ return len(self.inter_data) * self.prompt_sample_num
213
+ elif self.mode == 'valid':
214
+ return len(self.valid_text_data)
215
+ elif self.mode == 'test':
216
+ return len(self.inter_data)
217
+ else:
218
+ raise NotImplementedError
219
+
220
+ def _construct_valid_text(self):
221
+ self.valid_text_data = []
222
+ if self.sample_valid:
223
+ all_prompt_ids = range(len(self.prompts))
224
+ for i in range(len(self.inter_data)):
225
+ d = self.inter_data[i]
226
+ prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
227
+ for prompt_id in prompt_ids:
228
+ prompt = self.prompts[prompt_id]
229
+ input, output = self._get_text_data(d, prompt)
230
+ self.valid_text_data.append({"input_ids": input, "labels": output})
231
+ else:
232
+ self.prompt_sample_num = 1
233
+ prompt = self.prompts[self.valid_prompt_id]
234
+ for i in range(len(self.inter_data)):
235
+ d = self.inter_data[i]
236
+ input, output = self._get_text_data(d, prompt)
237
+ self.valid_text_data.append({"input_ids": input, "labels": output})
238
+
239
+ def _get_text_data(self, data, prompt):
240
+
241
+ instruction = prompt["instruction"].format(**data)
242
+ response = prompt["response"].format(**data)
243
+
244
+ input = sft_prompt.format(instruction = instruction, response = "")
245
+ output = sft_prompt.format(instruction = instruction, response = response)
246
+
247
+ if self.mode == 'test':
248
+ return input, response
249
+
250
+ return input, output
251
+
252
+ def __getitem__(self, index):
253
+
254
+ if self.mode == 'valid':
255
+ return self.valid_text_data[index]
256
+
257
+ idx = index // self.prompt_sample_num
258
+ d = self.inter_data[idx]
259
+ # print(index, idx)
260
+
261
+ if self.mode == 'train':
262
+ prompt_id = random.randint(0, len(self.prompts) - 1)
263
+ elif self.mode == 'test':
264
+ prompt_id = self.prompt_id
265
+
266
+ prompt = self.prompts[prompt_id]
267
+
268
+ input, output = self._get_text_data(d, prompt)
269
+
270
+ # print({"input": input, "output": output})
271
+
272
+ return dict(input_ids=input, labels=output)
273
+
274
+
275
+ class FusionSeqRecDataset(BaseDataset):
276
+
277
+ def __init__(self, args, mode="train",
278
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
279
+ super().__init__(args)
280
+
281
+ self.mode = mode
282
+ self.prompt_sample_num = prompt_sample_num
283
+ self.prompt_id = prompt_id
284
+ self.sample_num = sample_num
285
+
286
+ self.prompts = all_prompt["fusionseqrec"]
287
+
288
+ # load data
289
+ self._load_data()
290
+ # self._remap_items()
291
+
292
+ # load data
293
+ if self.mode == 'train':
294
+ self.inter_data = self._process_train_data()
295
+ elif self.mode == 'valid':
296
+ self.sample_valid = args.sample_valid
297
+ self.valid_prompt_id = args.valid_prompt_id
298
+ self.inter_data = self._process_valid_data()
299
+ self._construct_valid_text()
300
+ elif self.mode == 'test':
301
+ self.inter_data = self._process_test_data()
302
+ else:
303
+ raise NotImplementedError
304
+
305
+
306
+ def _load_data(self):
307
+
308
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
309
+ self.inters = json.load(f)
310
+ with open(self.index_file, 'r') as f:
311
+ self.indices = json.load(f)
312
+ with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
313
+ self.item_feat = json.load(f)
314
+
315
+ def _process_train_data(self):
316
+
317
+ inter_data = []
318
+ for uid in self.inters:
319
+ items = self.inters[uid][:-2]
320
+ for i in range(1, len(items)):
321
+ one_data = dict()
322
+ # one_data["user"] = uid
323
+ one_data["item"] = "".join(self.indices[str(items[i])])
324
+ one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`")
325
+ one_data["description"] = self.item_feat[str(items[i])]["description"]
326
+ history = items[:i]
327
+ if self.max_his_len > 0:
328
+ history = history[-self.max_his_len:]
329
+ inters = ["".join(self.indices[str(j)]) for j in history]
330
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
331
+
332
+
333
+ if self.add_prefix:
334
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
335
+ inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
336
+
337
+ one_data["inters"] = self.his_sep.join(inters)
338
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
339
+ inter_data.append(one_data)
340
+
341
+ if self.sample_num > 0:
342
+ all_inter_idx = range(len(inter_data))
343
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
344
+ inter_data = np.array(inter_data)[sample_idx].tolist()
345
+
346
+ return inter_data
347
+
348
+ def _process_valid_data(self):
349
+
350
+ inter_data = []
351
+ for uid in self.inters:
352
+ items = self.inters[uid]
353
+ one_data = dict()
354
+ one_data["item"] = "".join(self.indices[str(items[-2])])
355
+ one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`")
356
+ one_data["description"] = self.item_feat[str(items[-2])]["description"]
357
+
358
+
359
+ history = items[:-2]
360
+ if self.max_his_len > 0:
361
+ history = history[-self.max_his_len:]
362
+ inters = ["".join(self.indices[str(j)]) for j in history]
363
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
364
+
365
+ if self.add_prefix:
366
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
367
+ inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
368
+
369
+ one_data["inters"] = self.his_sep.join(inters)
370
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
371
+ inter_data.append(one_data)
372
+
373
+ if self.sample_num > 0:
374
+ all_inter_idx = range(len(inter_data))
375
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
376
+ inter_data = np.array(inter_data)[sample_idx].tolist()
377
+
378
+ return inter_data
379
+
380
+ def _process_test_data(self):
381
+
382
+ inter_data = []
383
+ for uid in self.inters:
384
+ items = self.inters[uid]
385
+ one_data = dict()
386
+ one_data["item"] = "".join(self.indices[str(items[-1])])
387
+ one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`")
388
+ one_data["description"] = self.item_feat[str(items[-1])]["description"]
389
+
390
+ history = items[:-1]
391
+ if self.max_his_len > 0:
392
+ history = history[-self.max_his_len:]
393
+ inters = ["".join(self.indices[str(j)]) for j in history]
394
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
395
+
396
+ if self.add_prefix:
397
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
398
+ inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
399
+
400
+ one_data["inters"] = self.his_sep.join(inters)
401
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
402
+ inter_data.append(one_data)
403
+
404
+ if self.sample_num > 0:
405
+ all_inter_idx = range(len(inter_data))
406
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
407
+ inter_data = np.array(inter_data)[sample_idx].tolist()
408
+
409
+ return inter_data
410
+
411
+ def set_prompt(self, prompt_id):
412
+
413
+ self.prompt_id = prompt_id
414
+
415
+ def __len__(self):
416
+ if self.mode == 'train':
417
+ return len(self.inter_data) * self.prompt_sample_num
418
+ elif self.mode == 'valid':
419
+ return len(self.valid_text_data)
420
+ elif self.mode == 'test':
421
+ return len(self.inter_data)
422
+ else:
423
+ raise NotImplementedError
424
+
425
+ def _construct_valid_text(self):
426
+ self.valid_text_data = []
427
+ if self.sample_valid:
428
+ all_prompt_ids = range(len(self.prompts))
429
+ for i in range(len(self.inter_data)):
430
+ d = self.inter_data[i]
431
+ prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
432
+ for prompt_id in prompt_ids:
433
+ prompt = self.prompts[prompt_id]
434
+ input, output = self._get_text_data(d, prompt)
435
+ self.valid_text_data.append({"input_ids": input, "labels": output})
436
+ else:
437
+ self.prompt_sample_num = 1
438
+ prompt = self.prompts[self.valid_prompt_id]
439
+ for i in range(len(self.inter_data)):
440
+ d = self.inter_data[i]
441
+ input, output = self._get_text_data(d, prompt)
442
+ self.valid_text_data.append({"input_ids": input, "labels": output})
443
+
444
+ def _get_text_data(self, data, prompt):
445
+
446
+ instruction = prompt["instruction"].format(**data)
447
+ response = prompt["response"].format(**data)
448
+
449
+ input = sft_prompt.format(instruction=instruction, response="")
450
+ output = sft_prompt.format(instruction=instruction, response=response)
451
+
452
+ if self.mode == 'test':
453
+ return input, response
454
+
455
+ return input, output
456
+
457
+ def __getitem__(self, index):
458
+
459
+ if self.mode == 'valid':
460
+ return self.valid_text_data[index]
461
+
462
+ idx = index // self.prompt_sample_num
463
+ d = self.inter_data[idx]
464
+
465
+ if self.mode == 'train':
466
+ prompt_id = random.randint(0, len(self.prompts) - 1)
467
+ elif self.mode == 'test':
468
+ prompt_id = self.prompt_id
469
+
470
+ prompt = self.prompts[prompt_id]
471
+
472
+ input, output = self._get_text_data(d, prompt)
473
+
474
+
475
+ return dict(input_ids=input, labels=output)
476
+
477
+
478
+ class ItemFeatDataset(BaseDataset):
479
+
480
+ def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1):
481
+ super().__init__(args)
482
+
483
+ self.task = task.lower()
484
+ self.prompt_sample_num = prompt_sample_num
485
+ self.sample_num = sample_num
486
+
487
+ self.prompts = all_prompt[self.task]
488
+
489
+ # load data
490
+ self._load_data()
491
+ self.feat_data = self._process_data()
492
+
493
+
494
+
495
+ def _load_data(self):
496
+
497
+ with open(self.index_file, 'r') as f:
498
+ self.indices = json.load(f)
499
+ with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
500
+ self.item_feat = json.load(f)
501
+
502
+
503
+ def _process_data(self):
504
+
505
+ feat_data = []
506
+ for iid in self.item_feat:
507
+ feat = self.item_feat[iid]
508
+ index = "".join(self.indices[iid])
509
+ feat["item"] = index
510
+ feat["title"] = feat["title"].strip().strip(".!?,;:`")
511
+ feat_data.append(feat)
512
+
513
+ if self.sample_num > 0:
514
+ all_idx = range(len(feat_data))
515
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
516
+
517
+ feat_data = np.array(feat_data)[sample_idx].tolist()
518
+
519
+ return feat_data
520
+
521
+
522
+ def __len__(self):
523
+ return len(self.feat_data) * self.prompt_sample_num
524
+
525
+ def _get_text_data(self, data, prompt):
526
+
527
+ instruction = prompt["instruction"].format(**data)
528
+ response = prompt["response"].format(**data)
529
+
530
+ input = sft_prompt.format(instruction = instruction, response = "")
531
+ output = sft_prompt.format(instruction = instruction, response = response)
532
+
533
+ return input, output
534
+
535
+ def __getitem__(self, index):
536
+
537
+ idx = index // self.prompt_sample_num
538
+ d = self.feat_data[idx]
539
+
540
+ prompt_id = random.randint(0, len(self.prompts) - 1)
541
+
542
+ prompt = self.prompts[prompt_id]
543
+
544
+ input, output = self._get_text_data(d, prompt)
545
+
546
+ return dict(input_ids=input, labels=output)
547
+
548
+
549
+ class ItemSearchDataset(BaseDataset):
550
+
551
+ def __init__(self, args, mode="train",
552
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
553
+ super().__init__(args)
554
+
555
+ self.mode = mode
556
+ self.prompt_sample_num = prompt_sample_num
557
+ self.prompt_id = prompt_id
558
+ self.sample_num = sample_num
559
+
560
+ self.prompts = all_prompt["itemsearch"]
561
+
562
+ # load data
563
+ self._load_data()
564
+ self.search_data = self._process_data()
565
+
566
+
567
+
568
+ def _load_data(self):
569
+
570
+ with open(self.index_file, 'r') as f:
571
+ self.indices = json.load(f)
572
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
573
+ self.user_info = json.load(f)
574
+
575
+
576
+ def _process_data(self):
577
+
578
+ search_data = []
579
+ user_explicit_preference = self.user_info["user_explicit_preference"]
580
+ user_vague_intention = self.user_info["user_vague_intention"]
581
+ if self.mode == 'train':
582
+ user_vague_intention = user_vague_intention["train"]
583
+ elif self.mode == 'test':
584
+ user_vague_intention = user_vague_intention["test"]
585
+ else:
586
+ raise NotImplementedError
587
+
588
+ for uid in user_explicit_preference.keys():
589
+ one_data = {}
590
+ user_ep = user_explicit_preference[uid]
591
+ user_vi = user_vague_intention[uid]["querys"]
592
+ one_data["explicit_preferences"] = user_ep
593
+ one_data["user_related_intention"] = user_vi[0]
594
+ one_data["item_related_intention"] = user_vi[1]
595
+
596
+ iid = user_vague_intention[uid]["item"]
597
+ inters = user_vague_intention[uid]["inters"]
598
+
599
+ index = "".join(self.indices[str(iid)])
600
+ one_data["item"] = index
601
+
602
+ if self.max_his_len > 0:
603
+ inters = inters[-self.max_his_len:]
604
+ inters = ["".join(self.indices[str(i)]) for i in inters]
605
+ if self.add_prefix:
606
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
607
+
608
+ one_data["inters"] = self.his_sep.join(inters)
609
+
610
+ search_data.append(one_data)
611
+
612
+ if self.sample_num > 0:
613
+ all_idx = range(len(search_data))
614
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
615
+
616
+ search_data = np.array(search_data)[sample_idx].tolist()
617
+
618
+ return search_data
619
+
620
+ def set_prompt(self, prompt_id):
621
+ self.prompt_id = prompt_id
622
+
623
+ def __len__(self):
624
+ if self.mode == 'train':
625
+ return len(self.search_data) * self.prompt_sample_num
626
+ elif self.mode == 'test':
627
+ return len(self.search_data)
628
+ else:
629
+ return len(self.search_data)
630
+
631
+
632
+ def _get_text_data(self, data, prompt):
633
+
634
+ instruction = prompt["instruction"].format(**data)
635
+ response = prompt["response"].format(**data)
636
+
637
+ input = sft_prompt.format(instruction = instruction, response = "")
638
+ output = sft_prompt.format(instruction = instruction, response = response)
639
+
640
+ if self.mode == 'test':
641
+ return input, response
642
+
643
+ return input, output
644
+
645
+ def __getitem__(self, index):
646
+
647
+ idx = index // self.prompt_sample_num
648
+
649
+ d = self.search_data[idx]
650
+ if self.mode == 'train':
651
+ prompt_id = random.randint(0, len(self.prompts) - 1)
652
+ elif self.mode == 'test':
653
+ prompt_id = self.prompt_id
654
+
655
+ prompt = self.prompts[prompt_id]
656
+
657
+ d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
658
+ all_querys = [d["user_related_intention"], d["item_related_intention"]]
659
+ d["query"] = random.choice(all_querys)
660
+
661
+ input, output = self._get_text_data(d, prompt)
662
+
663
+ return dict(input_ids=input, labels=output)
664
+
665
+
666
+
667
+ class PreferenceObtainDataset(BaseDataset):
668
+
669
+ def __init__(self, args, prompt_sample_num=1, sample_num=-1):
670
+ super().__init__(args)
671
+
672
+ self.prompt_sample_num = prompt_sample_num
673
+ self.sample_num = sample_num
674
+
675
+ self.prompts = all_prompt["preferenceobtain"]
676
+
677
+ # load data
678
+ self._load_data()
679
+ self._remap_items()
680
+
681
+ self.preference_data = self._process_data()
682
+
683
+
684
+
685
+ def _load_data(self):
686
+
687
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
688
+ self.user_info = json.load(f)
689
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
690
+ self.inters = json.load(f)
691
+ with open(self.index_file, 'r') as f:
692
+ self.indices = json.load(f)
693
+
694
+
695
+ def _remap_items(self):
696
+
697
+ self.remapped_inters = dict()
698
+ for uid, items in self.inters.items():
699
+ new_items = ["".join(self.indices[str(i)]) for i in items]
700
+ self.remapped_inters[uid] = new_items
701
+
702
+ def _process_data(self):
703
+
704
+ preference_data = []
705
+ user_explicit_preference = self.user_info["user_explicit_preference"]
706
+
707
+ for uid in user_explicit_preference.keys():
708
+ one_data = {}
709
+ inters = self.remapped_inters[uid][:-3]
710
+ user_ep = user_explicit_preference[uid]
711
+
712
+ if self.max_his_len > 0:
713
+ inters = inters[-self.max_his_len:]
714
+ if self.add_prefix:
715
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
716
+
717
+ one_data["explicit_preferences"] = user_ep
718
+ one_data["inters"] = self.his_sep.join(inters)
719
+
720
+ preference_data.append(one_data)
721
+
722
+ if self.sample_num > 0:
723
+ all_idx = range(len(preference_data))
724
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
725
+
726
+ preference_data = np.array(preference_data)[sample_idx].tolist()
727
+
728
+ return preference_data
729
+
730
+ def set_prompt(self, prompt_id):
731
+ self.prompt_id = prompt_id
732
+
733
+ def __len__(self):
734
+ return len(self.preference_data) * self.prompt_sample_num
735
+
736
+
737
+ def _get_text_data(self, data, prompt):
738
+
739
+ instruction = prompt["instruction"].format(**data)
740
+ response = prompt["response"].format(**data)
741
+
742
+ input = sft_prompt.format(instruction = instruction, response = "")
743
+ output = sft_prompt.format(instruction = instruction, response = response)
744
+
745
+ return input, output
746
+
747
+ def __getitem__(self, index):
748
+
749
+ idx = index // self.prompt_sample_num
750
+
751
+ d = self.preference_data[idx]
752
+ prompt_id = random.randint(0, len(self.prompts) - 1)
753
+
754
+ prompt = self.prompts[prompt_id]
755
+
756
+ d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
757
+
758
+ input, output = self._get_text_data(d, prompt)
759
+
760
+ return dict(input_ids=input, labels=output)
761
+
762
+
763
+
764
+
765
+
766
+ class SeqRecTestDataset(BaseDataset):
767
+
768
+ def __init__(self, args, prompt_id=0, sample_num=-1):
769
+ super().__init__(args)
770
+
771
+ self.prompt_id = prompt_id
772
+ self.sample_num = sample_num
773
+
774
+ self.prompt = all_prompt["seqrec"][self.prompt_id]
775
+
776
+ # load data
777
+ self._load_data()
778
+ self._remap_items()
779
+
780
+ self.inter_data = self._process_test_data()
781
+
782
+ def _load_data(self):
783
+
784
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
785
+ self.inters = json.load(f)
786
+ with open(self.index_file, 'r') as f:
787
+ self.indices = json.load(f)
788
+
789
+
790
+ def _remap_items(self):
791
+
792
+ self.remapped_inters = dict()
793
+ for uid, items in self.inters.items():
794
+ new_items = ["".join(self.indices[str(i)]) for i in items]
795
+ self.remapped_inters[uid] = new_items
796
+
797
+ def _process_test_data(self):
798
+
799
+ inter_data = []
800
+ for uid in self.remapped_inters:
801
+ items = self.remapped_inters[uid]
802
+ one_data = dict()
803
+ # one_data["user"] = uid
804
+ one_data["item"] = items[-1]
805
+ history = items[:-1]
806
+ if self.max_his_len > 0:
807
+ history = history[-self.max_his_len:]
808
+ if self.add_prefix:
809
+ history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
810
+ one_data["inters"] = self.his_sep.join(history)
811
+ inter_data.append(one_data)
812
+
813
+ if self.sample_num > 0:
814
+ all_inter_idx = range(len(inter_data))
815
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
816
+
817
+ inter_data = np.array(inter_data)[sample_idx].tolist()
818
+
819
+ return inter_data
820
+
821
+ def set_prompt(self, prompt_id):
822
+ self.prompt_id = prompt_id
823
+
824
+ self.prompt = all_prompt["seqrec"][self.prompt_id]
825
+
826
+ def __len__(self):
827
+
828
+ return len(self.inter_data)
829
+
830
+ def _get_text_data(self, data, prompt):
831
+
832
+ instruction = prompt["instruction"].format(**data)
833
+ response = prompt["response"].format(**data)
834
+
835
+ input = sft_prompt.format(instruction=instruction, response="")
836
+
837
+ return input, response
838
+
839
+ def __getitem__(self, index):
840
+
841
+ d = self.inter_data[index]
842
+ input, target = self._get_text_data(d, self.prompt)
843
+
844
+ return dict(input_ids=input, labels=target)
data_process/amazon18_data_process.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import gzip
4
+ import html
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from utils import check_path, clean_text, amazon18_dataset2fullname, write_json_file, write_remap_index
13
+
14
+ def load_ratings(file):
15
+ users, items, inters = set(), set(), set()
16
+ with open(file, 'r') as fp:
17
+ for line in tqdm(fp, desc='Load ratings'):
18
+ try:
19
+ item, user, rating, time = line.strip().split(',')
20
+ users.add(user)
21
+ items.add(item)
22
+ inters.add((user, item, float(rating), int(time)))
23
+ except ValueError:
24
+ print(line)
25
+ return users, items, inters
26
+
27
+
28
+ def load_meta_items(file):
29
+ items = {}
30
+ with gzip.open(file, "r") as fp:
31
+ for line in tqdm(fp, desc="Load metas"):
32
+ data = json.loads(line)
33
+ item = data["asin"]
34
+ title = clean_text(data["title"])
35
+
36
+ descriptions = data["description"]
37
+ descriptions = clean_text(descriptions)
38
+
39
+ brand = data["brand"].replace("by\n", "").strip()
40
+
41
+ categories = data["category"]
42
+ new_categories = []
43
+ for category in categories:
44
+ if "</span>" in category:
45
+ break
46
+ new_categories.append(category.strip())
47
+ categories = ",".join(new_categories).strip()
48
+
49
+ items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories}
50
+ # print(items[item])
51
+ return items
52
+
53
+
54
+ def load_review_data(args, user2id, item2id):
55
+
56
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
57
+ review_file_path = os.path.join(args.input_path, 'Review', dataset_full_name + '.json.gz')
58
+
59
+ reviews = {}
60
+
61
+ with gzip.open(review_file_path, "r") as fp:
62
+
63
+ for line in tqdm(fp,desc='Load reviews'):
64
+ inter = json.loads(line)
65
+ try:
66
+ user = inter['reviewerID']
67
+ item = inter['asin']
68
+ if user in user2id and item in item2id:
69
+ uid = user2id[user]
70
+ iid = item2id[item]
71
+ else:
72
+ continue
73
+ if 'reviewText' in inter:
74
+ review = clean_text(inter['reviewText'])
75
+ else:
76
+ review = ''
77
+ if 'summary' in inter:
78
+ summary = clean_text(inter['summary'])
79
+ else:
80
+ summary = ''
81
+ reviews[str((uid,iid))]={"review":review, "summary":summary}
82
+
83
+ except ValueError:
84
+ print(line)
85
+
86
+ return reviews
87
+
88
+
89
+ def get_user2count(inters):
90
+ user2count = collections.defaultdict(int)
91
+ for unit in inters:
92
+ user2count[unit[0]] += 1
93
+ return user2count
94
+
95
+
96
+ def get_item2count(inters):
97
+ item2count = collections.defaultdict(int)
98
+ for unit in inters:
99
+ item2count[unit[1]] += 1
100
+ return item2count
101
+
102
+
103
+ def generate_candidates(unit2count, threshold):
104
+ cans = set()
105
+ for unit, count in unit2count.items():
106
+ if count >= threshold:
107
+ cans.add(unit)
108
+ return cans, len(unit2count) - len(cans)
109
+
110
+
111
+ def filter_inters(inters, can_items=None,
112
+ user_k_core_threshold=0, item_k_core_threshold=0):
113
+ new_inters = []
114
+
115
+ # filter by meta items
116
+ if can_items:
117
+ print('\nFiltering by meta items: ')
118
+ for unit in inters:
119
+ if unit[1] in can_items.keys():
120
+ new_inters.append(unit)
121
+ inters, new_inters = new_inters, []
122
+ print(' The number of inters: ', len(inters))
123
+
124
+ # filter by k-core
125
+ if user_k_core_threshold or item_k_core_threshold:
126
+ print('\nFiltering by k-core:')
127
+ idx = 0
128
+ user2count = get_user2count(inters)
129
+ item2count = get_item2count(inters)
130
+
131
+ while True:
132
+ new_user2count = collections.defaultdict(int)
133
+ new_item2count = collections.defaultdict(int)
134
+ users, n_filtered_users = generate_candidates( # users is set
135
+ user2count, user_k_core_threshold)
136
+ items, n_filtered_items = generate_candidates(
137
+ item2count, item_k_core_threshold)
138
+ if n_filtered_users == 0 and n_filtered_items == 0:
139
+ break
140
+ for unit in inters:
141
+ if unit[0] in users and unit[1] in items:
142
+ new_inters.append(unit)
143
+ new_user2count[unit[0]] += 1
144
+ new_item2count[unit[1]] += 1
145
+ idx += 1
146
+ inters, new_inters = new_inters, []
147
+ user2count, item2count = new_user2count, new_item2count
148
+ print(' Epoch %d The number of inters: %d, users: %d, items: %d'
149
+ % (idx, len(inters), len(user2count), len(item2count)))
150
+ return inters
151
+
152
+
153
+ def make_inters_in_order(inters):
154
+ user2inters, new_inters = collections.defaultdict(list), list()
155
+ for inter in inters:
156
+ user, item, rating, timestamp = inter
157
+ user2inters[user].append((user, item, rating, timestamp))
158
+ for user in user2inters:
159
+ user_inters = user2inters[user]
160
+ user_inters.sort(key=lambda d: d[3])
161
+ interacted_item = set()
162
+ for inter in user_inters:
163
+ if inter[1] in interacted_item: # 过滤重复交互
164
+ continue
165
+ interacted_item.add(inter[1])
166
+ new_inters.append(inter)
167
+ return new_inters
168
+
169
+
170
+ def preprocess_rating(args):
171
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
172
+
173
+ print('Process rating data: ')
174
+ print(' Dataset: ', args.dataset)
175
+
176
+ # load ratings
177
+ rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv')
178
+ rating_users, rating_items, rating_inters = load_ratings(rating_file_path)
179
+
180
+ # load item IDs with meta data
181
+ meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz')
182
+ meta_items = load_meta_items(meta_file_path)
183
+
184
+ # 1. Filter items w/o meta data;
185
+ # 2. K-core filtering;
186
+ print('The number of raw inters: ', len(rating_inters))
187
+
188
+ rating_inters = make_inters_in_order(rating_inters)
189
+
190
+ rating_inters = filter_inters(rating_inters, can_items=meta_items,
191
+ user_k_core_threshold=args.user_k,
192
+ item_k_core_threshold=args.item_k)
193
+
194
+ # sort interactions chronologically for each user
195
+ rating_inters = make_inters_in_order(rating_inters)
196
+ print('\n')
197
+
198
+ # return: list of (user_ID, item_ID, rating, timestamp)
199
+ return rating_inters, meta_items
200
+
201
+ def convert_inters2dict(inters):
202
+ user2items = collections.defaultdict(list)
203
+ user2index, item2index = dict(), dict()
204
+ for inter in inters:
205
+ user, item, rating, timestamp = inter
206
+ if user not in user2index:
207
+ user2index[user] = len(user2index)
208
+ if item not in item2index:
209
+ item2index[item] = len(item2index)
210
+ user2items[user2index[user]].append(item2index[item])
211
+ return user2items, user2index, item2index
212
+
213
+ def generate_data(args, rating_inters):
214
+ print('Split dataset: ')
215
+ print(' Dataset: ', args.dataset)
216
+
217
+ # generate train valid temp
218
+ user2items, user2index, item2index = convert_inters2dict(rating_inters)
219
+ train_inters, valid_inters, test_inters = dict(), dict(), dict()
220
+ for u_index in range(len(user2index)):
221
+ inters = user2items[u_index]
222
+ # leave one out
223
+ train_inters[u_index] = [str(i_index) for i_index in inters[:-2]]
224
+ valid_inters[u_index] = [str(inters[-2])]
225
+ test_inters[u_index] = [str(inters[-1])]
226
+ assert len(user2items[u_index]) == len(train_inters[u_index]) + \
227
+ len(valid_inters[u_index]) + len(test_inters[u_index])
228
+ return user2items, train_inters, valid_inters, test_inters, user2index, item2index
229
+
230
+ def convert_to_atomic_files(args, train_data, valid_data, test_data):
231
+ print('Convert dataset: ')
232
+ print(' Dataset: ', args.dataset)
233
+ uid_list = list(train_data.keys())
234
+ uid_list.sort(key=lambda t: int(t))
235
+
236
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.train.inter'), 'w') as file:
237
+ file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
238
+ for uid in uid_list:
239
+ item_seq = train_data[uid]
240
+ seq_len = len(item_seq)
241
+ for target_idx in range(1, seq_len):
242
+ target_item = item_seq[-target_idx]
243
+ seq = item_seq[:-target_idx][-50:]
244
+ file.write(f'{uid}\t{" ".join(seq)}\t{target_item}\n')
245
+
246
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.valid.inter'), 'w') as file:
247
+ file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
248
+ for uid in uid_list:
249
+ item_seq = train_data[uid][-50:]
250
+ target_item = valid_data[uid][0]
251
+ file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n')
252
+
253
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.test.inter'), 'w') as file:
254
+ file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
255
+ for uid in uid_list:
256
+ item_seq = (train_data[uid] + valid_data[uid])[-50:]
257
+ target_item = test_data[uid][0]
258
+ file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n')
259
+
260
+ def parse_args():
261
+ parser = argparse.ArgumentParser()
262
+ parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
263
+ parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering')
264
+ parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering')
265
+ parser.add_argument('--input_path', type=str, default='')
266
+ parser.add_argument('--output_path', type=str, default='')
267
+ return parser.parse_args()
268
+
269
+
270
+ if __name__ == '__main__':
271
+ args = parse_args()
272
+
273
+ # load interactions from raw rating file
274
+ rating_inters, meta_items = preprocess_rating(args)
275
+
276
+
277
+ # split train/valid/temp
278
+ all_inters,train_inters, valid_inters, test_inters, user2index, item2index = generate_data(args, rating_inters)
279
+
280
+ check_path(os.path.join(args.output_path, args.dataset))
281
+
282
+ write_json_file(all_inters, os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter.json'))
283
+ convert_to_atomic_files(args, train_inters, valid_inters, test_inters)
284
+
285
+ item2feature = collections.defaultdict(dict)
286
+ for item, item_id in item2index.items():
287
+ item2feature[item_id] = meta_items[item]
288
+
289
+ # reviews = load_review_data(args, user2index, item2index)
290
+
291
+ print("user:",len(user2index))
292
+ print("item:",len(item2index))
293
+
294
+ write_json_file(item2feature, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item.json'))
295
+ # write_json_file(reviews, os.path.join(args.output_path, args.dataset, f'{args.dataset}.review.json'))
296
+
297
+
298
+ write_remap_index(user2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.user2id'))
299
+ write_remap_index(item2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item2id'))
data_process/amazon18_recbole_data_process.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import gzip
4
+ import html
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from utils import check_path, clean_text, amazon18_dataset2fullname,write_json_file,write_remap_index
13
+
14
+ def load_ratings(file):
15
+ users, items, inters = set(), set(), set()
16
+ with open(file, 'r') as fp:
17
+ for line in tqdm(fp, desc='Load ratings'):
18
+ try:
19
+ item, user, rating, time = line.strip().split(',')
20
+ users.add(user)
21
+ items.add(item)
22
+ inters.add((user, item, float(rating), int(time)))
23
+ except ValueError:
24
+ print(line)
25
+ return users, items, inters
26
+
27
+
28
+ def load_meta_items(file):
29
+ items = {}
30
+ # re_tag = re.compile('</?\w+[^>]*>')
31
+ with gzip.open(file, "r") as fp:
32
+ for line in tqdm(fp, desc="Load metas"):
33
+ data = json.loads(line)
34
+ item = data["asin"]
35
+ title = clean_text(data["title"])
36
+
37
+ descriptions = data["description"]
38
+ descriptions = clean_text(descriptions)
39
+ # new_descriptions = []
40
+ # for description in descriptions:
41
+ # description = re.sub(re_tag, '', description)
42
+ # new_descriptions.append(description.strip())
43
+ # descriptions = " ".join(new_descriptions).strip()
44
+
45
+ brand = data["brand"].replace("by\n", "").strip()
46
+
47
+ categories = data["category"]
48
+ new_categories = []
49
+ for category in categories:
50
+ if "</span>" in category:
51
+ break
52
+ new_categories.append(category.strip())
53
+ categories = ",".join(new_categories[1:]).strip()
54
+
55
+ items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories}
56
+ # print(items[item])
57
+ return items
58
+
59
+
60
+ def get_user2count(inters):
61
+ user2count = collections.defaultdict(int)
62
+ for unit in inters:
63
+ user2count[unit[0]] += 1
64
+ return user2count
65
+
66
+
67
+ def get_item2count(inters):
68
+ item2count = collections.defaultdict(int)
69
+ for unit in inters:
70
+ item2count[unit[1]] += 1
71
+ return item2count
72
+
73
+
74
+ def generate_candidates(unit2count, threshold):
75
+ cans = set()
76
+ for unit, count in unit2count.items():
77
+ if count >= threshold:
78
+ cans.add(unit)
79
+ return cans, len(unit2count) - len(cans)
80
+
81
+
82
+ def filter_inters(inters, can_items=None,
83
+ user_k_core_threshold=0, item_k_core_threshold=0):
84
+ new_inters = []
85
+
86
+ # filter by meta items
87
+ if can_items:
88
+ print('\nFiltering by meta items: ')
89
+ for unit in inters:
90
+ if unit[1] in can_items.keys():
91
+ new_inters.append(unit)
92
+ inters, new_inters = new_inters, []
93
+ print(' The number of inters: ', len(inters))
94
+
95
+ # filter by k-core
96
+ if user_k_core_threshold or item_k_core_threshold:
97
+ print('\nFiltering by k-core:')
98
+ idx = 0
99
+ user2count = get_user2count(inters)
100
+ item2count = get_item2count(inters)
101
+
102
+ while True:
103
+ new_user2count = collections.defaultdict(int)
104
+ new_item2count = collections.defaultdict(int)
105
+ users, n_filtered_users = generate_candidates( # users is set
106
+ user2count, user_k_core_threshold)
107
+ items, n_filtered_items = generate_candidates(
108
+ item2count, item_k_core_threshold)
109
+ if n_filtered_users == 0 and n_filtered_items == 0:
110
+ break
111
+ for unit in inters:
112
+ if unit[0] in users and unit[1] in items:
113
+ new_inters.append(unit)
114
+ new_user2count[unit[0]] += 1
115
+ new_item2count[unit[1]] += 1
116
+ idx += 1
117
+ inters, new_inters = new_inters, []
118
+ user2count, item2count = new_user2count, new_item2count
119
+ print(' Epoch %d The number of inters: %d, users: %d, items: %d'
120
+ % (idx, len(inters), len(user2count), len(item2count)))
121
+ return inters
122
+
123
+
124
+ def make_inters_in_order(inters):
125
+ user2inters, new_inters = collections.defaultdict(list), list()
126
+ for inter in inters:
127
+ user, item, rating, timestamp = inter
128
+ user2inters[user].append((user, item, rating, timestamp))
129
+ for user in user2inters:
130
+ user_inters = user2inters[user]
131
+ user_inters.sort(key=lambda d: d[3])
132
+ interacted_item = set()
133
+ for inter in user_inters:
134
+ if inter[1] in interacted_item: # 过滤重复交互
135
+ continue
136
+ interacted_item.add(inter[1])
137
+ new_inters.append(inter)
138
+ return new_inters
139
+
140
+
141
+ def preprocess_rating(args):
142
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
143
+
144
+ print('Process rating data: ')
145
+ print(' Dataset: ', args.dataset)
146
+
147
+ # load ratings
148
+ rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv')
149
+ rating_users, rating_items, rating_inters = load_ratings(rating_file_path)
150
+
151
+ # load item IDs with meta data
152
+ meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz')
153
+ meta_items = load_meta_items(meta_file_path)
154
+
155
+ # 1. Filter items w/o meta data;
156
+ # 2. K-core filtering;
157
+ print('The number of raw inters: ', len(rating_inters))
158
+
159
+ rating_inters = make_inters_in_order(rating_inters)
160
+
161
+ rating_inters = filter_inters(rating_inters, can_items=meta_items,
162
+ user_k_core_threshold=args.user_k,
163
+ item_k_core_threshold=args.item_k)
164
+
165
+ # sort interactions chronologically for each user
166
+ rating_inters = make_inters_in_order(rating_inters)
167
+ print('\n')
168
+
169
+ # return: list of (user_ID, item_ID, rating, timestamp)
170
+ return rating_inters, meta_items
171
+
172
+ def save_inter(args, inters):
173
+ print('Convert dataset: ')
174
+ print(' Dataset: ', args.dataset)
175
+
176
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter'), 'w') as file:
177
+ file.write('user_id:token\titem_id:token\trating:float\ttimestamp:float\n')
178
+ for inter in inters:
179
+ user, item, rating, timestamp = inter
180
+ file.write(f'{user}\t{item}\t{rating}\t{timestamp}\n')
181
+
182
+
183
+ def save_feat(args, feat, all_items):
184
+ iid_list = list(feat.keys())
185
+ num_item = 0
186
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.item'), 'w') as file:
187
+ # "title": title, "description": descriptions, "brand": brand, "categories": categories
188
+ file.write('item_id:token\ttitle:token_seq\tbrand:token\tcategories:token_seq\n')
189
+ for iid in iid_list:
190
+ if iid in all_items:
191
+ num_item += 1
192
+ title, brand, categories = feat[iid]["title"], feat[iid]["brand"], feat[iid]["categories"]
193
+ file.write(f'{iid}\t{title}\t{brand}\t{categories}\n')
194
+ print("num_item: ", num_item)
195
+
196
+
197
+ def parse_args():
198
+ parser = argparse.ArgumentParser()
199
+ parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
200
+ parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering')
201
+ parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering')
202
+ parser.add_argument('--input_path', type=str, default='')
203
+ parser.add_argument('--output_path', type=str, default='')
204
+ return parser.parse_args()
205
+
206
+
207
+ if __name__ == '__main__':
208
+ args = parse_args()
209
+
210
+ # load interactions from raw rating file
211
+ rating_inters, meta_items = preprocess_rating(args)
212
+
213
+ check_path(os.path.join(args.output_path, args.dataset))
214
+
215
+
216
+ all_items = set()
217
+ for inter in rating_inters:
218
+ user, item, rating, timestamp = inter
219
+ all_items.add(item)
220
+
221
+ print("total item: ", len(list(all_items)))
222
+
223
+ save_inter(args,rating_inters)
224
+ save_feat(args,meta_items, all_items)
225
+
226
+
data_process/amazon_text_emb.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import gzip
4
+ import html
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from utils import *
13
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModel
14
+
15
+
16
+ def load_data(args):
17
+
18
+ item2feature_path = args.data_path
19
+ item2feature = load_json(item2feature_path)
20
+
21
+ return item2feature
22
+
23
+ def generate_text(item2feature, features):
24
+ item_text_list = []
25
+
26
+ for item in item2feature:
27
+ data = item2feature[item]
28
+ text = []
29
+ for meta_key in features:
30
+ if meta_key in data:
31
+ meta_value = clean_text(data[meta_key])
32
+ text.append(meta_value.strip())
33
+
34
+ item_text_list.append([int(item), text])
35
+
36
+ return item_text_list
37
+
38
+ def preprocess_text(args):
39
+ print('Process text data ...')
40
+ # print('Dataset: ', args.dataset)
41
+
42
+ item2feature = load_data(args)
43
+ # load item text and clean
44
+ item_text_list = generate_text(item2feature, ['title'])
45
+ # item_text_list = generate_text(item2feature, ['title'])
46
+ # return: list of (item_ID, cleaned_item_text)
47
+ return item_text_list
48
+
49
+ def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1, save_path = ''):
50
+ print('Generate text embedding ...')
51
+ # print(' Dataset: ', args.dataset)
52
+
53
+ items, texts = zip(*item_text_list)
54
+ order_texts = [[0]] * len(items)
55
+ for item, text in zip(items, texts):
56
+ order_texts[item] = text
57
+ for text in order_texts:
58
+ assert text != [0]
59
+
60
+ embeddings = []
61
+ emb_result = []
62
+ start, batch_size = 0, 1
63
+ with torch.no_grad():
64
+ while start < len(order_texts):
65
+ if (start+1)%100==0:
66
+ print("==>",start+1)
67
+ field_texts = order_texts[start: start + batch_size]
68
+ # print(field_texts)
69
+ field_texts = zip(*field_texts)
70
+
71
+ field_embeddings = []
72
+ for sentences in field_texts:
73
+ sentences = list(sentences)
74
+ # print(sentences)
75
+ if word_drop_ratio > 0:
76
+ print(f'Word drop with p={word_drop_ratio}')
77
+ new_sentences = []
78
+ for sent in sentences:
79
+ new_sent = []
80
+ sent = sent.split(' ')
81
+ for wd in sent:
82
+ rd = random.random()
83
+ if rd > word_drop_ratio:
84
+ new_sent.append(wd)
85
+ new_sent = ' '.join(new_sent)
86
+ new_sentences.append(new_sent)
87
+ sentences = new_sentences
88
+ encoded_sentences = tokenizer(sentences, max_length=args.max_sent_len,
89
+ truncation=True, return_tensors='pt',padding="longest").to(args.device)
90
+ outputs = model(input_ids=encoded_sentences.input_ids,
91
+ attention_mask=encoded_sentences.attention_mask)
92
+
93
+ masked_output = outputs.last_hidden_state * encoded_sentences['attention_mask'].unsqueeze(-1)
94
+ mean_output = masked_output.sum(dim=1) / encoded_sentences['attention_mask'].sum(dim=-1, keepdim=True)
95
+ mean_output = mean_output.detach().cpu()
96
+ emb_result.append(mean_output.numpy().tolist())
97
+ field_embeddings.append(mean_output)
98
+
99
+ field_mean_embedding = torch.stack(field_embeddings, dim=0).mean(dim=0)
100
+ embeddings.append(field_mean_embedding)
101
+ start += batch_size
102
+
103
+ embeddings = torch.cat(embeddings, dim=0).numpy()
104
+ print('Embeddings shape: ', embeddings.shape)
105
+
106
+ all_results = {
107
+ 'text':[],
108
+ 'node_type':[],
109
+ 'emb':[]
110
+ }
111
+
112
+ all_results['text'] = [t[0] for t in texts]
113
+ all_results['node_type'] = [1] * len(all_results['text'])
114
+ for emb in emb_result:
115
+ str_emb = ''
116
+ for x in emb:
117
+ str_emb = str_emb + str(x) + ' '
118
+ all_results['emb'].append(str_emb[:-1])
119
+
120
+ import pandas as pd
121
+ df = pd.DataFrame(all_results)
122
+ # header = 0: w/o column name; index = False: w/o index column
123
+ df.to_csv(args.save_path, sep = '\t', header = 0, index = False)
124
+
125
+ # file = os.path.join(args.root, args.dataset + '.emb-' + args.plm_name + "-td" + ".npy")
126
+ # np.save(file, embeddings)
127
+
128
+
129
+ def parse_args():
130
+ parser = argparse.ArgumentParser()
131
+ parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
132
+ parser.add_argument('--root', type=str, default="")
133
+ parser.add_argument('--gpu_id', type=int, default=0, help='ID of running GPU')
134
+ parser.add_argument('--plm_name', type=str, default='llama')
135
+ parser.add_argument('--plm_checkpoint', type=str,
136
+ default='')
137
+ parser.add_argument('--max_sent_len', type=int, default=2048)
138
+ parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default')
139
+ parser.add_argument('--data_path', type=str, default='')
140
+ parser.add_argument('--save_path', type=str, default='')
141
+ return parser.parse_args()
142
+
143
+
144
+ if __name__ == '__main__':
145
+ args = parse_args()
146
+
147
+ args.root = os.path.join(args.root, args.dataset)
148
+
149
+ device = set_device(args.gpu_id)
150
+ args.device = device
151
+
152
+ item_text_list = preprocess_text(args)
153
+
154
+ plm_tokenizer, plm_model = load_plm(args.plm_checkpoint)
155
+ if plm_tokenizer.pad_token_id is None:
156
+ plm_tokenizer.pad_token_id = 0
157
+ plm_model = plm_model.to(device)
158
+
159
+ generate_item_embedding(args, item_text_list, plm_tokenizer,
160
+ plm_model, word_drop_ratio = args.word_drop_ratio,
161
+ save_path = args.save_path)
data_process/get_llm_output.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import argparse
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import time
8
+ from logging import getLogger
9
+ import openai
10
+ from utils import get_res_batch, load_json, intention_prompt, preference_prompt_1, preference_prompt_2, amazon18_dataset2fullname, write_json_file
11
+ import json
12
+
13
+
14
+
15
+ def get_intention_train(args, inters, item2feature, reviews, api_info):
16
+
17
+ intention_train_output_file = os.path.join(args.root,"intention_train.json")
18
+
19
+
20
+ # Suggest modifying the prompt based on different datasets
21
+ prompt = intention_prompt
22
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
23
+ dataset_full_name = dataset_full_name.replace("_", " ").lower()
24
+ print(dataset_full_name)
25
+
26
+ prompt_list = []
27
+
28
+ inter_data = []
29
+
30
+ for (user,item_list) in inters.items():
31
+ user = int(user)
32
+ item = int(item_list[-3])
33
+ history = item_list[:-3]
34
+
35
+ inter_data.append((user,item,history))
36
+
37
+ review = reviews[str((user, item))]["review"]
38
+ item_title = item2feature[str(item)]["title"]
39
+ input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review)
40
+ prompt_list.append(input_prompt)
41
+
42
+ st = 0
43
+ with open(intention_train_output_file, mode='a') as f:
44
+
45
+ while st < len(prompt_list):
46
+ # while st < 3:
47
+ print(st)
48
+ # if st < 25631:
49
+ # st += args.batchsize
50
+ # continue
51
+
52
+
53
+ res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info)
54
+
55
+ for i, answer in enumerate(res):
56
+ user, item, history = inter_data[st+i]
57
+ # print(answer)
58
+ # print("=============")
59
+
60
+ if answer == '':
61
+ print("answer null error")
62
+ answer = "I enjoy high-quality item."
63
+
64
+ if answer.strip().count('\n') != 1:
65
+ if 'haracteristics:' in answer:
66
+ answer = answer.strip().split("The item's characteristics:")
67
+ else:
68
+ answer = answer.strip().split("The item's characteristic:")
69
+ else:
70
+ answer = answer.strip().split('\n')
71
+
72
+ if '' in answer:
73
+ answer.remove('')
74
+
75
+ if len(answer) == 1:
76
+ print(answer)
77
+ user_preference = item_character = answer[0]
78
+ elif len(answer) >= 3:
79
+ print(answer)
80
+ answer = answer[-1]
81
+ user_preference = item_character = answer
82
+ else:
83
+ user_preference, item_character = answer
84
+
85
+ if ':' in user_preference:
86
+ idx = user_preference.index(':')
87
+ user_preference = user_preference[idx+1:]
88
+ user_preference = user_preference.strip().replace('}','')
89
+ user_preference = user_preference.replace('\n','')
90
+
91
+ if ':' in item_character:
92
+ idx = item_character.index(':')
93
+ item_character = item_character[idx+1:]
94
+ item_character = item_character.strip().replace('}','')
95
+ item_character = item_character.replace('\n','')
96
+
97
+
98
+ dict = {"user":user, "item":item, "inters": history,
99
+ "user_related_intention":user_preference, "item_related_intention": item_character}
100
+
101
+ json.dump(dict, f)
102
+ f.write("\n")
103
+
104
+ st += args.batchsize
105
+
106
+ return intention_train_output_file
107
+
108
+
109
+ def get_intention_test(args, inters, item2feature, reviews, api_info):
110
+
111
+ intention_test_output_file = os.path.join(args.root,"intention_test.json")
112
+
113
+ # Suggest modifying the prompt based on different datasets
114
+ prompt = intention_prompt
115
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
116
+ dataset_full_name = dataset_full_name.replace("_", " ").lower()
117
+ print(dataset_full_name)
118
+
119
+ prompt_list = []
120
+
121
+ inter_data = []
122
+
123
+ for (user,item_list) in inters.items():
124
+ user = int(user)
125
+ item = int(item_list[-1])
126
+ history = item_list[:-1]
127
+
128
+ inter_data.append((user,item,history))
129
+
130
+ review = reviews[str((user, item))]["review"]
131
+ item_title = item2feature[str(item)]["title"]
132
+ input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review)
133
+ prompt_list.append(input_prompt)
134
+
135
+ st = 0
136
+ with open(intention_test_output_file, mode='a') as f:
137
+
138
+ while st < len(prompt_list):
139
+ # while st < 3:
140
+ print(st)
141
+ # if st < 4623:
142
+ # st += args.batchsize
143
+ # continue
144
+
145
+ res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info)
146
+
147
+ for i, answer in enumerate(res):
148
+ user, item, history = inter_data[st+i]
149
+
150
+ if answer == '':
151
+ print("answer null error")
152
+ answer = "I enjoy high-quality item."
153
+
154
+ if answer.strip().count('\n') != 1:
155
+ if 'haracteristics:' in answer:
156
+ answer = answer.strip().split("The item's characteristics:")
157
+ else:
158
+ answer = answer.strip().split("The item's characteristic:")
159
+ else:
160
+ answer = answer.strip().split('\n')
161
+
162
+ if '' in answer:
163
+ answer.remove('')
164
+
165
+ if len(answer) == 1:
166
+ print(answer)
167
+ user_preference = item_character = answer[0]
168
+ elif len(answer) >= 3:
169
+ print(answer)
170
+ answer = answer[-1]
171
+ user_preference = item_character = answer
172
+ else:
173
+ user_preference, item_character = answer
174
+
175
+ if ':' in user_preference:
176
+ idx = user_preference.index(':')
177
+ user_preference = user_preference[idx+1:]
178
+ user_preference = user_preference.strip().replace('}','')
179
+ user_preference = user_preference.replace('\n','')
180
+
181
+ if ':' in item_character:
182
+ idx = item_character.index(':')
183
+ item_character = item_character[idx+1:]
184
+ item_character = item_character.strip().replace('}','')
185
+ item_character = item_character.replace('\n','')
186
+
187
+
188
+ dict = {"user":user, "item":item, "inters": history,
189
+ "user_related_intention":user_preference, "item_related_intention": item_character}
190
+
191
+ json.dump(dict, f)
192
+ f.write("\n")
193
+
194
+ st += args.batchsize
195
+
196
+ return intention_test_output_file
197
+
198
+
199
+
200
+
201
+ def get_user_preference(args, inters, item2feature, reviews, api_info):
202
+
203
+ preference_output_file = os.path.join(args.root,"user_preference.json")
204
+
205
+
206
+ # Suggest modifying the prompt based on different datasets
207
+ prompt_1 = preference_prompt_1
208
+ prompt_2 = preference_prompt_2
209
+
210
+
211
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
212
+ dataset_full_name = dataset_full_name.replace("_", " ").lower()
213
+ print(dataset_full_name)
214
+
215
+ prompt_list_1 = []
216
+ prompt_list_2 = []
217
+
218
+ users = []
219
+
220
+ for (user,item_list) in inters.items():
221
+ users.append(user)
222
+ history = item_list[:-3]
223
+ item_titles = []
224
+ for j, item in enumerate(history):
225
+ item_titles.append(str(j+1) + '.' + item2feature[str(item)]["title"])
226
+ if len(item_titles) > args.max_his_len:
227
+ item_titles = item_titles[-args.max_his_len:]
228
+ item_titles = ", ".join(item_titles)
229
+
230
+ input_prompt_1 = prompt_1.format(dataset_full_name=dataset_full_name, item_titles=item_titles)
231
+ input_prompt_2 = prompt_2.format(dataset_full_name=dataset_full_name, item_titles=item_titles)
232
+
233
+ prompt_list_1.append(input_prompt_1)
234
+ prompt_list_2.append(input_prompt_2)
235
+
236
+
237
+ st = 0
238
+ with open(preference_output_file, mode='a') as f:
239
+
240
+ while st < len(prompt_list_1):
241
+ # while st < 3:
242
+ print(st)
243
+ # if st < 22895:
244
+ # st += args.batchsize
245
+ # continue
246
+
247
+ res_1 = get_res_batch(args.model_name, prompt_list_1[st:st + args.batchsize], args.max_tokens, api_info)
248
+ res_2 = get_res_batch(args.model_name, prompt_list_2[st:st + args.batchsize], args.max_tokens, api_info)
249
+ for i, answers in enumerate(zip(res_1, res_2)):
250
+
251
+ user = users[st + i]
252
+
253
+ answer_1, answer_2 = answers
254
+ # print(answers)
255
+ # print("=============")
256
+
257
+ if answer_1 == '':
258
+ print("answer null error")
259
+ answer_1 = "I enjoy high-quality item."
260
+
261
+ if answer_2 == '':
262
+ print("answer null error")
263
+ answer_2 = "I enjoy high-quality item."
264
+
265
+ if answer_2.strip().count('\n') != 1:
266
+ if 'references:' in answer_2:
267
+ answer_2 = answer_2.strip().split("Short-term preferences:")
268
+ else:
269
+ answer_2 = answer_2.strip().split("Short-term preference:")
270
+ else:
271
+ answer_2 = answer_2.strip().split('\n')
272
+
273
+ if '' in answer_2:
274
+ answer_2.remove('')
275
+
276
+ if len(answer_2) == 1:
277
+ print(answer_2)
278
+ long_preference = short_preference = answer_2[0]
279
+ elif len(answer_2) >= 3:
280
+ print(answer_2)
281
+ answer_2 = answer_2[-1]
282
+ long_preference = short_preference = answer_2
283
+ else:
284
+ long_preference, short_preference = answer_2
285
+
286
+ if ':' in long_preference:
287
+ idx = long_preference.index(':')
288
+ long_preference = long_preference[idx+1:]
289
+ long_preference = long_preference.strip().replace('}','')
290
+ long_preference = long_preference.replace('\n','')
291
+
292
+ if ':' in short_preference:
293
+ idx = short_preference.index(':')
294
+ short_preference = short_preference[idx+1:]
295
+ short_preference = short_preference.strip().replace('}','')
296
+ short_preference = short_preference.replace('\n','')
297
+
298
+ dict = {"user":user,"user_preference":[answer_1, long_preference, short_preference]}
299
+ # print(dict)
300
+ json.dump(dict, f)
301
+ f.write("\n")
302
+
303
+ st += args.batchsize
304
+
305
+ return preference_output_file
306
+
307
+ def parse_args():
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument('--dataset', type=str, default='Instruments', help='Instruments / Arts / Games')
310
+ parser.add_argument('--root', type=str, default='')
311
+ parser.add_argument('--api_info', type=str, default='./api_info.json')
312
+ parser.add_argument('--model_name', type=str, default='text-davinci-003')
313
+ parser.add_argument('--max_tokens', type=int, default=512)
314
+ parser.add_argument('--batchsize', type=int, default=16)
315
+ parser.add_argument('--max_his_len', type=int, default=20)
316
+ return parser.parse_args()
317
+
318
+ if __name__ == "__main__":
319
+ args = parse_args()
320
+
321
+ args.root = os.path.join(args.root, args.dataset)
322
+
323
+ api_info = load_json(args.api_info)
324
+ openai.api_key = api_info["api_key_list"].pop()
325
+
326
+
327
+ inter_path = os.path.join(args.root, f'{args.dataset}.inter.json')
328
+ inters = load_json(inter_path)
329
+
330
+
331
+ item2feature_path = os.path.join(args.root, f'{args.dataset}.item.json')
332
+ item2feature = load_json(item2feature_path)
333
+
334
+ reviews_path = os.path.join(args.root, f'{args.dataset}.review.json')
335
+ reviews = load_json(reviews_path)
336
+
337
+ intention_train_output_file = get_intention_train(args, inters, item2feature, reviews, api_info)
338
+ intention_test_output_file = get_intention_test(args, inters, item2feature, reviews ,api_info)
339
+ preference_output_file = get_user_preference(args, inters, item2feature, reviews, api_info)
340
+
341
+ intention_train = {}
342
+ intention_test = {}
343
+ user_preference = {}
344
+
345
+ with open(intention_train_output_file, "r") as f:
346
+ for line in f:
347
+ # print(line)
348
+ content = json.loads(line)
349
+ if content["user"] not in intention_train:
350
+ intention_train[content["user"]] = {"item":content["item"],
351
+ "inters":content["inters"],
352
+ "querys":[ content["user_related_intention"], content["item_related_intention"] ]}
353
+
354
+
355
+ with open(intention_test_output_file, "r") as f:
356
+ for line in f:
357
+ content = json.loads(line)
358
+ if content["user"] not in intention_train:
359
+ intention_test[content["user"]] = {"item":content["item"],
360
+ "inters":content["inters"],
361
+ "querys":[ content["user_related_intention"], content["item_related_intention"] ]}
362
+
363
+
364
+ with open(preference_output_file, "r") as f:
365
+ for line in f:
366
+ content = json.loads(line)
367
+ user_preference[content["user"]] = content["user_preference"]
368
+
369
+ user_dict = {
370
+ "user_explicit_preference": user_preference,
371
+ "user_vague_intention": {"train": intention_train, "test": intention_test},
372
+ }
373
+
374
+ write_json_file(user_dict, os.path.join(args.root, f'{args.dataset}.user.json'))
data_process/utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import json
3
+ import os
4
+ import pickle
5
+ import re
6
+ import time
7
+
8
+ import torch
9
+ # import gensim
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import collections
12
+ import openai
13
+
14
+
15
+
16
+ def get_res_batch(model_name, prompt_list, max_tokens, api_info):
17
+
18
+ while True:
19
+ try:
20
+ res = openai.Completion.create(
21
+ model=model_name,
22
+ prompt=prompt_list,
23
+ temperature=0.4,
24
+ max_tokens=max_tokens,
25
+ top_p=1,
26
+ frequency_penalty=0,
27
+ presence_penalty=0
28
+ )
29
+ output_list = []
30
+ for choice in res['choices']:
31
+ output = choice['text'].strip()
32
+ output_list.append(output)
33
+
34
+ return output_list
35
+
36
+ except openai.error.AuthenticationError as e:
37
+ print(e)
38
+ openai.api_key = api_info["api_key_list"].pop()
39
+ time.sleep(10)
40
+ except openai.error.RateLimitError as e:
41
+ print(e)
42
+ if str(e) == "You exceeded your current quota, please check your plan and billing details.":
43
+ openai.api_key = api_info["api_key_list"].pop()
44
+ time.sleep(10)
45
+ else:
46
+ print('\nopenai.error.RateLimitError\nRetrying...')
47
+ time.sleep(10)
48
+ except openai.error.ServiceUnavailableError as e:
49
+ print(e)
50
+ print('\nopenai.error.ServiceUnavailableError\nRetrying...')
51
+ time.sleep(10)
52
+ except openai.error.Timeout:
53
+ print('\nopenai.error.Timeout\nRetrying...')
54
+ time.sleep(10)
55
+ except openai.error.APIError as e:
56
+ print(e)
57
+ print('\nopenai.error.APIError\nRetrying...')
58
+ time.sleep(10)
59
+ except openai.error.APIConnectionError as e:
60
+ print(e)
61
+ print('\nopenai.error.APIConnectionError\nRetrying...')
62
+ time.sleep(10)
63
+ except Exception as e:
64
+ print(e)
65
+ return None
66
+
67
+
68
+
69
+
70
+ def check_path(path):
71
+ if not os.path.exists(path):
72
+ os.makedirs(path)
73
+
74
+
75
+ def set_device(gpu_id):
76
+ if gpu_id == -1:
77
+ return torch.device('cpu')
78
+ else:
79
+ return torch.device(
80
+ 'cuda:' + str(gpu_id) if torch.cuda.is_available() else 'cpu')
81
+
82
+ def load_plm(model_path='bert-base-uncased'):
83
+
84
+ tokenizer = AutoTokenizer.from_pretrained(model_path,)
85
+
86
+ print("Load Model:", model_path)
87
+
88
+ model = AutoModel.from_pretrained(model_path,low_cpu_mem_usage=True,)
89
+ return tokenizer, model
90
+
91
+ def load_json(file):
92
+ with open(file, 'r') as f:
93
+ data = json.load(f)
94
+ return data
95
+
96
+ def clean_text(raw_text):
97
+ if isinstance(raw_text, list):
98
+ new_raw_text=[]
99
+ for raw in raw_text:
100
+ raw = html.unescape(raw)
101
+ raw = re.sub(r'</?\w+[^>]*>', '', raw)
102
+ raw = re.sub(r'["\n\r]*', '', raw)
103
+ new_raw_text.append(raw.strip())
104
+ cleaned_text = ' '.join(new_raw_text)
105
+ else:
106
+ if isinstance(raw_text, dict):
107
+ cleaned_text = str(raw_text)[1:-1].strip()
108
+ else:
109
+ cleaned_text = raw_text.strip()
110
+ cleaned_text = html.unescape(cleaned_text)
111
+ cleaned_text = re.sub(r'</?\w+[^>]*>', '', cleaned_text)
112
+ cleaned_text = re.sub(r'["\n\r]*', '', cleaned_text)
113
+ index = -1
114
+ while -index < len(cleaned_text) and cleaned_text[index] == '.':
115
+ index -= 1
116
+ index += 1
117
+ if index == 0:
118
+ cleaned_text = cleaned_text + '.'
119
+ else:
120
+ cleaned_text = cleaned_text[:index] + '.'
121
+ if len(cleaned_text) >= 2000:
122
+ cleaned_text = ''
123
+ return cleaned_text
124
+
125
+ def load_pickle(filename):
126
+ with open(filename, "rb") as f:
127
+ return pickle.load(f)
128
+
129
+
130
+ def make_inters_in_order(inters):
131
+ user2inters, new_inters = collections.defaultdict(list), list()
132
+ for inter in inters:
133
+ user, item, rating, timestamp = inter
134
+ user2inters[user].append((user, item, rating, timestamp))
135
+ for user in user2inters:
136
+ user_inters = user2inters[user]
137
+ user_inters.sort(key=lambda d: d[3])
138
+ for inter in user_inters:
139
+ new_inters.append(inter)
140
+ return new_inters
141
+
142
+ def write_json_file(dic, file):
143
+ print('Writing json file: ',file)
144
+ with open(file, 'w') as fp:
145
+ json.dump(dic, fp, indent=4)
146
+
147
+ def write_remap_index(unit2index, file):
148
+ print('Writing remap file: ',file)
149
+ with open(file, 'w') as fp:
150
+ for unit in unit2index:
151
+ fp.write(unit + '\t' + str(unit2index[unit]) + '\n')
152
+
153
+
154
+ intention_prompt = "After purchasing a {dataset_full_name} item named \"{item_title}\", the user left a comment expressing his opinion and personal preferences. The user's comment is as follows: \n\"{review}\" " \
155
+ "\nAs we all know, user comments often contain information about both their personal preferences and the characteristics of the item they interacted with. From this comment, you can infer both the user's personal preferences and the characteristics of the item. " \
156
+ "Please describe your inferred user preferences and item characteristics in the first person and in the following format:\n\nMy preferences: []\nThe item's characteristics: []\n\n" \
157
+ "Note that your inference of the personalized preferences should not include any information about the title of the item."
158
+
159
+
160
+ preference_prompt_1 = "Suppose the user has bought a variety of {dataset_full_name} items, they are: \n{item_titles}. \nAs we all know, these historically purchased items serve as a reflection of the user's personalized preferences. " \
161
+ "Please analyze the user's personalized preferences based on the items he has bought and provide a brief third-person summary of the user's preferences, highlighting the key factors that influence his choice of items. Avoid listing specific items and do not list multiple examples. " \
162
+ "Your analysis should be brief and in the third person."
163
+
164
+ preference_prompt_2 = "Given a chronological list of {dataset_full_name} items that a user has purchased, we can analyze his long-term and short-term preferences. Long-term preferences are inherent characteristics of the user, which are reflected in all the items he has interacted with over time. Short-term preferences are the user's recent preferences, which are reflected in some of the items he has bought more recently. " \
165
+ "To determine the user's long-term preferences, please analyze the contents of all the items he has bought. Look for common features that appear frequently across the user's shopping records. To determine the user's short-term preferences, focus on the items he has bought most recently. Identify any new or different features that have emerged in the user's shopping records. " \
166
+ "Here is a chronological list of items that the user has bought: \n{item_titles}. \nPlease provide separate analyses for the user's long-term and short-term preferences. Your answer should be concise and general, without listing specific items. Your answer should be in the third person and in the following format:\n\nLong-term preferences: []\nShort-term preferences: []\n\n"
167
+
168
+
169
+ # remove 'Magazine', 'Gift', 'Music', 'Kindle'
170
+ amazon18_dataset_list = [
171
+ 'Appliances', 'Beauty',
172
+ 'Fashion', 'Software', 'Luxury', 'Scientific', 'Pantry',
173
+ 'Instruments', 'Arts', 'Games', 'Office', 'Garden',
174
+ 'Food', 'Cell', 'CDs', 'Automotive', 'Toys',
175
+ 'Pet', 'Tools', 'Kindle', 'Sports', 'Movies',
176
+ 'Electronics', 'Home', 'Clothing', 'Books'
177
+ ]
178
+
179
+ amazon18_dataset2fullname = {
180
+ 'Beauty': 'All_Beauty',
181
+ 'Fashion': 'AMAZON_FASHION',
182
+ 'Appliances': 'Appliances',
183
+ 'Arts': 'Arts_Crafts_and_Sewing',
184
+ 'Automotive': 'Automotive',
185
+ 'Books': 'Books',
186
+ 'CDs': 'CDs_and_Vinyl',
187
+ 'Cell': 'Cell_Phones_and_Accessories',
188
+ 'Clothing': 'Clothing_Shoes_and_Jewelry',
189
+ 'Music': 'Digital_Music',
190
+ 'Electronics': 'Electronics',
191
+ 'Gift': 'Gift_Cards',
192
+ 'Food': 'Grocery_and_Gourmet_Food',
193
+ 'Home': 'Home_and_Kitchen',
194
+ 'Scientific': 'Industrial_and_Scientific',
195
+ 'Kindle': 'Kindle_Store',
196
+ 'Luxury': 'Luxury_Beauty',
197
+ 'Magazine': 'Magazine_Subscriptions',
198
+ 'Movies': 'Movies_and_TV',
199
+ 'Instruments': 'Musical_Instruments',
200
+ 'Office': 'Office_Products',
201
+ 'Garden': 'Patio_Lawn_and_Garden',
202
+ 'Pet': 'Pet_Supplies',
203
+ 'Pantry': 'Prime_Pantry',
204
+ 'Software': 'Software',
205
+ 'Sports': 'Sports_and_Outdoors',
206
+ 'Tools': 'Tools_and_Home_Improvement',
207
+ 'Toys': 'Toys_and_Games',
208
+ 'Games': 'Video_Games'
209
+ }
210
+
211
+ amazon14_dataset_list = [
212
+ 'Beauty','Toys','Sports'
213
+ ]
214
+
215
+ amazon14_dataset2fullname = {
216
+ 'Beauty': 'Beauty',
217
+ 'Sports': 'Sports_and_Outdoors',
218
+ 'Toys': 'Toys_and_Games',
219
+ }
220
+
221
+ # c1. c2. c3. c4.
222
+ amazon_text_feature1 = ['title', 'category', 'brand']
223
+
224
+ # re-order
225
+ amazon_text_feature1_ro1 = ['brand', 'main_cat', 'category', 'title']
226
+
227
+ # remove
228
+ amazon_text_feature1_re1 = ['title']
229
+
230
+ amazon_text_feature2 = ['title']
231
+
232
+ amazon_text_feature3 = ['description']
233
+
234
+ amazon_text_feature4 = ['description', 'main_cat', 'category', 'brand']
235
+
236
+ amazon_text_feature5 = ['title', 'description']
237
+
238
+
evaluate.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ def get_topk_results(predictions, scores, targets, k, all_items=None):
4
+ results = []
5
+ B = len(targets)
6
+ predictions = [_.split("Response:")[-1] for _ in predictions]
7
+ predictions = [_.strip().replace(" ","") for _ in predictions]
8
+
9
+ if all_items is not None:
10
+ for i, seq in enumerate(predictions):
11
+ if seq not in all_items:
12
+ scores[i] = -1000
13
+
14
+ for b in range(B):
15
+ batch_seqs = predictions[b * k: (b + 1) * k]
16
+ batch_scores = scores[b * k: (b + 1) * k]
17
+
18
+ pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)]
19
+ sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True)
20
+ target_item = targets[b]
21
+ one_results = []
22
+ for sorted_pred in sorted_pairs:
23
+ if sorted_pred[0] == target_item:
24
+ one_results.append(1)
25
+ else:
26
+ one_results.append(0)
27
+
28
+ results.append(one_results)
29
+
30
+ return results
31
+
32
+ def get_metrics_results(topk_results, metrics):
33
+ res = {}
34
+ for m in metrics:
35
+ if m.lower().startswith("hit"):
36
+ k = int(m.split("@")[1])
37
+ res[m] = hit_k(topk_results, k)
38
+ elif m.lower().startswith("ndcg"):
39
+ k = int(m.split("@")[1])
40
+ res[m] = ndcg_k(topk_results, k)
41
+ else:
42
+ raise NotImplementedError
43
+
44
+ return res
45
+
46
+
47
+ def ndcg_k(topk_results, k):
48
+
49
+ ndcg = 0.0
50
+ for row in topk_results:
51
+ res = row[:k]
52
+ one_ndcg = 0.0
53
+ for i in range(len(res)):
54
+ one_ndcg += res[i] / math.log(i + 2, 2)
55
+ ndcg += one_ndcg
56
+ return ndcg
57
+
58
+
59
+ def hit_k(topk_results, k):
60
+ hit = 0.0
61
+ for row in topk_results:
62
+ res = row[:k]
63
+ if sum(res) > 0:
64
+ hit += 1
65
+ return hit
66
+
finetune.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import sys
5
+ from typing import List
6
+
7
+ import torch
8
+ import transformers
9
+
10
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
11
+
12
+ from utils import *
13
+ from collator import Collator
14
+
15
+ def train(args):
16
+
17
+ set_seed(args.seed)
18
+ ensure_dir(args.output_dir)
19
+
20
+ device_map = "auto"
21
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
22
+ ddp = world_size != 1
23
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
24
+ if local_rank == 0:
25
+ print(vars(args))
26
+
27
+ if ddp:
28
+ device_map = {"": local_rank}
29
+
30
+ config = LlamaConfig.from_pretrained(args.base_model)
31
+ tokenizer = LlamaTokenizer.from_pretrained(
32
+ args.base_model,
33
+ model_max_length = args.model_max_length,
34
+ padding_side="right",
35
+ )
36
+ tokenizer.pad_token_id = 0
37
+ gradient_checkpointing = True
38
+
39
+ train_data, valid_data = load_datasets(args)
40
+ add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
41
+ config.vocab_size = len(tokenizer)
42
+ if local_rank == 0:
43
+ print("add {} new token.".format(add_num))
44
+ print("data num:", len(train_data))
45
+ tokenizer.save_pretrained(args.output_dir)
46
+ config.save_pretrained(args.output_dir)
47
+
48
+ collator = Collator(args, tokenizer)
49
+
50
+
51
+ model = LlamaForCausalLM.from_pretrained(
52
+ args.base_model,
53
+ # torch_dtype=torch.float16,
54
+ device_map=device_map,
55
+ )
56
+ model.resize_token_embeddings(len(tokenizer))
57
+
58
+
59
+ if not ddp and torch.cuda.device_count() > 1:
60
+ model.is_parallelizable = True
61
+ model.model_parallel = True
62
+
63
+
64
+ trainer = transformers.Trainer(
65
+ model=model,
66
+ train_dataset=train_data,
67
+ eval_dataset=valid_data,
68
+ args=transformers.TrainingArguments(
69
+ seed=args.seed,
70
+ per_device_train_batch_size=args.per_device_batch_size,
71
+ per_device_eval_batch_size=args.per_device_batch_size,
72
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
73
+ warmup_ratio=args.warmup_ratio,
74
+ num_train_epochs=args.epochs,
75
+ learning_rate=args.learning_rate,
76
+ weight_decay=args.weight_decay,
77
+ lr_scheduler_type=args.lr_scheduler_type,
78
+ fp16=args.fp16,
79
+ bf16=args.bf16,
80
+ logging_steps=args.logging_step,
81
+ optim=args.optim,
82
+ gradient_checkpointing=gradient_checkpointing,
83
+ evaluation_strategy=args.save_and_eval_strategy,
84
+ save_strategy=args.save_and_eval_strategy,
85
+ eval_steps=args.save_and_eval_steps,
86
+ save_steps=args.save_and_eval_steps,
87
+ output_dir=args.output_dir,
88
+ save_total_limit=5,
89
+ load_best_model_at_end=True,
90
+ deepspeed=args.deepspeed,
91
+ ddp_find_unused_parameters=False if ddp else None,
92
+ report_to=None,
93
+ eval_delay= 1 if args.save_and_eval_strategy=="epoch" else 2000,
94
+ dataloader_num_workers = args.dataloader_num_workers,
95
+ dataloader_prefetch_factor = args.dataloader_prefetch_factor
96
+ ),
97
+ tokenizer=tokenizer,
98
+ data_collator=collator,
99
+ )
100
+ model.config.use_cache = False
101
+
102
+
103
+ trainer.train(
104
+ resume_from_checkpoint=args.resume_from_checkpoint,
105
+ )
106
+
107
+ trainer.save_state()
108
+ trainer.save_model(output_dir=args.output_dir)
109
+
110
+
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser(description='LLMRec')
115
+ parser = parse_global_args(parser)
116
+ parser = parse_train_args(parser)
117
+ parser = parse_dataset_args(parser)
118
+
119
+ args = parser.parse_args()
120
+
121
+ train(args)
index/datasets.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data as data
4
+
5
+
6
+ class EmbDataset(data.Dataset):
7
+
8
+ def __init__(self,data_path):
9
+
10
+ self.data_path = data_path
11
+ # self.embeddings = np.fromfile(data_path, dtype=np.float32).reshape(16859,-1)
12
+ self.embeddings = np.load(data_path)
13
+ self.dim = self.embeddings.shape[-1]
14
+
15
+ def __getitem__(self, index):
16
+ emb = self.embeddings[index]
17
+ tensor_emb=torch.FloatTensor(emb)
18
+ return tensor_emb
19
+
20
+ def __len__(self):
21
+ return len(self.embeddings)
index/generate_indices.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import json
3
+ import logging
4
+ import argparse
5
+ import numpy as np
6
+ import torch
7
+ from time import time
8
+ from torch import optim
9
+ from tqdm import tqdm
10
+
11
+ from torch.utils.data import DataLoader
12
+
13
+ from datasets import EmbDataset
14
+ from models.rqvae import RQVAE
15
+
16
+ import os
17
+
18
+ def check_collision(all_indices_str):
19
+ tot_item = len(all_indices_str)
20
+ tot_indice = len(set(all_indices_str.tolist()))
21
+ return tot_item==tot_indice
22
+
23
+ def get_indices_count(all_indices_str):
24
+ indices_count = collections.defaultdict(int)
25
+ for index in all_indices_str:
26
+ indices_count[index] += 1
27
+ return indices_count
28
+
29
+ def get_collision_item(all_indices_str):
30
+ index2id = {}
31
+ for i, index in enumerate(all_indices_str):
32
+ if index not in index2id:
33
+ index2id[index] = []
34
+ index2id[index].append(i)
35
+
36
+ collision_item_groups = []
37
+
38
+ for index in index2id:
39
+ if len(index2id[index]) > 1:
40
+ collision_item_groups.append(index2id[index])
41
+
42
+ return collision_item_groups
43
+
44
+ def parse_args():
45
+ parser = argparse.ArgumentParser(description = "Index")
46
+ parser.add_argument("--ckpt_path", type = str, default = "", help = "")
47
+ parser.add_argument("--data_path", type = str, default = "", help = "")
48
+ parser.add_argument("--save_path", type = str, default = "", help = "")
49
+ parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu")
50
+ return parser.parse_args()
51
+
52
+ infer_args = parse_args()
53
+ print(infer_args)
54
+
55
+ # dataset = "Games"
56
+ # ckpt_path = "/zhengbowen/rqvae_ckpt/xxxx"
57
+ # output_dir = f"/zhengbowen/data/{dataset}/"
58
+ # output_file = f"{dataset}.index.json"
59
+ # output_file = os.path.join(output_dir,output_file)
60
+ # device = torch.device("cuda:1")
61
+
62
+ device = torch.device(infer_args.device)
63
+
64
+ ckpt = torch.load(infer_args.ckpt_path, map_location = torch.device('cpu'))
65
+ args = ckpt["args"]
66
+ state_dict = ckpt["state_dict"]
67
+
68
+ data = EmbDataset(infer_args.data_path)
69
+
70
+ model = RQVAE(in_dim=data.dim,
71
+ num_emb_list=args.num_emb_list,
72
+ e_dim=args.e_dim,
73
+ layers=args.layers,
74
+ dropout_prob=args.dropout_prob,
75
+ bn=args.bn,
76
+ loss_type=args.loss_type,
77
+ quant_loss_weight=args.quant_loss_weight,
78
+ kmeans_init=args.kmeans_init,
79
+ kmeans_iters=args.kmeans_iters,
80
+ sk_epsilons=args.sk_epsilons,
81
+ sk_iters=args.sk_iters,
82
+ )
83
+
84
+ model.load_state_dict(state_dict)
85
+ model = model.to(device)
86
+ model.eval()
87
+ print(model)
88
+
89
+ data_loader = DataLoader(data,num_workers=args.num_workers,
90
+ batch_size=64, shuffle=False,
91
+ pin_memory=True)
92
+
93
+ all_indices = []
94
+ all_indices_str = []
95
+ prefix = ["<a_{}>","<b_{}>","<c_{}>","<d_{}>","<e_{}>"]
96
+
97
+ for d in tqdm(data_loader):
98
+ d = d.to(device)
99
+ indices = model.get_indices(d,use_sk=False)
100
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
101
+ for index in indices:
102
+ code = []
103
+ for i, ind in enumerate(index):
104
+ code.append(prefix[i].format(int(ind)))
105
+
106
+ all_indices.append(code)
107
+ all_indices_str.append(str(code))
108
+ # break
109
+
110
+ all_indices = np.array(all_indices)
111
+ all_indices_str = np.array(all_indices_str)
112
+
113
+ for vq in model.rq.vq_layers[:-1]:
114
+ vq.sk_epsilon=0.0
115
+
116
+ if model.rq.vq_layers[-1].sk_epsilon == 0.0:
117
+ model.rq.vq_layers[-1].sk_epsilon = 0.003
118
+
119
+ tt = 0
120
+ #There are often duplicate items in the dataset, and we no longer differentiate them
121
+ while True:
122
+ if tt >= 10 or check_collision(all_indices_str):
123
+ break
124
+
125
+ collision_item_groups = get_collision_item(all_indices_str)
126
+ # print(collision_item_groups)
127
+ print(len(collision_item_groups))
128
+ for collision_items in collision_item_groups:
129
+ d = data[collision_items].to(device)
130
+
131
+ indices = model.get_indices(d, use_sk=True)
132
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
133
+ for item, index in zip(collision_items, indices):
134
+ code = []
135
+ for i, ind in enumerate(index):
136
+ code.append(prefix[i].format(int(ind)))
137
+
138
+ all_indices[item] = code
139
+ all_indices_str[item] = str(code)
140
+ tt += 1
141
+
142
+
143
+ print("All indices number: ",len(all_indices))
144
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
145
+
146
+ tot_item = len(all_indices_str)
147
+ tot_indice = len(set(all_indices_str.tolist()))
148
+ print("Collision Rate",(tot_item-tot_indice)/tot_item)
149
+
150
+ all_indices_dict = {}
151
+ for item, indices in enumerate(all_indices.tolist()):
152
+ all_indices_dict[item] = list(indices)
153
+
154
+ with open(infer_args.save_path, 'w') as fp:
155
+ json.dump(all_indices_dict, fp)
index/main.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ from time import time
6
+ import logging
7
+
8
+ from torch.utils.data import DataLoader
9
+
10
+ from datasets import EmbDataset
11
+ from models.rqvae import RQVAE
12
+ from trainer import Trainer
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(description="Index")
16
+
17
+ parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
18
+ parser.add_argument('--epochs', type=int, default=5000, help='number of epochs')
19
+ parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
20
+ parser.add_argument('--num_workers', type=int, default=4, )
21
+ parser.add_argument('--eval_step', type=int, default=50, help='eval step')
22
+ parser.add_argument('--learner', type=str, default="AdamW", help='optimizer')
23
+ parser.add_argument("--data_path", type=str,
24
+ default="../data/Games/Games.emb-llama-td.npy",
25
+ help="Input data path.")
26
+
27
+ parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight')
28
+ parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio")
29
+ parser.add_argument("--bn", type=bool, default=False, help="use bn or not")
30
+ parser.add_argument("--loss_type", type=str, default="mse", help="loss_type")
31
+ parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not")
32
+ parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters")
33
+ parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons")
34
+ parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters")
35
+
36
+ parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu")
37
+
38
+ parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq')
39
+ parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size')
40
+ parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight')
41
+ parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer')
42
+
43
+ parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model")
44
+
45
+ return parser.parse_args()
46
+
47
+
48
+ if __name__ == '__main__':
49
+ """fix the random seed"""
50
+ seed = 2023
51
+ random.seed(seed)
52
+ np.random.seed(seed)
53
+ torch.manual_seed(seed)
54
+ torch.cuda.manual_seed_all(seed)
55
+ torch.backends.cudnn.deterministic = True
56
+ torch.backends.cudnn.benchmark = False
57
+
58
+ args = parse_args()
59
+ print(args)
60
+
61
+ logging.basicConfig(level=logging.DEBUG)
62
+
63
+ """build dataset"""
64
+ data = EmbDataset(args.data_path)
65
+ model = RQVAE(in_dim=data.dim,
66
+ num_emb_list=args.num_emb_list,
67
+ e_dim=args.e_dim,
68
+ layers=args.layers,
69
+ dropout_prob=args.dropout_prob,
70
+ bn=args.bn,
71
+ loss_type=args.loss_type,
72
+ quant_loss_weight=args.quant_loss_weight,
73
+ kmeans_init=args.kmeans_init,
74
+ kmeans_iters=args.kmeans_iters,
75
+ sk_epsilons=args.sk_epsilons,
76
+ sk_iters=args.sk_iters,
77
+ )
78
+ print(model)
79
+ data_loader = DataLoader(data,num_workers=args.num_workers,
80
+ batch_size=args.batch_size, shuffle=True,
81
+ pin_memory=True)
82
+ trainer = Trainer(args,model)
83
+ best_loss, best_collision_rate = trainer.fit(data_loader)
84
+
85
+ print("Best Loss",best_loss)
86
+ print("Best Collision Rate", best_collision_rate)
87
+
index/models/layers.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.init import xavier_normal_
4
+ from sklearn.cluster import KMeans
5
+
6
+
7
+ class MLPLayers(nn.Module):
8
+
9
+ def __init__(
10
+ self, layers, dropout=0.0, activation="relu", bn=False
11
+ ):
12
+ super(MLPLayers, self).__init__()
13
+ self.layers = layers
14
+ self.dropout = dropout
15
+ self.activation = activation
16
+ self.use_bn = bn
17
+
18
+ mlp_modules = []
19
+ for idx, (input_size, output_size) in enumerate(
20
+ zip(self.layers[:-1], self.layers[1:])
21
+ ):
22
+ mlp_modules.append(nn.Dropout(p=self.dropout))
23
+ mlp_modules.append(nn.Linear(input_size, output_size))
24
+ if self.use_bn:
25
+ mlp_modules.append(nn.BatchNorm1d(num_features=output_size))
26
+ activation_func = activation_layer(self.activation, output_size)
27
+ if activation_func is not None and idx != (len(self.layers)-2):
28
+ mlp_modules.append(activation_func)
29
+
30
+ self.mlp_layers = nn.Sequential(*mlp_modules)
31
+ self.apply(self.init_weights)
32
+
33
+ def init_weights(self, module):
34
+ # We just initialize the module with normal distribution as the paper said
35
+ if isinstance(module, nn.Linear):
36
+ xavier_normal_(module.weight.data)
37
+ if module.bias is not None:
38
+ module.bias.data.fill_(0.0)
39
+
40
+ def forward(self, input_feature):
41
+ return self.mlp_layers(input_feature)
42
+
43
+ def activation_layer(activation_name="relu", emb_dim=None):
44
+
45
+ if activation_name is None:
46
+ activation = None
47
+ elif isinstance(activation_name, str):
48
+ if activation_name.lower() == "sigmoid":
49
+ activation = nn.Sigmoid()
50
+ elif activation_name.lower() == "tanh":
51
+ activation = nn.Tanh()
52
+ elif activation_name.lower() == "relu":
53
+ activation = nn.ReLU()
54
+ elif activation_name.lower() == "leakyrelu":
55
+ activation = nn.LeakyReLU()
56
+ elif activation_name.lower() == "none":
57
+ activation = None
58
+ elif issubclass(activation_name, nn.Module):
59
+ activation = activation_name()
60
+ else:
61
+ raise NotImplementedError(
62
+ "activation function {} is not implemented".format(activation_name)
63
+ )
64
+
65
+ return activation
66
+
67
+ def kmeans(
68
+ samples,
69
+ num_clusters,
70
+ num_iters = 10,
71
+ ):
72
+ B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device
73
+ x = samples.cpu().detach().numpy()
74
+
75
+ cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x)
76
+
77
+ centers = cluster.cluster_centers_
78
+ tensor_centers = torch.from_numpy(centers).to(device)
79
+
80
+ return tensor_centers
81
+
82
+
83
+ @torch.no_grad()
84
+ def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations):
85
+ Q = torch.exp(- distances / epsilon)
86
+
87
+ B = Q.shape[0] # number of samples to assign
88
+ K = Q.shape[1] # how many centroids per block (usually set to 256)
89
+
90
+ # make the matrix sums to 1
91
+ sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True)
92
+ Q /= sum_Q
93
+ # print(Q.sum())
94
+ for it in range(sinkhorn_iterations):
95
+
96
+ # normalize each column: total weight per sample must be 1/B
97
+ Q /= torch.sum(Q, dim=1, keepdim=True)
98
+ Q /= B
99
+
100
+ # normalize each row: total weight per prototype must be 1/K
101
+ Q /= torch.sum(Q, dim=0, keepdim=True)
102
+ Q /= K
103
+
104
+
105
+ Q *= B # the colomns must sum to 1 so that Q is an assignment
106
+ return Q
index/models/rq.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vq import VectorQuantizer
5
+
6
+
7
+ class ResidualVectorQuantizer(nn.Module):
8
+ """ References:
9
+ SoundStream: An End-to-End Neural Audio Codec
10
+ https://arxiv.org/pdf/2107.03312.pdf
11
+ """
12
+
13
+ def __init__(self, n_e_list, e_dim, sk_epsilons,
14
+ kmeans_init = False, kmeans_iters = 100, sk_iters=100,):
15
+ super().__init__()
16
+ self.n_e_list = n_e_list
17
+ self.e_dim = e_dim
18
+ self.num_quantizers = len(n_e_list)
19
+ self.kmeans_init = kmeans_init
20
+ self.kmeans_iters = kmeans_iters
21
+ self.sk_epsilons = sk_epsilons
22
+ self.sk_iters = sk_iters
23
+ self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim,
24
+ kmeans_init = self.kmeans_init,
25
+ kmeans_iters = self.kmeans_iters,
26
+ sk_epsilon=sk_epsilon,
27
+ sk_iters=sk_iters)
28
+ for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ])
29
+
30
+ def get_codebook(self):
31
+ all_codebook = []
32
+ for quantizer in self.vq_layers:
33
+ codebook = quantizer.get_codebook()
34
+ all_codebook.append(codebook)
35
+ return torch.stack(all_codebook)
36
+
37
+ def forward(self, x, use_sk=True):
38
+ all_losses = []
39
+ all_indices = []
40
+
41
+ x_q = 0
42
+ residual = x
43
+ for quantizer in self.vq_layers:
44
+ x_res, loss, indices = quantizer(residual, use_sk=use_sk)
45
+ residual = residual - x_res
46
+ x_q = x_q + x_res
47
+
48
+ all_losses.append(loss)
49
+ all_indices.append(indices)
50
+
51
+ mean_losses = torch.stack(all_losses).mean()
52
+ all_indices = torch.stack(all_indices, dim=-1)
53
+
54
+ return x_q, mean_losses, all_indices
index/models/rqvae.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from .layers import MLPLayers
7
+ from .rq import ResidualVectorQuantizer
8
+
9
+
10
+ class RQVAE(nn.Module):
11
+ def __init__(self,
12
+ in_dim=768,
13
+ # num_emb_list=[256,256,256,256],
14
+ num_emb_list=None,
15
+ e_dim=64,
16
+ # layers=[512,256,128],
17
+ layers=None,
18
+ dropout_prob=0.0,
19
+ bn=False,
20
+ loss_type="mse",
21
+ quant_loss_weight=1.0,
22
+ kmeans_init=False,
23
+ kmeans_iters=100,
24
+ # sk_epsilons=[0,0,0.003,0.01]],
25
+ sk_epsilons=None,
26
+ sk_iters=100,
27
+ ):
28
+ super(RQVAE, self).__init__()
29
+
30
+ self.in_dim = in_dim
31
+ self.num_emb_list = num_emb_list
32
+ self.e_dim = e_dim
33
+
34
+ self.layers = layers
35
+ self.dropout_prob = dropout_prob
36
+ self.bn = bn
37
+ self.loss_type = loss_type
38
+ self.quant_loss_weight=quant_loss_weight
39
+ self.kmeans_init = kmeans_init
40
+ self.kmeans_iters = kmeans_iters
41
+ self.sk_epsilons = sk_epsilons
42
+ self.sk_iters = sk_iters
43
+
44
+ self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim]
45
+ self.encoder = MLPLayers(layers=self.encode_layer_dims,
46
+ dropout=self.dropout_prob,bn=self.bn)
47
+
48
+ self.rq = ResidualVectorQuantizer(num_emb_list, e_dim,
49
+ kmeans_init = self.kmeans_init,
50
+ kmeans_iters = self.kmeans_iters,
51
+ sk_epsilons=self.sk_epsilons,
52
+ sk_iters=self.sk_iters,)
53
+
54
+ self.decode_layer_dims = self.encode_layer_dims[::-1]
55
+ self.decoder = MLPLayers(layers=self.decode_layer_dims,
56
+ dropout=self.dropout_prob,bn=self.bn)
57
+
58
+ def forward(self, x, use_sk=True):
59
+ x = self.encoder(x)
60
+ x_q, rq_loss, indices = self.rq(x,use_sk=use_sk)
61
+ out = self.decoder(x_q)
62
+
63
+ return out, rq_loss, indices
64
+
65
+ @torch.no_grad()
66
+ def get_indices(self, xs, use_sk=False):
67
+ x_e = self.encoder(xs)
68
+ _, _, indices = self.rq(x_e, use_sk=use_sk)
69
+ return indices
70
+
71
+ def compute_loss(self, out, quant_loss, xs=None):
72
+
73
+ if self.loss_type == 'mse':
74
+ loss_recon = F.mse_loss(out, xs, reduction='mean')
75
+ elif self.loss_type == 'l1':
76
+ loss_recon = F.l1_loss(out, xs, reduction='mean')
77
+ else:
78
+ raise ValueError('incompatible loss type')
79
+
80
+ loss_total = loss_recon + self.quant_loss_weight * quant_loss
81
+
82
+ return loss_total, loss_recon
index/models/vq.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .layers import kmeans, sinkhorn_algorithm
5
+
6
+
7
+ class VectorQuantizer(nn.Module):
8
+
9
+ def __init__(self, n_e, e_dim,
10
+ beta = 0.25, kmeans_init = False, kmeans_iters = 10,
11
+ sk_epsilon=0.01, sk_iters=100):
12
+ super().__init__()
13
+ self.n_e = n_e
14
+ self.e_dim = e_dim
15
+ self.beta = beta
16
+ self.kmeans_init = kmeans_init
17
+ self.kmeans_iters = kmeans_iters
18
+ self.sk_epsilon = sk_epsilon
19
+ self.sk_iters = sk_iters
20
+
21
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
22
+ if not kmeans_init:
23
+ self.initted = True
24
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
25
+ else:
26
+ self.initted = False
27
+ self.embedding.weight.data.zero_()
28
+
29
+ def get_codebook(self):
30
+ return self.embedding.weight
31
+
32
+ def get_codebook_entry(self, indices, shape=None):
33
+ # get quantized latent vectors
34
+ z_q = self.embedding(indices)
35
+ if shape is not None:
36
+ z_q = z_q.view(shape)
37
+
38
+ return z_q
39
+
40
+ def init_emb(self, data):
41
+
42
+ centers = kmeans(
43
+ data,
44
+ self.n_e,
45
+ self.kmeans_iters,
46
+ )
47
+
48
+ self.embedding.weight.data.copy_(centers)
49
+ self.initted = True
50
+
51
+ @staticmethod
52
+ def center_distance_for_constraint(distances):
53
+ # distances: B, K
54
+ max_distance = distances.max()
55
+ min_distance = distances.min()
56
+
57
+ middle = (max_distance + min_distance) / 2
58
+ amplitude = max_distance - middle + 1e-5
59
+ assert amplitude > 0
60
+ centered_distances = (distances - middle) / amplitude
61
+ return centered_distances
62
+
63
+ def forward(self, x, use_sk=True):
64
+ # Flatten input
65
+ latent = x.view(-1, self.e_dim)
66
+
67
+ if not self.initted and self.training:
68
+ self.init_emb(latent)
69
+
70
+ # Calculate the L2 Norm between latent and Embedded weights
71
+ d = torch.sum(latent**2, dim=1, keepdim=True) + \
72
+ torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \
73
+ 2 * torch.matmul(latent, self.embedding.weight.t())
74
+ if not use_sk or self.sk_epsilon <= 0:
75
+ indices = torch.argmin(d, dim=-1)
76
+ # print("=======",self.sk_epsilon)
77
+ else:
78
+ # print("++++++++",self.sk_epsilon)
79
+ d = self.center_distance_for_constraint(d)
80
+ d = d.double()
81
+ Q = sinkhorn_algorithm(d,self.sk_epsilon,self.sk_iters)
82
+ # print(Q.sum(0)[:10])
83
+ if torch.isnan(Q).any() or torch.isinf(Q).any():
84
+ print(f"Sinkhorn Algorithm returns nan/inf values.")
85
+ indices = torch.argmax(Q, dim=-1)
86
+
87
+ # indices = torch.argmin(d, dim=-1)
88
+
89
+ x_q = self.embedding(indices).view(x.shape)
90
+
91
+ # compute loss for embedding
92
+ commitment_loss = F.mse_loss(x_q.detach(), x)
93
+ codebook_loss = F.mse_loss(x_q, x.detach())
94
+ loss = codebook_loss + self.beta * commitment_loss
95
+
96
+ # preserve gradients
97
+ x_q = x + (x_q - x).detach()
98
+
99
+ indices = indices.view(x.shape[:-1])
100
+
101
+ return x_q, loss, indices
102
+
103
+
index/run.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA_PATH=/datain/v-yinju/rqvae-zzx/data/instruments-emb-llama.npy
2
+ CKPT_DIR=your_ckpt_save_directory # E.g., /datain/v-yinju/rqvae-zzx/model
3
+ mkdir -p $CKPT_DIR
4
+ python -u main.py \
5
+ --num_emb_list 256 256 256 256 \
6
+ --sk_epsilons 0.0 0.0 0.0 0.003 \
7
+ --lr 1e-3 \
8
+ --device cuda:0 \
9
+ --batch_size 1024 \
10
+ --data_path $DATA_PATH \
11
+ --ckpt_dir $CKPT_DIR
12
+
13
+ # Infer item index
14
+ # python generate_indices.py \
15
+ # --ckpt_path your_rqvae_model_path \ E.g., /datain/v-yinju/rqvae-zzx/model/20241127/best_collision_model.pth
16
+ # --data_path $DATA_PATH \
17
+ # --save_path your_index_save_path \ E.g., /datain/v-yinju/rqvae-zzx/model/20241127/indices.json
18
+ # --device cuda:0
index/trainer.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ from time import time
6
+ from torch import optim
7
+ from tqdm import tqdm
8
+
9
+ from utils import ensure_dir,set_color,get_local_time
10
+ import os
11
+
12
+ class Trainer(object):
13
+
14
+ def __init__(self, args, model):
15
+ self.args = args
16
+ self.model = model
17
+ self.logger = logging.getLogger()
18
+
19
+ self.lr = args.lr
20
+ self.learner = args.learner
21
+ self.weight_decay = args.weight_decay
22
+ self.epochs = args.epochs
23
+ self.eval_step = min(args.eval_step, self.epochs)
24
+ self.device = args.device
25
+ self.device = torch.device(self.device)
26
+ self.ckpt_dir = args.ckpt_dir
27
+ saved_model_dir = "{}".format(get_local_time())
28
+ self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir)
29
+ ensure_dir(self.ckpt_dir)
30
+
31
+ self.best_loss = np.inf
32
+ self.best_collision_rate = np.inf
33
+ self.best_loss_ckpt = "best_loss_model.pth"
34
+ self.best_collision_ckpt = "best_collision_model.pth"
35
+ self.optimizer = self._build_optimizer()
36
+ self.model = self.model.to(self.device)
37
+
38
+ def _build_optimizer(self):
39
+
40
+ params = self.model.parameters()
41
+ learner = self.learner
42
+ learning_rate = self.lr
43
+ weight_decay = self.weight_decay
44
+
45
+ if learner.lower() == "adam":
46
+ optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)
47
+ elif learner.lower() == "sgd":
48
+ optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
49
+ elif learner.lower() == "adagrad":
50
+ optimizer = optim.Adagrad(
51
+ params, lr=learning_rate, weight_decay=weight_decay
52
+ )
53
+ for state in optimizer.state.values():
54
+ for k, v in state.items():
55
+ if torch.is_tensor(v):
56
+ state[k] = v.to(self.device)
57
+ elif learner.lower() == "rmsprop":
58
+ optimizer = optim.RMSprop(
59
+ params, lr=learning_rate, weight_decay=weight_decay
60
+ )
61
+ elif learner.lower() == 'adamw':
62
+ optimizer = optim.AdamW(
63
+ params, lr=learning_rate, weight_decay=weight_decay
64
+ )
65
+ else:
66
+ self.logger.warning(
67
+ "Received unrecognized optimizer, set default Adam optimizer"
68
+ )
69
+ optimizer = optim.Adam(params, lr=learning_rate)
70
+ return optimizer
71
+ def _check_nan(self, loss):
72
+ if torch.isnan(loss):
73
+ raise ValueError("Training loss is nan")
74
+
75
+ def _train_epoch(self, train_data, epoch_idx):
76
+
77
+ self.model.train()
78
+
79
+ total_loss = 0
80
+ total_recon_loss = 0
81
+ iter_data = tqdm(
82
+ train_data,
83
+ total=len(train_data),
84
+ ncols=100,
85
+ desc=set_color(f"Train {epoch_idx}","pink"),
86
+ )
87
+
88
+ for batch_idx, data in enumerate(iter_data):
89
+ data = data.to(self.device)
90
+ self.optimizer.zero_grad()
91
+ out, rq_loss, indices = self.model(data)
92
+ loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data)
93
+ self._check_nan(loss)
94
+ loss.backward()
95
+ self.optimizer.step()
96
+ total_loss += loss.item()
97
+ total_recon_loss += loss_recon.item()
98
+
99
+ return total_loss, total_recon_loss
100
+
101
+ @torch.no_grad()
102
+ def _valid_epoch(self, valid_data):
103
+
104
+ self.model.eval()
105
+
106
+ iter_data =tqdm(
107
+ valid_data,
108
+ total=len(valid_data),
109
+ ncols=100,
110
+ desc=set_color(f"Evaluate ", "pink"),
111
+ )
112
+ indices_set = set()
113
+ num_sample = 0
114
+ for batch_idx, data in enumerate(iter_data):
115
+ num_sample += len(data)
116
+ data = data.to(self.device)
117
+ indices = self.model.get_indices(data)
118
+ indices = indices.view(-1,indices.shape[-1]).cpu().numpy()
119
+ for index in indices:
120
+ code = "-".join([str(int(_)) for _ in index])
121
+ indices_set.add(code)
122
+
123
+ collision_rate = (num_sample - len(indices_set))/num_sample
124
+
125
+ return collision_rate
126
+
127
+ def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None):
128
+
129
+ ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \
130
+ else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate))
131
+ state = {
132
+ "args": self.args,
133
+ "epoch": epoch,
134
+ "best_loss": self.best_loss,
135
+ "best_collision_rate": self.best_collision_rate,
136
+ "state_dict": self.model.state_dict(),
137
+ "optimizer": self.optimizer.state_dict(),
138
+ }
139
+ torch.save(state, ckpt_path, pickle_protocol=4)
140
+
141
+ self.logger.info(
142
+ set_color("Saving current", "blue") + f": {ckpt_path}"
143
+ )
144
+
145
+ def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss):
146
+ train_loss_output = (
147
+ set_color("epoch %d training", "green")
148
+ + " ["
149
+ + set_color("time", "blue")
150
+ + ": %.2fs, "
151
+ ) % (epoch_idx, e_time - s_time)
152
+ train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss
153
+ train_loss_output +=", "
154
+ train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss
155
+ return train_loss_output + "]"
156
+
157
+
158
+ def fit(self, data):
159
+
160
+ cur_eval_step = 0
161
+
162
+ for epoch_idx in range(self.epochs):
163
+ # train
164
+ training_start_time = time()
165
+ train_loss, train_recon_loss = self._train_epoch(data, epoch_idx)
166
+ training_end_time = time()
167
+ train_loss_output = self._generate_train_loss_output(
168
+ epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss
169
+ )
170
+ self.logger.info(train_loss_output)
171
+
172
+ if train_loss < self.best_loss:
173
+ self.best_loss = train_loss
174
+ # self._save_checkpoint(epoch=epoch_idx,ckpt_file=self.best_loss_ckpt)
175
+
176
+ # eval
177
+ if (epoch_idx + 1) % self.eval_step == 0:
178
+ valid_start_time = time()
179
+ collision_rate = self._valid_epoch(data)
180
+
181
+ if collision_rate < self.best_collision_rate:
182
+ self.best_collision_rate = collision_rate
183
+ cur_eval_step = 0
184
+ self._save_checkpoint(epoch_idx, collision_rate=collision_rate,
185
+ ckpt_file=self.best_collision_ckpt)
186
+ else:
187
+ cur_eval_step += 1
188
+
189
+
190
+ valid_end_time = time()
191
+ valid_score_output = (
192
+ set_color("epoch %d evaluating", "green")
193
+ + " ["
194
+ + set_color("time", "blue")
195
+ + ": %.2fs, "
196
+ + set_color("collision_rate", "blue")
197
+ + ": %f]"
198
+ ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate)
199
+
200
+ self.logger.info(valid_score_output)
201
+ if epoch_idx>1000:
202
+ self._save_checkpoint(epoch_idx, collision_rate=collision_rate)
203
+
204
+
205
+ return self.best_loss, self.best_collision_rate
206
+
207
+
208
+
209
+
index/utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import datetime
3
+ import os
4
+
5
+
6
+ def ensure_dir(dir_path):
7
+
8
+ os.makedirs(dir_path, exist_ok=True)
9
+
10
+ def set_color(log, color, highlight=True):
11
+ color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"]
12
+ try:
13
+ index = color_set.index(color)
14
+ except:
15
+ index = len(color_set) - 1
16
+ prev_log = "\033["
17
+ if highlight:
18
+ prev_log += "1;3"
19
+ else:
20
+ prev_log += "0;3"
21
+ prev_log += str(index) + "m"
22
+ return prev_log + log + "\033[0m"
23
+
24
+ def get_local_time():
25
+ r"""Get current time
26
+
27
+ Returns:
28
+ str: current time
29
+ """
30
+ cur = datetime.datetime.now()
31
+ cur = cur.strftime("%b-%d-%Y_%H-%M-%S")
32
+
33
+ return cur
34
+
35
+
36
+
instruments_eval.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET=Instruments
2
+ BASE=/datain/v-yinju/llama-7b
3
+ DATA_PATH=/datain/v-yinju/rqvae-zzx/data
4
+ CKPT_PATH=/datain/v-yinju/RQVAE_Bench/llama
5
+ RESULTS_FILE=$CKPT_PATH/result.json
6
+ INDEX=/datain/v-yinju/RQVAE_Bench/rqvae/Nov-27-2024_23-08-08/indices.json
7
+
8
+ torchrun --nproc_per_node=8 test_ddp.py \
9
+ --base_model $BASE \
10
+ --ckpt_path $CKPT_PATH \
11
+ --dataset $DATASET \
12
+ --data_path $DATA_PATH \
13
+ --results_file $RESULTS_FILE \
14
+ --test_batch_size 1 \
15
+ --num_beams 10 \
16
+ --test_prompt_ids all \
17
+ --index_file $INDEX
instruments_train.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export WANDB_MODE=disabled
2
+ export CUDA_LAUNCH_BLOCKING=1
3
+
4
+ DATASET=Instruments
5
+ BASE_MODEL=/datain/v-yinju/llama-7b
6
+ DATA_PATH=/datain/v-yinju/rqvae-zzx/data
7
+ INDEX=your_index_save_path
8
+ OUTPUT_DIR=your_ckpt_save_dir
9
+
10
+ mkdir -p $OUTPUT_DIR
11
+
12
+ torchrun --nproc_per_node=8 lora_finetune.py \
13
+ --base_model $BASE_MODEL \
14
+ --output_dir $OUTPUT_DIR \
15
+ --dataset $DATASET \
16
+ --data_path $DATA_PATH \
17
+ --per_device_batch_size 6 \
18
+ --gradient_accumulation_steps 2 \
19
+ --learning_rate 5e-5 \
20
+ --epochs 4 \
21
+ --weight_decay 0.01 \
22
+ --save_and_eval_strategy epoch \
23
+ --fp16 \
24
+ --deepspeed ./config/ds_z2_fp16.json \
25
+ --dataloader_num_workers 4 \
26
+ --only_train_response \
27
+ --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
28
+ --train_prompt_sample_num 1,1,1,1,1,1 \
29
+ --train_data_sample_num 0,0,0,0,0,0 \
30
+ --index_file $INDEX
31
+
32
+ cd convert
33
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
34
+ cd ..
lora_finetune.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from typing import List
5
+
6
+ import torch
7
+ import transformers
8
+
9
+
10
+ from peft import (
11
+ TaskType,
12
+ LoraConfig,
13
+ get_peft_model,
14
+ get_peft_model_state_dict,
15
+ set_peft_model_state_dict,
16
+ )
17
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
18
+
19
+ from utils import *
20
+ from collator import Collator
21
+
22
+ def train(args):
23
+
24
+ set_seed(args.seed)
25
+ ensure_dir(args.output_dir)
26
+
27
+ device_map = "auto"
28
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
29
+ ddp = world_size != 1
30
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
31
+ if local_rank == 0:
32
+ print(vars(args))
33
+
34
+ if ddp:
35
+ device_map = {"": local_rank}
36
+
37
+ config = LlamaConfig.from_pretrained(args.base_model)
38
+ tokenizer = LlamaTokenizer.from_pretrained(
39
+ args.base_model,
40
+ model_max_length=args.model_max_length,
41
+ padding_side="right",
42
+ )
43
+ tokenizer.pad_token_id = 0
44
+
45
+ train_data, valid_data = load_datasets(args)
46
+ add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
47
+ config.vocab_size = len(tokenizer)
48
+ if local_rank == 0:
49
+ print("add {} new token.".format(add_num))
50
+ print("data num:", len(train_data))
51
+ tokenizer.save_pretrained(args.output_dir)
52
+ config.save_pretrained(args.output_dir)
53
+
54
+ collator = Collator(args, tokenizer)
55
+
56
+ model = LlamaForCausalLM.from_pretrained(
57
+ args.base_model,
58
+ torch_dtype=torch.float16,
59
+ device_map=device_map,
60
+ )
61
+ model.resize_token_embeddings(len(tokenizer))
62
+
63
+ config = LoraConfig(
64
+ r=args.lora_r,
65
+ lora_alpha=args.lora_alpha,
66
+ target_modules=args.lora_target_modules.split(","),
67
+ modules_to_save=args.lora_modules_to_save.split(","),
68
+ lora_dropout=args.lora_dropout,
69
+ bias="none",
70
+ inference_mode=False,
71
+ task_type=TaskType.CAUSAL_LM,
72
+ )
73
+ model = get_peft_model(model, config)
74
+
75
+ if args.resume_from_checkpoint:
76
+ checkpoint_name = os.path.join(
77
+ args.resume_from_checkpoint, "adapter_model.bin"
78
+ ) # only LoRA model - LoRA config above has to fit
79
+ args.resume_from_checkpoint = False # So the trainer won't try loading its state
80
+ # The two files above have a different name depending on how they were saved, but are actually the same.
81
+ if os.path.exists(checkpoint_name):
82
+ if local_rank == 0:
83
+ print(f"Restarting from {checkpoint_name}")
84
+ adapters_weights = torch.load(checkpoint_name)
85
+ model = set_peft_model_state_dict(model, adapters_weights)
86
+ else:
87
+ if local_rank == 0:
88
+ print(f"Checkpoint {checkpoint_name} not found")
89
+
90
+ for n, p in model.named_parameters():
91
+ if "original_module" in n and any(module_name in n for module_name in config.modules_to_save):
92
+ p.requires_grad = False
93
+
94
+ if local_rank == 0:
95
+ model.print_trainable_parameters()
96
+
97
+
98
+ if not ddp and torch.cuda.device_count() > 1:
99
+ model.is_parallelizable = True
100
+ model.model_parallel = True
101
+
102
+ trainer = transformers.Trainer(
103
+ model=model,
104
+ train_dataset=train_data,
105
+ eval_dataset=valid_data,
106
+ args=transformers.TrainingArguments(
107
+ seed=args.seed,
108
+ per_device_train_batch_size=args.per_device_batch_size,
109
+ per_device_eval_batch_size=args.per_device_batch_size,
110
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
111
+ warmup_ratio=args.warmup_ratio,
112
+ num_train_epochs=args.epochs,
113
+ learning_rate=args.learning_rate,
114
+ weight_decay=args.weight_decay,
115
+ lr_scheduler_type=args.lr_scheduler_type,
116
+ fp16=args.fp16,
117
+ bf16=args.bf16,
118
+ logging_steps=args.logging_step,
119
+ optim=args.optim,
120
+ gradient_checkpointing=True,
121
+ evaluation_strategy=args.save_and_eval_strategy,
122
+ save_strategy=args.save_and_eval_strategy,
123
+ eval_steps=args.save_and_eval_steps,
124
+ save_steps=args.save_and_eval_steps,
125
+ output_dir=args.output_dir,
126
+ save_total_limit=5,
127
+ load_best_model_at_end=True,
128
+ deepspeed=args.deepspeed,
129
+ ddp_find_unused_parameters=False if ddp else None,
130
+ report_to=None,
131
+ eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000,
132
+ dataloader_num_workers = args.dataloader_num_workers,
133
+ dataloader_prefetch_factor = args.dataloader_prefetch_factor
134
+ ),
135
+ tokenizer=tokenizer,
136
+ data_collator=collator,
137
+ )
138
+ model.config.use_cache = False
139
+
140
+ # old_state_dict = model.state_dict
141
+ # model.state_dict = (
142
+ # lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
143
+ # ).__get__(model, type(model))
144
+
145
+ if torch.__version__ >= "2" and sys.platform != "win32":
146
+ model = torch.compile(model)
147
+
148
+ trainer.train(
149
+ resume_from_checkpoint=args.resume_from_checkpoint,
150
+ )
151
+
152
+ trainer.save_state()
153
+ trainer.save_model(output_dir=args.output_dir)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ parser = argparse.ArgumentParser(description='LLMRec')
158
+ parser = parse_global_args(parser)
159
+ parser = parse_train_args(parser)
160
+ parser = parse_dataset_args(parser)
161
+
162
+ args = parser.parse_args()
163
+
164
+ train(args)
prompt.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ sft_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." \
4
+ "\n\n### Instruction:\n{instruction}\n\n### Response:{response}"
5
+
6
+
7
+
8
+
9
+
10
+
11
+
12
+ all_prompt = {}
13
+
14
+ # =====================================================
15
+ # Task 1 -- Sequential Recommendation -- 17 Prompt
16
+ # =====================================================
17
+
18
+ seqrec_prompt = []
19
+
20
+ #####——0
21
+ prompt = {}
22
+ prompt["instruction"] = "The user has interacted with items {inters} in chronological order. Can you predict the next possible item that the user may expect?"
23
+ prompt["response"] = "{item}"
24
+ seqrec_prompt.append(prompt)
25
+
26
+ #####——1
27
+ prompt = {}
28
+ prompt["instruction"] = "I find the user's historical interactive items: {inters}, and I want to know what next item the user needs. Can you help me decide?"
29
+ prompt["response"] = "{item}"
30
+ seqrec_prompt.append(prompt)
31
+
32
+ #####——2
33
+ prompt = {}
34
+ prompt["instruction"] = "Here are the user's historical interactions: {inters}, try to recommend another item to the user. Note that the historical interactions are arranged in chronological order."
35
+ prompt["response"] = "{item}"
36
+ seqrec_prompt.append(prompt)
37
+
38
+ #####——3
39
+ prompt = {}
40
+ prompt["instruction"] = "Based on the items that the user has interacted with: {inters}, can you determine what item would be recommended to him next?"
41
+ prompt["response"] = "{item}"
42
+ seqrec_prompt.append(prompt)
43
+
44
+ #####——4
45
+ prompt = {}
46
+ prompt["instruction"] = "The user has interacted with the following items in order: {inters}. What else do you think the user need?"
47
+ prompt["response"] = "{item}"
48
+ seqrec_prompt.append(prompt)
49
+
50
+ #####——5
51
+ prompt = {}
52
+ prompt["instruction"] = "Here is the item interaction history of the user: {inters}, what to recommend to the user next?"
53
+ prompt["response"] = "{item}"
54
+ seqrec_prompt.append(prompt)
55
+
56
+ #####——6
57
+ prompt = {}
58
+ prompt["instruction"] = "Which item would the user be likely to interact with next after interacting with items {inters}?"
59
+ prompt["response"] = "{item}"
60
+ seqrec_prompt.append(prompt)
61
+
62
+ #####——7
63
+ prompt = {}
64
+ prompt["instruction"] = "By analyzing the user's historical interactions with items {inters}, what is the next expected interaction item?"
65
+ prompt["response"] = "{item}"
66
+ seqrec_prompt.append(prompt)
67
+
68
+ #####——8
69
+ prompt = {}
70
+ prompt["instruction"] = "After interacting with items {inters}, what is the next item that could be recommended for the user?"
71
+ prompt["response"] = "{item}"
72
+ seqrec_prompt.append(prompt)
73
+
74
+ #####——9
75
+ prompt = {}
76
+ prompt["instruction"] = "Given the user's historical interactive items arranged in chronological order: {inters}, can you recommend a suitable item for the user?"
77
+ prompt["response"] = "{item}"
78
+ seqrec_prompt.append(prompt)
79
+
80
+ #####——10
81
+ prompt = {}
82
+ prompt["instruction"] = "Considering the user has interacted with items {inters}. What is the next recommendation for the user?"
83
+ prompt["response"] = "{item}"
84
+ seqrec_prompt.append(prompt)
85
+
86
+ #####——11
87
+ prompt = {}
88
+ prompt["instruction"] = "What is the top recommended item for the user who has previously interacted with items {inters} in order?"
89
+ prompt["response"] = "{item}"
90
+ seqrec_prompt.append(prompt)
91
+
92
+ #####——12
93
+ prompt = {}
94
+ prompt["instruction"] = "The user has interacted with the following items in the past in order: {inters}. Please predict the next item that the user most desires based on the given interaction records."
95
+ prompt["response"] = "{item}"
96
+ seqrec_prompt.append(prompt)
97
+
98
+ # prompt = {}
99
+ # prompt["instruction"] = "The user has interacted with the following items in the past in order: {inters}. Please predict the next item that the user is most likely to interact with based on the given interaction record. Note that his most recently interacted item is {}."
100
+ # prompt["response"] = "{item}"
101
+ # prompt["task"] = "sequential"
102
+ # prompt["id"] = "1-13"
103
+ #
104
+ # seqrec_prompt.append(prompt)
105
+
106
+ #####——13
107
+ prompt = {}
108
+ prompt["instruction"] = "Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy. The historical interactions are provided as follows: {inters}."
109
+ prompt["response"] = "{item}"
110
+ seqrec_prompt.append(prompt)
111
+
112
+ #####——14
113
+ prompt = {}
114
+ prompt["instruction"] = "You can access the user's historical item interaction records: {inters}. Now your task is to recommend the next potential item to him, considering his past interactions."
115
+ prompt["response"] = "{item}"
116
+ seqrec_prompt.append(prompt)
117
+
118
+ #####——15
119
+ prompt = {}
120
+ prompt["instruction"] = "You have observed that the user has interacted with the following items: {inters}, please recommend a next item that you think would be suitable for the user."
121
+ prompt["response"] = "{item}"
122
+ seqrec_prompt.append(prompt)
123
+
124
+ #####——16
125
+ prompt = {}
126
+ prompt["instruction"] = "You have obtained the ordered list of user historical interaction items, which is as follows: {inters}. Using this history as a reference, please select the next item to recommend to the user."
127
+ prompt["response"] = "{item}"
128
+ seqrec_prompt.append(prompt)
129
+
130
+ all_prompt["seqrec"] = seqrec_prompt
131
+
132
+
133
+
134
+ # ========================================================
135
+ # Task 2 -- Item2Index -- 19 Prompt
136
+ # ========================================================
137
+ # Remove periods when inputting
138
+
139
+ item2index_prompt = []
140
+
141
+ # ========================================================
142
+ # Title2Index
143
+
144
+ #####——0
145
+ prompt = {}
146
+ prompt["instruction"] = "Which item has the title: \"{title}\"?"
147
+ prompt["response"] = "{item}"
148
+ item2index_prompt.append(prompt)
149
+
150
+ #####——1
151
+ prompt = {}
152
+ prompt["instruction"] = "Which item is assigned the title: \"{title}\"?"
153
+ prompt["response"] = "{item}"
154
+ item2index_prompt.append(prompt)
155
+
156
+ #####——2
157
+ prompt = {}
158
+ prompt["instruction"] = "An item is called \"{title}\", could you please let me know which item it is?"
159
+ prompt["response"] = "{item}"
160
+ item2index_prompt.append(prompt)
161
+
162
+ #####——3
163
+ prompt = {}
164
+ prompt["instruction"] = "Which item is called \"{title}\"?"
165
+ prompt["response"] = "{item}"
166
+ item2index_prompt.append(prompt)
167
+
168
+ #####——4
169
+ prompt = {}
170
+ prompt["instruction"] = "One of the items is named \"{title}\", can you tell me which item this is?"
171
+ prompt["response"] = "{item}"
172
+ item2index_prompt.append(prompt)
173
+
174
+ #####——5
175
+ prompt = {}
176
+ prompt["instruction"] = "What is the item that goes by the title \"{title}\"?"
177
+ prompt["response"] = "{item}"
178
+ item2index_prompt.append(prompt)
179
+
180
+ # prompt = {}
181
+ # prompt["instruction"] = "Which item is referred to as \"{title}\"?"
182
+ # prompt["response"] = "{item}"
183
+ # item2index_prompt.append(prompt)
184
+
185
+ # ========================================================
186
+ # Description2Index
187
+
188
+ #####——6
189
+ prompt = {}
190
+ prompt["instruction"] = "An item can be described as follows: \"{description}\". Which item is it describing?"
191
+ prompt["response"] = "{item}"
192
+ item2index_prompt.append(prompt)
193
+
194
+ #####——7
195
+ prompt = {}
196
+ prompt["instruction"] = "Can you tell me what item is described as \"{description}\"?"
197
+ prompt["response"] = "{item}"
198
+ item2index_prompt.append(prompt)
199
+
200
+ #####——8
201
+ prompt = {}
202
+ prompt["instruction"] = "Can you provide the item that corresponds to the following description: \"{description}\"?"
203
+ prompt["response"] = "{item}"
204
+ item2index_prompt.append(prompt)
205
+
206
+
207
+ # prompt = {}
208
+ # prompt["instruction"] = "What is the item described as follows: \"{description}\"?"
209
+ # prompt["response"] = "{item}"
210
+ # item2index_prompt.append(prompt)
211
+
212
+ #####——9
213
+ prompt = {}
214
+ prompt["instruction"] = "Which item has the following characteristics: \"{description}\"?"
215
+ prompt["response"] = "{item}"
216
+ item2index_prompt.append(prompt)
217
+
218
+ #####——10
219
+ prompt = {}
220
+ prompt["instruction"] = "Which item is characterized by the following description: \"{description}\"?"
221
+ prompt["response"] = "{item}"
222
+ item2index_prompt.append(prompt)
223
+
224
+ #####——11
225
+ prompt = {}
226
+ prompt["instruction"] = "I am curious to know which item can be described as follows: \"{description}\". Can you tell me?"
227
+ prompt["response"] = "{item}"
228
+ item2index_prompt.append(prompt)
229
+
230
+ # ========================================================
231
+ # Title and Description to index
232
+
233
+ #####——12
234
+ prompt = {}
235
+ prompt["instruction"] = "An item is called \"{title}\" and described as \"{description}\", can you tell me which item it is?"
236
+ prompt["response"] = "{item}"
237
+ item2index_prompt.append(prompt)
238
+
239
+ #####——13
240
+ prompt = {}
241
+ prompt["instruction"] = "Could you please identify what item is called \"{title}\" and described as \"{description}\"?"
242
+ prompt["response"] = "{item}"
243
+ item2index_prompt.append(prompt)
244
+
245
+ #####——14
246
+ prompt = {}
247
+ prompt["instruction"] = "Which item is called \"{title}\" and has the characteristics described below: \"{description}\"?"
248
+ prompt["response"] = "{item}"
249
+ item2index_prompt.append(prompt)
250
+
251
+ #####——15
252
+ prompt = {}
253
+ prompt["instruction"] = "Please show me which item is named \"{title}\" and its corresponding description is: \"{description}\"."
254
+ prompt["response"] = "{item}"
255
+ item2index_prompt.append(prompt)
256
+
257
+
258
+ # prompt = {}
259
+ # prompt["instruction"] = "Here is an item called \"{title}\" and described as \"{description}\". Which item is it?"
260
+ # prompt["response"] = "{item}"
261
+ # item2index_prompt.append(prompt)
262
+
263
+ #####——16
264
+ prompt = {}
265
+ prompt["instruction"] = "Determine which item this is by its title and description. The title is: \"{title}\", and the description is: \"{description}\"."
266
+ prompt["response"] = "{item}"
267
+ item2index_prompt.append(prompt)
268
+
269
+ #####——17
270
+ prompt = {}
271
+ prompt["instruction"] = "Based on the title: \"{title}\", and the description: \"{description}\", answer which item is this?"
272
+ prompt["response"] = "{item}"
273
+ item2index_prompt.append(prompt)
274
+
275
+ #####——18
276
+ prompt = {}
277
+ prompt["instruction"] = "Can you identify the item from the provided title: \"{title}\", and description: \"{description}\"?"
278
+ prompt["response"] = "{item}"
279
+ item2index_prompt.append(prompt)
280
+
281
+ all_prompt["item2index"] = item2index_prompt
282
+
283
+
284
+ # ========================================================
285
+ # Task 3 -- Index2Item --17 Prompt
286
+ # ========================================================
287
+ # Remove periods when inputting
288
+
289
+ index2item_prompt = []
290
+
291
+ # ========================================================
292
+ # Index2Title
293
+
294
+ #####——0
295
+ prompt = {}
296
+ prompt["instruction"] = "What is the title of item {item}?"
297
+ prompt["response"] = "{title}"
298
+ index2item_prompt.append(prompt)
299
+
300
+ #####——1
301
+ prompt = {}
302
+ prompt["instruction"] = "What title is assigned to item {item}?"
303
+ prompt["response"] = "{title}"
304
+ index2item_prompt.append(prompt)
305
+
306
+ #####——2
307
+ prompt = {}
308
+ prompt["instruction"] = "Could you please tell me what item {item} is called?"
309
+ prompt["response"] = "{title}"
310
+ index2item_prompt.append(prompt)
311
+
312
+ #####——3
313
+ prompt = {}
314
+ prompt["instruction"] = "Can you provide the title of item {item}?"
315
+ prompt["response"] = "{title}"
316
+ index2item_prompt.append(prompt)
317
+
318
+ #####——4
319
+ prompt = {}
320
+ prompt["instruction"] = "What item {item} is referred to as?"
321
+ prompt["response"] = "{title}"
322
+ index2item_prompt.append(prompt)
323
+
324
+ #####——5
325
+ prompt = {}
326
+ prompt["instruction"] = "Would you mind informing me about the title of item {item}?"
327
+ prompt["response"] = "{title}"
328
+ index2item_prompt.append(prompt)
329
+
330
+ # ========================================================
331
+ # Index2Description
332
+
333
+ #####——6
334
+ prompt = {}
335
+ prompt["instruction"] = "Please provide a description of item {item}."
336
+ prompt["response"] = "{description}"
337
+ index2item_prompt.append(prompt)
338
+
339
+ #####——7
340
+ prompt = {}
341
+ prompt["instruction"] = "Briefly describe item {item}."
342
+ prompt["response"] = "{description}"
343
+ index2item_prompt.append(prompt)
344
+
345
+ #####——8
346
+ prompt = {}
347
+ prompt["instruction"] = "Can you share with me the description corresponding to item {item}?"
348
+ prompt["response"] = "{description}"
349
+ index2item_prompt.append(prompt)
350
+
351
+ #####——9
352
+ prompt = {}
353
+ prompt["instruction"] = "What is the description of item {item}?"
354
+ prompt["response"] = "{description}"
355
+ index2item_prompt.append(prompt)
356
+
357
+ #####——10
358
+ prompt = {}
359
+ prompt["instruction"] = "How to describe the characteristics of item {item}?"
360
+ prompt["response"] = "{description}"
361
+ index2item_prompt.append(prompt)
362
+
363
+ #####——11
364
+ prompt = {}
365
+ prompt["instruction"] = "Could you please tell me what item {item} looks like?"
366
+ prompt["response"] = "{description}"
367
+ index2item_prompt.append(prompt)
368
+
369
+
370
+ # ========================================================
371
+ # index to Title and Description
372
+
373
+ #####——12
374
+ prompt = {}
375
+ prompt["instruction"] = "What is the title and description of item {item}?"
376
+ prompt["response"] = "{title}\n\n{description}"
377
+ index2item_prompt.append(prompt)
378
+
379
+ #####——13
380
+ prompt = {}
381
+ prompt["instruction"] = "Can you provide the corresponding title and description for item {item}?"
382
+ prompt["response"] = "{title}\n\n{description}"
383
+ index2item_prompt.append(prompt)
384
+
385
+ #####——14
386
+ prompt = {}
387
+ prompt["instruction"] = "Please tell me what item {item} is called, along with a brief description of it."
388
+ prompt["response"] = "{title}\n\n{description}"
389
+ index2item_prompt.append(prompt)
390
+
391
+ #####——15
392
+ prompt = {}
393
+ prompt["instruction"] = "Would you mind informing me about the title of the item {item} and how to describe its characteristics?"
394
+ prompt["response"] = "{title}\n\n{description}"
395
+ index2item_prompt.append(prompt)
396
+
397
+ #####——16
398
+ prompt = {}
399
+ prompt["instruction"] = "I need to know the title and description of item {item}. Could you help me with that?"
400
+ prompt["response"] = "{title}\n\n{description}"
401
+ index2item_prompt.append(prompt)
402
+
403
+ all_prompt["index2item"] = index2item_prompt
404
+
405
+
406
+
407
+
408
+
409
+ # ========================================================
410
+ # Task 4 -- FusionSequentialRec -- Prompt
411
+ # ========================================================
412
+
413
+
414
+ fusionseqrec_prompt = []
415
+
416
+ #####——0
417
+ prompt = {}
418
+ prompt["instruction"] = "The user has sequentially interacted with items {inters}. Can you recommend the next item for him? Tell me the title of the item?"
419
+ prompt["response"] = "{title}"
420
+ fusionseqrec_prompt.append(prompt)
421
+
422
+ #####——1
423
+ prompt = {}
424
+ prompt["instruction"] = "Based on the user's historical interactions: {inters}, try to predict the title of the item that the user may need next."
425
+ prompt["response"] = "{title}"
426
+ fusionseqrec_prompt.append(prompt)
427
+
428
+ #####——2
429
+ prompt = {}
430
+ prompt["instruction"] = "Utilizing the user's past ordered interactions, which include items {inters}, please recommend the next item you think is suitable for the user and provide its title."
431
+ prompt["response"] = "{title}"
432
+ fusionseqrec_prompt.append(prompt)
433
+
434
+
435
+ #####——3
436
+ prompt = {}
437
+ prompt["instruction"] = "After interacting with items {inters}, what is the most probable item for the user to interact with next? Kindly provide the item's title."
438
+ prompt["response"] = "{title}"
439
+ fusionseqrec_prompt.append(prompt)
440
+
441
+
442
+
443
+
444
+
445
+ #####——4
446
+ prompt = {}
447
+ prompt["instruction"] = "Please review the user's historical interactions: {inters}, and describe what kind of item he still needs."
448
+ prompt["response"] = "{description}"
449
+ fusionseqrec_prompt.append(prompt)
450
+
451
+ #####——5
452
+ prompt = {}
453
+ prompt["instruction"] = "Here is the item interaction history of the user: {inters}, please tell me what features he expects from his next item."
454
+ prompt["response"] = "{description}"
455
+ fusionseqrec_prompt.append(prompt)
456
+
457
+ #####——6
458
+ prompt = {}
459
+ prompt["instruction"] = "By analyzing the user's historical interactions with items {inters}, can you infer what the user's next interactive item will look like?"
460
+ prompt["response"] = "{description}"
461
+ fusionseqrec_prompt.append(prompt)
462
+
463
+ #####——7
464
+ prompt = {}
465
+ prompt["instruction"] = "Access the user's historical item interaction records: {inters}. Your objective is to describe the next potential item for him, taking into account his past interactions."
466
+ prompt["response"] = "{description}"
467
+ fusionseqrec_prompt.append(prompt)
468
+
469
+
470
+
471
+
472
+
473
+
474
+ #####——8
475
+ prompt = {}
476
+ prompt["instruction"] = "Given the title sequence of user historical interactive items: {inter_titles}, can you recommend a suitable next item for the user?"
477
+ prompt["response"] = "{item}"
478
+ fusionseqrec_prompt.append(prompt)
479
+
480
+ #####——9
481
+ prompt = {}
482
+ prompt["instruction"] = "I possess a user's past interaction history, denoted by the title sequence of interactive items: {inter_titles}, and I am interested in knowing the user's next most desired item. Can you help me?"
483
+ prompt["response"] = "{item}"
484
+ fusionseqrec_prompt.append(prompt)
485
+
486
+ #####——10
487
+ prompt = {}
488
+ prompt["instruction"] = "Considering the title sequence of user history interaction items: {inter_titles}. What is the next recommendation for the user?"
489
+ prompt["response"] = "{item}"
490
+ fusionseqrec_prompt.append(prompt)
491
+
492
+ #####——11
493
+ prompt = {}
494
+ prompt["instruction"] = "You have obtained the ordered title list of user historical interaction items, as follows: {inter_titles}. Based on this historical context, kindly choose the subsequent item for user recommendation."
495
+ prompt["response"] = "{item}"
496
+ fusionseqrec_prompt.append(prompt)
497
+
498
+
499
+ all_prompt["fusionseqrec"] = fusionseqrec_prompt
500
+
501
+
502
+
503
+
504
+
505
+
506
+
507
+ # ========================================================
508
+ # Task 5 -- ItemSearch -- Prompt
509
+ # ========================================================
510
+
511
+
512
+ itemsearch_prompt = []
513
+
514
+ #####——0
515
+ prompt = {}
516
+ prompt["instruction"] = "Here is the historical interactions of a user: {inters}. And his personalized preferences are as follows: \"{explicit_preference}\". Your task is to recommend an item that is consistent with the user's preference."
517
+ prompt["response"] = "{item}"
518
+ itemsearch_prompt.append(prompt)
519
+
520
+ #####——1
521
+ prompt = {}
522
+ prompt["instruction"] = "The user has interacted with a list of items, which are as follows: {inters}. Based on these interacted items, the user current intent is as follows \"{user_related_intention}\", and your task is to generate an item that matches the user's current intent."
523
+ prompt["response"] = "{item}"
524
+ itemsearch_prompt.append(prompt)
525
+
526
+ #####——2
527
+ prompt = {}
528
+ prompt["instruction"] = "As a recommender system, you are assisting a user who has recently interacted with the following items: {inters}. The user expresses a desire to obtain another item with the following characteristics: \"{item_related_intention}\". Please recommend an item that meets these criteria."
529
+ prompt["response"] = "{item}"
530
+ itemsearch_prompt.append(prompt)
531
+
532
+ #####——3
533
+ prompt = {}
534
+ prompt["instruction"] = "Using the user's current query: \"{query}\" and his historical interactions: {inters}, you can estimate the user's preferences \"{explicit_preference}\". Please respond to the user's query by selecting an item that best matches his preference and query."
535
+ prompt["response"] = "{item}"
536
+ itemsearch_prompt.append(prompt)
537
+
538
+ #####——4
539
+ prompt = {}
540
+ prompt["instruction"] = "The user needs a new item and searches for: \"{query}\". In addition, he has previously interacted with: {inters}. You can obtain his preference by analyzing his historical interactions: \"{explicit_preference}\". Can you recommend an item that best matches the search query and preferences?"
541
+ prompt["response"] = "{item}"
542
+ itemsearch_prompt.append(prompt)
543
+
544
+ #####——5
545
+ prompt = {}
546
+ prompt["instruction"] = "Based on the user's historical interactions with the following items: {inters}. You can infer his preference by observing the historical interactions: \"{explicit_preference}\". Now the user wants a new item and searches for: \"{query}\". Please select a suitable item that matches his preference and search intent."
547
+ prompt["response"] = "{item}"
548
+ itemsearch_prompt.append(prompt)
549
+
550
+
551
+
552
+
553
+
554
+ #####——6
555
+ prompt = {}
556
+ prompt["instruction"] = "Suppose you are a search engine, now a user searches that: \"{query}\", can you select an item to respond to the user's query?"
557
+ prompt["response"] = "{item}"
558
+ itemsearch_prompt.append(prompt)
559
+
560
+ #####——7
561
+ prompt = {}
562
+ prompt["instruction"] = "As a search engine, your task is to answer the user's query by generating a related item. The user's query is provided as \"{query}\". Please provide your generated item as your answer."
563
+ prompt["response"] = "{item}"
564
+ itemsearch_prompt.append(prompt)
565
+
566
+ #####——8
567
+ prompt = {}
568
+ prompt["instruction"] = "As a recommender system, your task is to recommend an item that is related to the user's request, which is specified as follows: \"{query}\". Please provide your recommendation."
569
+ prompt["response"] = "{item}"
570
+ itemsearch_prompt.append(prompt)
571
+
572
+ #####——9
573
+ prompt = {}
574
+ prompt["instruction"] = "You meet a user's query: \"{query}\". Please respond to this user by selecting an appropriate item."
575
+ prompt["response"] = "{item}"
576
+ itemsearch_prompt.append(prompt)
577
+
578
+
579
+ #####——10
580
+ prompt = {}
581
+ prompt["instruction"] = "Your task is to recommend the best item that matches the user's query. Here is the search query of the user: \"{query}\", tell me the item you recommend."
582
+ prompt["response"] = "{item}"
583
+ itemsearch_prompt.append(prompt)
584
+
585
+ all_prompt["itemsearch"] = itemsearch_prompt
586
+
587
+
588
+
589
+
590
+
591
+ # ========================================================
592
+ # Task 6 -- PreferenceObtain -- Prompt
593
+ # ========================================================
594
+
595
+ preferenceobtain_prompt = []
596
+
597
+ #####——0
598
+ prompt = {}
599
+ prompt["instruction"] = "The user has interacted with items {inters} in chronological order. Please estimate his preferences."
600
+ prompt["response"] = "{explicit_preference}"
601
+ preferenceobtain_prompt.append(prompt)
602
+
603
+ #####——1
604
+ prompt = {}
605
+ prompt["instruction"] = "Based on the items that the user has interacted with: {inters}, can you infer what preferences he has?"
606
+ prompt["response"] = "{explicit_preference}"
607
+ preferenceobtain_prompt.append(prompt)
608
+
609
+ #####——3
610
+ prompt = {}
611
+ prompt["instruction"] = "Can you provide a summary of the user's preferences based on his historical interactions: {inters}?"
612
+ prompt["response"] = "{explicit_preference}"
613
+ preferenceobtain_prompt.append(prompt)
614
+
615
+ #####——4
616
+ prompt = {}
617
+ prompt["instruction"] = "After interacting with items {inters} in order, what preferences do you think the user has?"
618
+ prompt["response"] = "{explicit_preference}"
619
+ preferenceobtain_prompt.append(prompt)
620
+
621
+ #####——5
622
+ prompt = {}
623
+ prompt["instruction"] = "Here is the item interaction history of the user: {inters}, could you please infer the user's preferences."
624
+ prompt["response"] = "{explicit_preference}"
625
+ preferenceobtain_prompt.append(prompt)
626
+
627
+ #####——6
628
+ prompt = {}
629
+ prompt["instruction"] = "Based on the user's historical interaction records: {inters}, what are your speculations about his preferences?"
630
+ prompt["response"] = "{explicit_preference}"
631
+ preferenceobtain_prompt.append(prompt)
632
+
633
+ #####——7
634
+ prompt = {}
635
+ prompt["instruction"] = "Given the user's historical interactive items arranged in chronological order: {inters}, what can be inferred about the preferences of the user?"
636
+ prompt["response"] = "{explicit_preference}"
637
+ preferenceobtain_prompt.append(prompt)
638
+
639
+ #####——8
640
+ prompt = {}
641
+ prompt["instruction"] = "Can you speculate on the user's preferences based on his historical item interaction records: {inters}?"
642
+ prompt["response"] = "{explicit_preference}"
643
+ preferenceobtain_prompt.append(prompt)
644
+
645
+ #####——9
646
+ prompt = {}
647
+ prompt["instruction"] = "What is the preferences of a user who has previously interacted with items {inters} sequentially?"
648
+ prompt["response"] = "{explicit_preference}"
649
+ preferenceobtain_prompt.append(prompt)
650
+
651
+ #####——10
652
+ prompt = {}
653
+ prompt["instruction"] = "Using the user's historical interactions as input data, summarize the user's preferences. The historical interactions are provided as follows: {inters}."
654
+ prompt["response"] = "{explicit_preference}"
655
+ preferenceobtain_prompt.append(prompt)
656
+
657
+ #####——11
658
+ prompt = {}
659
+ prompt["instruction"] = "Utilizing the ordered list of the user's historical interaction items as a reference, please make an informed estimation of the user's preferences. The historical interactions are as follows: {inters}."
660
+ prompt["response"] = "{explicit_preference}"
661
+ preferenceobtain_prompt.append(prompt)
662
+
663
+ all_prompt["preferenceobtain"] = preferenceobtain_prompt
run.sh ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export WANDB_MODE=disabled
2
+ export CUDA_LAUNCH_BLOCKING=1
3
+
4
+ DATASET=Games
5
+ BASE_MODEL=huggyllama/llama-7b
6
+ DATA_PATH=./data
7
+ OUTPUT_DIR=./ckpt/$DATASET/
8
+
9
+ torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
10
+ --base_model $BASE_MODEL \
11
+ --output_dir $OUTPUT_DIR \
12
+ --dataset $DATASET \
13
+ --data_path $DATA_PATH \
14
+ --per_device_batch_size 8 \
15
+ --gradient_accumulation_steps 2 \
16
+ --learning_rate 5e-5 \
17
+ --epochs 4 \
18
+ --weight_decay 0.01 \
19
+ --save_and_eval_strategy epoch \
20
+ --deepspeed ./config/ds_z3_bf16.json \
21
+ --bf16 \
22
+ --only_train_response \
23
+ --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
24
+ --train_prompt_sample_num 1,1,1,1,1,1 \
25
+ --train_data_sample_num 0,0,0,100000,0,0 \
26
+ --index_file .index.json
27
+
28
+
29
+ cd convert
30
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
31
+ cd ..
32
+
33
+
34
+
35
+
36
+
37
+
38
+ DATASET=Arts
39
+ BASE_MODEL=huggyllama/llama-7b
40
+ DATA_PATH=./data
41
+ OUTPUT_DIR=./ckpt/$DATASET/
42
+
43
+ torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
44
+ --base_model $BASE_MODEL \
45
+ --output_dir $OUTPUT_DIR \
46
+ --dataset $DATASET \
47
+ --data_path $DATA_PATH \
48
+ --per_device_batch_size 8 \
49
+ --gradient_accumulation_steps 2 \
50
+ --learning_rate 5e-5 \
51
+ --epochs 4 \
52
+ --weight_decay 0.01 \
53
+ --save_and_eval_strategy epoch \
54
+ --deepspeed ./config/ds_z3_bf16.json \
55
+ --bf16 \
56
+ --only_train_response \
57
+ --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
58
+ --train_prompt_sample_num 1,1,1,1,1,1 \
59
+ --train_data_sample_num 0,0,0,30000,0,0 \
60
+ --index_file .index.json
61
+
62
+
63
+ cd convert
64
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
65
+ cd ..
66
+
67
+
68
+
69
+
70
+
71
+ DATASET=Instruments
72
+ BASE_MODEL=huggyllama/llama-7b
73
+ DATA_PATH=./data
74
+ OUTPUT_DIR=./ckpt/$DATASET/
75
+
76
+ torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
77
+ --base_model $BASE_MODEL \
78
+ --output_dir $OUTPUT_DIR \
79
+ --dataset $DATASET \
80
+ --data_path $DATA_PATH \
81
+ --per_device_batch_size 8 \
82
+ --gradient_accumulation_steps 2 \
83
+ --learning_rate 5e-5 \
84
+ --epochs 4 \
85
+ --weight_decay 0.01 \
86
+ --save_and_eval_strategy epoch \
87
+ --deepspeed ./config/ds_z3_bf16.json \
88
+ --bf16 \
89
+ --only_train_response \
90
+ --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
91
+ --train_prompt_sample_num 1,1,1,1,1,1 \
92
+ --train_data_sample_num 0,0,0,20000,0,0 \
93
+ --index_file .index.json
94
+
95
+
96
+ cd convert
97
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
98
+ cd ..
run_test.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ DATASET=Games
4
+ DATA_PATH=./data
5
+ OUTPUT_DIR=./ckpt/$DATASET/
6
+ RESULTS_FILE=./results/$DATASET/xxx.json
7
+
8
+ python test.py \
9
+ --gpu_id 0 \
10
+ --ckpt_path $CKPT_PATH \
11
+ --dataset $DATASET \
12
+ --data_path $DATA_PATH \
13
+ --results_file $RESULTS_FILE \
14
+ --test_batch_size 1 \
15
+ --num_beams 20 \
16
+ --test_prompt_ids all \
17
+ --index_file .index.json
test.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ from typing import List
6
+
7
+ import torch
8
+ import transformers
9
+ from peft import PeftModel
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
13
+
14
+ from utils import *
15
+ from collator import TestCollator
16
+ from prompt import all_prompt
17
+ from evaluate import get_topk_results, get_metrics_results
18
+
19
+
20
+ def test(args):
21
+
22
+ set_seed(args.seed)
23
+ print(vars(args))
24
+
25
+ device_map = {"": args.gpu_id}
26
+ device = torch.device("cuda",args.gpu_id)
27
+
28
+
29
+ tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path)
30
+ if args.lora:
31
+ model = LlamaForCausalLM.from_pretrained(
32
+ args.base_model,
33
+ torch_dtype=torch.bfloat16,
34
+ low_cpu_mem_usage=True,
35
+ device_map=device_map,
36
+ )
37
+ model.resize_token_embeddings(len(tokenizer))
38
+ model = PeftModel.from_pretrained(
39
+ model,
40
+ args.ckpt_path,
41
+ torch_dtype=torch.bfloat16,
42
+ device_map=device_map,
43
+ )
44
+ else:
45
+ model = LlamaForCausalLM.from_pretrained(
46
+ args.ckpt_path,
47
+ torch_dtype=torch.bfloat16,
48
+ low_cpu_mem_usage=True,
49
+ device_map=device_map,
50
+ )
51
+ # assert model.config.vocab_size == len(tokenizer)
52
+
53
+ if args.test_prompt_ids == "all":
54
+ if args.test_task.lower() == "seqrec":
55
+ prompt_ids = range(len(all_prompt["seqrec"]))
56
+ elif args.test_task.lower() == "itemsearch":
57
+ prompt_ids = range(len(all_prompt["itemsearch"]))
58
+ elif args.test_task.lower() == "fusionseqrec":
59
+ prompt_ids = range(len(all_prompt["fusionseqrec"]))
60
+ else:
61
+ prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")]
62
+
63
+ test_data = load_test_dataset(args)
64
+ collator = TestCollator(args, tokenizer)
65
+ all_items = test_data.get_all_items()
66
+
67
+
68
+ prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer)
69
+
70
+ test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator,
71
+ shuffle=True, num_workers=4, pin_memory=True)
72
+
73
+
74
+ print("data num:", len(test_data))
75
+
76
+ model.eval()
77
+
78
+ metrics = args.metrics.split(",")
79
+ all_prompt_results = []
80
+ with torch.no_grad():
81
+ for prompt_id in prompt_ids:
82
+
83
+ test_loader.dataset.set_prompt(prompt_id)
84
+ metrics_results = {}
85
+ total = 0
86
+
87
+ for step, batch in enumerate(tqdm(test_loader)):
88
+ inputs = batch[0].to(device)
89
+ targets = batch[1]
90
+ total += len(targets)
91
+
92
+ output = model.generate(
93
+ input_ids=inputs["input_ids"],
94
+ attention_mask=inputs["attention_mask"],
95
+ max_new_tokens=10,
96
+ # max_length=10,
97
+ prefix_allowed_tokens_fn=prefix_allowed_tokens,
98
+ num_beams=args.num_beams,
99
+ num_return_sequences=args.num_beams,
100
+ output_scores=True,
101
+ return_dict_in_generate=True,
102
+ early_stopping=True,
103
+ )
104
+ output_ids = output["sequences"]
105
+ scores = output["sequences_scores"]
106
+
107
+ output = tokenizer.batch_decode(
108
+ output_ids, skip_special_tokens=True
109
+ )
110
+ # print(output)
111
+ topk_res = get_topk_results(output,scores,targets,args.num_beams,
112
+ all_items=all_items if args.filter_items else None)
113
+
114
+ batch_metrics_res = get_metrics_results(topk_res, metrics)
115
+ # print(batch_metrics_res)
116
+
117
+ for m, res in batch_metrics_res.items():
118
+ if m not in metrics_results:
119
+ metrics_results[m] = res
120
+ else:
121
+ metrics_results[m] += res
122
+
123
+ if (step+1)%10 == 0:
124
+ temp={}
125
+ for m in metrics_results:
126
+ temp[m] = metrics_results[m] / total
127
+ print(temp)
128
+
129
+ for m in metrics_results:
130
+ metrics_results[m] = metrics_results[m] / total
131
+
132
+ all_prompt_results.append(metrics_results)
133
+ print("======================================================")
134
+ print("Prompt {} results: ".format(prompt_id), metrics_results)
135
+ print("======================================================")
136
+ print("")
137
+
138
+ mean_results = {}
139
+ min_results = {}
140
+ max_results = {}
141
+
142
+ for m in metrics:
143
+ all_res = [_[m] for _ in all_prompt_results]
144
+ mean_results[m] = sum(all_res)/len(all_res)
145
+ min_results[m] = min(all_res)
146
+ max_results[m] = max(all_res)
147
+
148
+ print("======================================================")
149
+ print("Mean results: ", mean_results)
150
+ print("Min results: ", min_results)
151
+ print("Max results: ", max_results)
152
+ print("======================================================")
153
+
154
+
155
+ save_data={}
156
+ save_data["test_prompt_ids"] = args.test_prompt_ids
157
+ save_data["mean_results"] = mean_results
158
+ save_data["min_results"] = min_results
159
+ save_data["max_results"] = max_results
160
+ save_data["all_prompt_results"] = all_prompt_results
161
+
162
+ with open(args.results_file, "w") as f:
163
+ json.dump(save_data, f, indent=4)
164
+
165
+
166
+
167
+ if __name__ == "__main__":
168
+ parser = argparse.ArgumentParser(description="LLMRec_test")
169
+ parser = parse_global_args(parser)
170
+ parser = parse_dataset_args(parser)
171
+ parser = parse_test_args(parser)
172
+
173
+ args = parser.parse_args()
174
+
175
+ test(args)
test_ddp.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import transformers
8
+ import torch.distributed as dist
9
+ from torch.utils.data.distributed import DistributedSampler
10
+ from torch.nn.parallel import DistributedDataParallel
11
+ from peft import PeftModel
12
+ from torch.utils.data import DataLoader
13
+ from tqdm import tqdm
14
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
15
+
16
+ from utils import *
17
+ from collator import TestCollator
18
+ from prompt import all_prompt
19
+ from evaluate import get_topk_results, get_metrics_results
20
+
21
+
22
+ def test_ddp(args):
23
+
24
+ set_seed(args.seed)
25
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
26
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
27
+ torch.cuda.set_device(local_rank)
28
+ if local_rank == 0:
29
+ print(vars(args))
30
+
31
+ dist.init_process_group(backend="nccl", world_size=world_size, rank=local_rank)
32
+
33
+ device_map = {"": local_rank}
34
+ device = torch.device("cuda",local_rank)
35
+
36
+ tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path)
37
+ args.lora=True
38
+ if args.lora:
39
+ model = LlamaForCausalLM.from_pretrained(
40
+ args.base_model,
41
+ torch_dtype=torch.float16,
42
+ low_cpu_mem_usage=True,
43
+ device_map=device_map,
44
+ )
45
+ model.resize_token_embeddings(len(tokenizer))
46
+ model = PeftModel.from_pretrained(
47
+ model,
48
+ args.ckpt_path,
49
+ torch_dtype=torch.float16,
50
+ device_map=device_map,
51
+ )
52
+ else:
53
+ model = LlamaForCausalLM.from_pretrained(
54
+ args.ckpt_path,
55
+ torch_dtype=torch.float16,
56
+ low_cpu_mem_usage=True,
57
+ device_map=device_map,
58
+ )
59
+ # assert model.config.vocab_size == len(tokenizer)
60
+ model = DistributedDataParallel(model, device_ids=[local_rank])
61
+
62
+ if args.test_prompt_ids == "all":
63
+ if args.test_task.lower() == "seqrec":
64
+ prompt_ids = range(len(all_prompt["seqrec"]))
65
+ elif args.test_task.lower() == "itemsearch":
66
+ prompt_ids = range(len(all_prompt["itemsearch"]))
67
+ elif args.test_task.lower() == "fusionseqrec":
68
+ prompt_ids = range(len(all_prompt["fusionseqrec"]))
69
+ else:
70
+ prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")]
71
+
72
+ test_data = load_test_dataset(args)
73
+ ddp_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=local_rank, drop_last=True)
74
+
75
+ test_data = load_test_dataset(args)
76
+ collator = TestCollator(args, tokenizer)
77
+ all_items = test_data.get_all_items()
78
+
79
+
80
+ prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer)
81
+
82
+
83
+ test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator,
84
+ sampler=ddp_sampler, num_workers=2, pin_memory=True)
85
+
86
+ if local_rank == 0:
87
+ print("data num:", len(test_data))
88
+
89
+ model.eval()
90
+
91
+ metrics = args.metrics.split(",")
92
+ all_prompt_results = []
93
+ with torch.no_grad():
94
+
95
+ for prompt_id in prompt_ids:
96
+
97
+ if local_rank == 0:
98
+ print("Start prompt: ",prompt_id)
99
+
100
+ test_loader.dataset.set_prompt(prompt_id)
101
+ metrics_results = {}
102
+ total = 0
103
+
104
+ for step, batch in enumerate(tqdm(test_loader)):
105
+ inputs = batch[0].to(device)
106
+ targets = batch[1]
107
+ bs = len(targets)
108
+ num_beams = args.num_beams
109
+ while True:
110
+ try:
111
+ output = model.module.generate(
112
+ input_ids=inputs["input_ids"],
113
+ attention_mask=inputs["attention_mask"],
114
+ max_new_tokens=10,
115
+ prefix_allowed_tokens_fn=prefix_allowed_tokens,
116
+ num_beams=num_beams,
117
+ num_return_sequences=num_beams,
118
+ output_scores=True,
119
+ return_dict_in_generate=True,
120
+ early_stopping=True,
121
+ )
122
+ break
123
+ except torch.cuda.OutOfMemoryError as e:
124
+ print("Out of memory!")
125
+ num_beams = num_beams -1
126
+ print("Beam:", num_beams)
127
+ except Exception:
128
+ raise RuntimeError
129
+
130
+ output_ids = output["sequences"]
131
+ scores = output["sequences_scores"]
132
+
133
+ output = tokenizer.batch_decode(
134
+ output_ids, skip_special_tokens=True
135
+ )
136
+
137
+ topk_res = get_topk_results(output, scores, targets, num_beams,
138
+ all_items=all_items if args.filter_items else None)
139
+
140
+ bs_gather_list = [None for _ in range(world_size)]
141
+ dist.all_gather_object(obj=bs, object_list=bs_gather_list)
142
+ total += sum(bs_gather_list)
143
+ res_gather_list = [None for _ in range(world_size)]
144
+ dist.all_gather_object(obj=topk_res, object_list=res_gather_list)
145
+
146
+
147
+ if local_rank == 0:
148
+ all_device_topk_res = []
149
+ for ga_res in res_gather_list:
150
+ all_device_topk_res += ga_res
151
+ batch_metrics_res = get_metrics_results(all_device_topk_res, metrics)
152
+ for m, res in batch_metrics_res.items():
153
+ if m not in metrics_results:
154
+ metrics_results[m] = res
155
+ else:
156
+ metrics_results[m] += res
157
+
158
+ if (step + 1) % 50 == 0:
159
+ temp = {}
160
+ for m in metrics_results:
161
+ temp[m] = metrics_results[m] / total
162
+ print(temp)
163
+
164
+ dist.barrier()
165
+
166
+ if local_rank == 0:
167
+ for m in metrics_results:
168
+ metrics_results[m] = metrics_results[m] / total
169
+
170
+ all_prompt_results.append(metrics_results)
171
+ print("======================================================")
172
+ print("Prompt {} results: ".format(prompt_id), metrics_results)
173
+ print("======================================================")
174
+ print("")
175
+
176
+ dist.barrier()
177
+
178
+ dist.barrier()
179
+
180
+ if local_rank == 0:
181
+ mean_results = {}
182
+ min_results = {}
183
+ max_results = {}
184
+
185
+ for m in metrics:
186
+ all_res = [_[m] for _ in all_prompt_results]
187
+ mean_results[m] = sum(all_res)/len(all_res)
188
+ min_results[m] = min(all_res)
189
+ max_results[m] = max(all_res)
190
+
191
+ print("======================================================")
192
+ print("Mean results: ", mean_results)
193
+ print("Min results: ", min_results)
194
+ print("Max results: ", max_results)
195
+ print("======================================================")
196
+
197
+
198
+ save_data={}
199
+ save_data["test_prompt_ids"] = args.test_prompt_ids
200
+ save_data["mean_results"] = mean_results
201
+ save_data["min_results"] = min_results
202
+ save_data["max_results"] = max_results
203
+ save_data["all_prompt_results"] = all_prompt_results
204
+
205
+ with open(args.results_file, "w") as f:
206
+ json.dump(save_data, f, indent=4)
207
+ print("Save file: ", args.results_file)
208
+
209
+ import smtplib
210
+ from email.mime.text import MIMEText
211
+ mail_host = 'smtp.qq.com'
212
+ mail_code = 'ouzplpngooqndjcb'
213
+ sender = '[email protected]'
214
+ receiver = '[email protected]'
215
+
216
+ task = '[v67: evaluate lcrec]'
217
+ message = MIMEText('Task {task} Finished'.format(task = task), 'plain', 'utf-8')
218
+ message['Subject'] = 'Auto Email'
219
+ message['From'] = sender
220
+ message['To'] = receiver
221
+
222
+ server = smtplib.SMTP_SSL("smtp.qq.com", 465)
223
+ server.login(sender, mail_code)
224
+ server.sendmail(sender, receiver, message.as_string())
225
+
226
+ server.quit()
227
+
228
+
229
+
230
+ if __name__ == "__main__":
231
+ parser = argparse.ArgumentParser(description="LLMRec_test")
232
+ parser = parse_global_args(parser)
233
+ parser = parse_dataset_args(parser)
234
+ parser = parse_test_args(parser)
235
+
236
+ args = parser.parse_args()
237
+
238
+ test_ddp(args)
test_ddp.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET=Instruments
2
+ DATA_PATH=$datain/v-yinju/rqvae-zzx/data
3
+ CKPT_PATH=$datain/v-yinju/rq-llama
4
+ RESULTS_FILE=$CKPT_PATH/result.json
5
+
6
+ torchrun --nproc_per_node=8 --master_port=4324 test_ddp.py \
7
+ --ckpt_path $CKPT_PATH \
8
+ --dataset $DATASET \
9
+ --data_path $DATA_PATH \
10
+ --results_file $RESULTS_FILE \
11
+ --test_batch_size 1 \
12
+ --num_beams 20 \
13
+ --test_prompt_ids all \
14
+ --index_file .index.json
utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ import datetime
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import ConcatDataset
10
+ from data import SeqRecDataset, ItemFeatDataset, ItemSearchDataset, FusionSeqRecDataset, SeqRecTestDataset, PreferenceObtainDataset
11
+
12
+
13
+ def parse_global_args(parser):
14
+
15
+
16
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
17
+
18
+ parser.add_argument("--base_model", type=str,
19
+ default="./llama-7b/",
20
+ help="basic model path")
21
+ parser.add_argument("--output_dir", type=str,
22
+ default="./ckpt/",
23
+ help="The output directory")
24
+
25
+
26
+ return parser
27
+
28
+ def parse_dataset_args(parser):
29
+ parser.add_argument("--data_path", type=str, default="",
30
+ help="data directory")
31
+ parser.add_argument("--tasks", type=str, default="seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain",
32
+ help="Downstream tasks, separate by comma")
33
+ parser.add_argument("--dataset", type=str, default="Games", help="Dataset name")
34
+ parser.add_argument("--index_file", type=str, default=".index.json", help="the item indices file")
35
+ parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers")
36
+ parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor")
37
+
38
+
39
+ # arguments related to sequential task
40
+ parser.add_argument("--max_his_len", type=int, default=20,
41
+ help="the max number of items in history sequence, -1 means no limit")
42
+ parser.add_argument("--add_prefix", action="store_true", default=False,
43
+ help="whether add sequential prefix in history")
44
+ parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history")
45
+ parser.add_argument("--only_train_response", action="store_true", default=False,
46
+ help="whether only train on responses")
47
+
48
+ parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1",
49
+ help="the number of sampling prompts for each task")
50
+ parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,100000,0,0",
51
+ help="the number of sampling prompts for each task")
52
+
53
+ parser.add_argument("--valid_prompt_id", type=int, default=0,
54
+ help="The prompt used for validation")
55
+ parser.add_argument("--sample_valid", action="store_true", default=True,
56
+ help="use sampled prompt for validation")
57
+ parser.add_argument("--valid_prompt_sample_num", type=int, default=2,
58
+ help="the number of sampling validation sequential recommendation prompts")
59
+
60
+ return parser
61
+
62
+ def parse_train_args(parser):
63
+
64
+ parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer')
65
+ parser.add_argument("--epochs", type=int, default=4)
66
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
67
+ parser.add_argument("--per_device_batch_size", type=int, default=8)
68
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
69
+ parser.add_argument("--logging_step", type=int, default=10)
70
+ parser.add_argument("--model_max_length", type=int, default=2048)
71
+ parser.add_argument("--weight_decay", type=float, default=0.01)
72
+
73
+ parser.add_argument("--lora_r", type=int, default=8)
74
+ parser.add_argument("--lora_alpha", type=int, default=32)
75
+ parser.add_argument("--lora_dropout", type=float, default=0.05)
76
+ parser.add_argument("--lora_target_modules", type=str,
77
+ default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma")
78
+ parser.add_argument("--lora_modules_to_save", type=str,
79
+ default="embed_tokens,lm_head", help="separate by comma")
80
+
81
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter")
82
+
83
+ parser.add_argument("--warmup_ratio", type=float, default=0.01)
84
+ parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
85
+ parser.add_argument("--save_and_eval_strategy", type=str, default="epoch")
86
+ parser.add_argument("--save_and_eval_steps", type=int, default=1000)
87
+ parser.add_argument("--fp16", action="store_true", default=False)
88
+ parser.add_argument("--bf16", action="store_true", default=False)
89
+ parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json")
90
+
91
+ return parser
92
+
93
+ def parse_test_args(parser):
94
+
95
+ parser.add_argument("--ckpt_path", type=str,
96
+ default="",
97
+ help="The checkpoint path")
98
+ parser.add_argument("--lora", action="store_true", default=False)
99
+ parser.add_argument("--filter_items", action="store_true", default=False,
100
+ help="whether filter illegal items")
101
+
102
+ parser.add_argument("--results_file", type=str,
103
+ default="./results/test-ddp.json",
104
+ help="result output path")
105
+
106
+ parser.add_argument("--test_batch_size", type=int, default=1)
107
+ parser.add_argument("--num_beams", type=int, default=20)
108
+ parser.add_argument("--sample_num", type=int, default=-1,
109
+ help="test sample number, -1 represents using all test data")
110
+ parser.add_argument("--gpu_id", type=int, default=0,
111
+ help="GPU ID when testing with single GPU")
112
+ parser.add_argument("--test_prompt_ids", type=str, default="0",
113
+ help="test prompt ids, separate by comma. 'all' represents using all")
114
+ parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10",
115
+ help="test metrics, separate by comma")
116
+ parser.add_argument("--test_task", type=str, default="SeqRec")
117
+
118
+
119
+ return parser
120
+
121
+
122
+ def get_local_time():
123
+ cur = datetime.datetime.now()
124
+ cur = cur.strftime("%b-%d-%Y_%H-%M-%S")
125
+
126
+ return cur
127
+
128
+
129
+ def set_seed(seed):
130
+ random.seed(seed)
131
+ np.random.seed(seed)
132
+ torch.manual_seed(seed)
133
+ torch.cuda.manual_seed_all(seed)
134
+ torch.backends.cudnn.benchmark = False
135
+ torch.backends.cudnn.deterministic = True
136
+ torch.backends.cudnn.enabled = False
137
+
138
+ def ensure_dir(dir_path):
139
+
140
+ os.makedirs(dir_path, exist_ok=True)
141
+
142
+
143
+ def load_datasets(args):
144
+
145
+ tasks = args.tasks.split(",")
146
+
147
+ train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")]
148
+ assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number"
149
+ train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")]
150
+ assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number"
151
+
152
+ train_datasets = []
153
+ for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num):
154
+ if task.lower() == "seqrec":
155
+ dataset = SeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
156
+
157
+ elif task.lower() == "item2index" or task.lower() == "index2item":
158
+ dataset = ItemFeatDataset(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
159
+
160
+ elif task.lower() == "fusionseqrec":
161
+ dataset = FusionSeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
162
+
163
+ elif task.lower() == "itemsearch":
164
+ dataset = ItemSearchDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
165
+
166
+ elif task.lower() == "preferenceobtain":
167
+ dataset = PreferenceObtainDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
168
+
169
+ else:
170
+ raise NotImplementedError
171
+ train_datasets.append(dataset)
172
+
173
+ train_data = ConcatDataset(train_datasets)
174
+
175
+ valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num)
176
+
177
+ return train_data, valid_data
178
+
179
+ def load_test_dataset(args):
180
+
181
+ if args.test_task.lower() == "seqrec":
182
+ test_data = SeqRecDataset(args, mode="test", sample_num=args.sample_num)
183
+ # test_data = SeqRecTestDataset(args, sample_num=args.sample_num)
184
+ elif args.test_task.lower() == "itemsearch":
185
+ test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num)
186
+ elif args.test_task.lower() == "fusionseqrec":
187
+ test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num)
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ return test_data
192
+
193
+ def load_json(file):
194
+ with open(file, 'r') as f:
195
+ data = json.load(f)
196
+ return data