Add files using upload-large-folder tool
Browse files- .gitattributes +1 -0
- README.md +124 -0
- asset/model.jpg +3 -0
- collator.py +77 -0
- config/ds_z2_bf16.json +28 -0
- config/ds_z2_fp16.json +34 -0
- config/ds_z3_bf16.json +31 -0
- config/ds_z3_bf16_save16bit.json +31 -0
- config/ds_z3_fp16.json +37 -0
- config/ds_z3_fp16_save16bit.json +37 -0
- convert/convert.py +16 -0
- convert/convert.sh +18 -0
- convert/convert_fp16.py +23 -0
- convert/make_delta.py +46 -0
- convert/merge_delta.py +167 -0
- convert/zero_to_fp32.py +600 -0
- data.py +844 -0
- data_process/amazon18_data_process.py +299 -0
- data_process/amazon18_recbole_data_process.py +226 -0
- data_process/amazon_text_emb.py +161 -0
- data_process/get_llm_output.py +374 -0
- data_process/utils.py +238 -0
- evaluate.py +66 -0
- finetune.py +121 -0
- index/datasets.py +21 -0
- index/generate_indices.py +155 -0
- index/main.py +87 -0
- index/models/layers.py +106 -0
- index/models/rq.py +54 -0
- index/models/rqvae.py +82 -0
- index/models/vq.py +103 -0
- index/run.sh +18 -0
- index/trainer.py +209 -0
- index/utils.py +36 -0
- instruments_eval.sh +17 -0
- instruments_train.sh +34 -0
- lora_finetune.py +164 -0
- prompt.py +663 -0
- run.sh +98 -0
- run_test.sh +17 -0
- test.py +175 -0
- test_ddp.py +238 -0
- test_ddp.sh +14 -0
- utils.py +196 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
asset/model.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LC-Rec
|
2 |
+
|
3 |
+
This is the official PyTorch implementation for the paper:
|
4 |
+
|
5 |
+
> [Adapting Large Language Models by Integrating Collaborative Semantics for Recommendation](https://arxiv.org/abs/2311.09049)
|
6 |
+
|
7 |
+
## Overview
|
8 |
+
|
9 |
+
We propose **LC-Rec**, a new approach to integrate **L**anguage and **C**ollaborative semantics for improving LLMs in **Rec**ommender systems. To tackle the large gap between the language semantics modeled by LLMs and collaborative semantics implied by recommender systems, we make two major contributions in two aspects. For item indexing, we design a learning-based vector quantization method with uniform semantic mapping, which can assign meaningful and non-conflicting IDs (called item indices) for items. For alignment tuning, we propose a series of specially designed tuning tasks to enhance the integration of collaborative semantics in LLMs. Our fine-tuning tasks enforce LLMs to deeply integrate language and collaborative semantics (characterized by the learned item indices), so as to achieve an effective adaptation to recommender systems.
|
10 |
+
|
11 |
+

|
12 |
+
|
13 |
+
## Requirements
|
14 |
+
|
15 |
+
```
|
16 |
+
torch==1.13.1+cu117
|
17 |
+
accelerate
|
18 |
+
bitsandbytes
|
19 |
+
deepspeed
|
20 |
+
evaluate
|
21 |
+
peft
|
22 |
+
sentencepiece
|
23 |
+
tqdm
|
24 |
+
transformers
|
25 |
+
```
|
26 |
+
|
27 |
+
## Model Checkpoint
|
28 |
+
|
29 |
+
The delta weights on the three datasets can be downloaded from huggingface hub ([Instruments](https://huggingface.co/bwzheng0324/lc-rec-instruments-delta), [Arts](https://huggingface.co/bwzheng0324/lc-rec-arts-delta), [Games](https://huggingface.co/bwzheng0324/lc-rec-games-delta)). After downloading, you can add our deltas to the original LLaMA weights to obtain LC-Rec weights:
|
30 |
+
|
31 |
+
1. Get the original [LLaMA](https://huggingface.co/huggyllama/llama-7b) weights.
|
32 |
+
2. Use the following scripts to get LC-Rec weights by applying our delta.
|
33 |
+
|
34 |
+
```shell
|
35 |
+
python -m convert/merge_delta.py \
|
36 |
+
--base-model-path /path/to/llama-7b \
|
37 |
+
--target-model-path /path/output/lc-rec \
|
38 |
+
--delta-path bwzheng0324/lc-rec-games-delta
|
39 |
+
```
|
40 |
+
|
41 |
+
## Dataset
|
42 |
+
|
43 |
+
We use three datasets in our paper, all of which have been uploaded to [Google Drive](https://drive.google.com/drive/folders/1RcJ2M1l5zWPHYuGd9l5Gibcs5w5aI3y6?usp=sharing)
|
44 |
+
|
45 |
+
## Train
|
46 |
+
|
47 |
+
The detailed scripts for all three datasets are in `run.sh`:
|
48 |
+
|
49 |
+
```shell
|
50 |
+
DATASET=Games
|
51 |
+
BASE_MODEL=huggyllama/llama-7b
|
52 |
+
DATA_PATH=./data
|
53 |
+
OUTPUT_DIR=./ckpt/$DATASET/
|
54 |
+
|
55 |
+
torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
|
56 |
+
--base_model $BASE_MODEL \
|
57 |
+
--output_dir $OUTPUT_DIR \
|
58 |
+
--dataset $DATASET \
|
59 |
+
--data_path $DATA_PATH \
|
60 |
+
--per_device_batch_size 8 \
|
61 |
+
--gradient_accumulation_steps 2 \
|
62 |
+
--learning_rate 5e-5 \
|
63 |
+
--epochs 4 \
|
64 |
+
--weight_decay 0.01 \
|
65 |
+
--save_and_eval_strategy epoch \
|
66 |
+
--deepspeed ./config/ds_z3_bf16.json \
|
67 |
+
--bf16 \
|
68 |
+
--only_train_response \
|
69 |
+
--tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
|
70 |
+
--train_prompt_sample_num 1,1,1,1,1,1 \
|
71 |
+
--train_data_sample_num 0,0,0,100000,0,0 \
|
72 |
+
--index_file .index.json
|
73 |
+
|
74 |
+
|
75 |
+
cd convert
|
76 |
+
nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
|
77 |
+
cd ..
|
78 |
+
```
|
79 |
+
|
80 |
+
## Test
|
81 |
+
|
82 |
+
Test with a single GPU:
|
83 |
+
|
84 |
+
```shell
|
85 |
+
DATASET=Games
|
86 |
+
DATA_PATH=./data
|
87 |
+
OUTPUT_DIR=./ckpt/$DATASET/
|
88 |
+
RESULTS_FILE=./results/$DATASET/result.json
|
89 |
+
|
90 |
+
python test.py \
|
91 |
+
--gpu_id 0 \
|
92 |
+
--ckpt_path $CKPT_PATH \
|
93 |
+
--dataset $DATASET \
|
94 |
+
--data_path $DATA_PATH \
|
95 |
+
--results_file $RESULTS_FILE \
|
96 |
+
--test_batch_size 1 \
|
97 |
+
--num_beams 20 \
|
98 |
+
--test_prompt_ids all \
|
99 |
+
--index_file .index.json
|
100 |
+
```
|
101 |
+
|
102 |
+
Test with multiple GPUs:
|
103 |
+
|
104 |
+
```shell
|
105 |
+
DATASET=Games
|
106 |
+
DATA_PATH=./data
|
107 |
+
OUTPUT_DIR=./ckpt/$DATASET/
|
108 |
+
RESULTS_FILE=./results/$DATASET/result.json
|
109 |
+
|
110 |
+
torchrun --nproc_per_node=8 --master_port=4324 test_ddp.py \
|
111 |
+
--ckpt_path $CKPT_PATH \
|
112 |
+
--dataset $DATASET \
|
113 |
+
--data_path $DATA_PATH \
|
114 |
+
--results_file $RESULTS_FILE \
|
115 |
+
--test_batch_size 1 \
|
116 |
+
--num_beams 20 \
|
117 |
+
--test_prompt_ids all \
|
118 |
+
--index_file .index.json
|
119 |
+
```
|
120 |
+
|
121 |
+
## Acknowledgement
|
122 |
+
|
123 |
+
The implementation is based on [HuggingFace](https://github.com/huggingface/transformers).
|
124 |
+
|
asset/model.jpg
ADDED
![]() |
Git LFS Details
|
collator.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import argparse
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
import transformers
|
7 |
+
import math
|
8 |
+
from torch.utils.data import Sampler
|
9 |
+
import torch.distributed as dist
|
10 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, T5Tokenizer, T5Config, T5ForConditionalGeneration
|
11 |
+
|
12 |
+
|
13 |
+
class Collator(object):
|
14 |
+
|
15 |
+
def __init__(self, args, tokenizer):
|
16 |
+
self.args = args
|
17 |
+
self.only_train_response = args.only_train_response
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
if self.tokenizer.pad_token_id is None:
|
20 |
+
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
|
21 |
+
# print(self.tokenizer.model_max_length)
|
22 |
+
|
23 |
+
def __call__(self, batch):
|
24 |
+
|
25 |
+
input_texts = [d["input_ids"] for d in batch]
|
26 |
+
full_texts = [d["labels"] + self.tokenizer.eos_token for d in batch]
|
27 |
+
|
28 |
+
inputs = self.tokenizer(
|
29 |
+
text = full_texts,
|
30 |
+
text_target = input_texts,
|
31 |
+
return_tensors="pt",
|
32 |
+
padding="longest",
|
33 |
+
max_length=self.tokenizer.model_max_length,
|
34 |
+
truncation=True,
|
35 |
+
return_attention_mask=True,
|
36 |
+
)
|
37 |
+
labels = copy.deepcopy(inputs["input_ids"])
|
38 |
+
if self.only_train_response:
|
39 |
+
# ignore padding
|
40 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
41 |
+
# ignore input text
|
42 |
+
labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100
|
43 |
+
|
44 |
+
inputs["labels"] = labels
|
45 |
+
|
46 |
+
|
47 |
+
return inputs
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
class TestCollator(object):
|
52 |
+
|
53 |
+
def __init__(self, args, tokenizer):
|
54 |
+
self.args = args
|
55 |
+
self.tokenizer = tokenizer
|
56 |
+
if self.tokenizer.pad_token_id is None:
|
57 |
+
self.tokenizer.pad_token_id = 0
|
58 |
+
|
59 |
+
if isinstance(self.tokenizer, LlamaTokenizer):
|
60 |
+
# Allow batched inference
|
61 |
+
self.tokenizer.padding_side = "left"
|
62 |
+
|
63 |
+
def __call__(self, batch):
|
64 |
+
|
65 |
+
input_texts = [d["input_ids"] for d in batch]
|
66 |
+
targets = [d["labels"] for d in batch]
|
67 |
+
inputs = self.tokenizer(
|
68 |
+
text=input_texts,
|
69 |
+
return_tensors="pt",
|
70 |
+
padding="longest",
|
71 |
+
max_length=self.tokenizer.model_max_length,
|
72 |
+
truncation=True,
|
73 |
+
return_attention_mask=True,
|
74 |
+
)
|
75 |
+
|
76 |
+
return (inputs, targets)
|
77 |
+
|
config/ds_z2_bf16.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": "auto"
|
4 |
+
},
|
5 |
+
"zero_optimization": {
|
6 |
+
"stage": 2,
|
7 |
+
"allgather_partitions": true,
|
8 |
+
"allgather_bucket_size": 5e8,
|
9 |
+
"overlap_comm": true,
|
10 |
+
"reduce_scatter": true,
|
11 |
+
"reduce_bucket_size": 5e8,
|
12 |
+
"contiguous_gradients": true
|
13 |
+
},
|
14 |
+
"gradient_accumulation_steps": "auto",
|
15 |
+
"gradient_clipping": "auto",
|
16 |
+
"steps_per_print": 2000,
|
17 |
+
"train_batch_size": "auto",
|
18 |
+
"train_micro_batch_size_per_gpu": "auto",
|
19 |
+
"wall_clock_breakdown": false,
|
20 |
+
"flops_profiler": {
|
21 |
+
"enabled": true,
|
22 |
+
"profile_step": 10,
|
23 |
+
"module_depth": -1,
|
24 |
+
"top_modules": 3,
|
25 |
+
"detailed": true,
|
26 |
+
"output_file": "flops_profiler.out"
|
27 |
+
}
|
28 |
+
}
|
config/ds_z2_fp16.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"auto_cast": false,
|
5 |
+
"loss_scale": 0,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"loss_scale_window": 1000,
|
8 |
+
"hysteresis": 2,
|
9 |
+
"min_loss_scale": 1
|
10 |
+
},
|
11 |
+
"zero_optimization": {
|
12 |
+
"stage": 2,
|
13 |
+
"allgather_partitions": true,
|
14 |
+
"allgather_bucket_size": 5e8,
|
15 |
+
"overlap_comm": true,
|
16 |
+
"reduce_scatter": true,
|
17 |
+
"reduce_bucket_size": 5e8,
|
18 |
+
"contiguous_gradients": true
|
19 |
+
},
|
20 |
+
"gradient_accumulation_steps": "auto",
|
21 |
+
"gradient_clipping": "auto",
|
22 |
+
"steps_per_print": 2000,
|
23 |
+
"train_batch_size": "auto",
|
24 |
+
"train_micro_batch_size_per_gpu": "auto",
|
25 |
+
"wall_clock_breakdown": false,
|
26 |
+
"flops_profiler": {
|
27 |
+
"enabled": true,
|
28 |
+
"profile_step": 10,
|
29 |
+
"module_depth": -1,
|
30 |
+
"top_modules": 3,
|
31 |
+
"detailed": true,
|
32 |
+
"output_file": "flops_profiler.out"
|
33 |
+
}
|
34 |
+
}
|
config/ds_z3_bf16.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": "auto"
|
4 |
+
},
|
5 |
+
"zero_optimization": {
|
6 |
+
"stage": 3,
|
7 |
+
"overlap_comm": true,
|
8 |
+
"contiguous_gradients": true,
|
9 |
+
"sub_group_size": 1e9,
|
10 |
+
"reduce_bucket_size": "auto",
|
11 |
+
"stage3_prefetch_bucket_size": "auto",
|
12 |
+
"stage3_param_persistence_threshold": "auto",
|
13 |
+
"stage3_max_live_parameters": 1e9,
|
14 |
+
"stage3_max_reuse_distance": 1e9,
|
15 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
16 |
+
},
|
17 |
+
"gradient_accumulation_steps": "auto",
|
18 |
+
"gradient_clipping": "auto",
|
19 |
+
"steps_per_print": 2000,
|
20 |
+
"train_batch_size": "auto",
|
21 |
+
"train_micro_batch_size_per_gpu": "auto",
|
22 |
+
"wall_clock_breakdown": false,
|
23 |
+
"flops_profiler": {
|
24 |
+
"enabled": true,
|
25 |
+
"profile_step": 10,
|
26 |
+
"module_depth": -1,
|
27 |
+
"top_modules": 3,
|
28 |
+
"detailed": true,
|
29 |
+
"output_file": "flops_profiler.out"
|
30 |
+
}
|
31 |
+
}
|
config/ds_z3_bf16_save16bit.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": "auto"
|
4 |
+
},
|
5 |
+
"zero_optimization": {
|
6 |
+
"stage": 3,
|
7 |
+
"overlap_comm": true,
|
8 |
+
"contiguous_gradients": true,
|
9 |
+
"sub_group_size": 1e9,
|
10 |
+
"reduce_bucket_size": "auto",
|
11 |
+
"stage3_prefetch_bucket_size": "auto",
|
12 |
+
"stage3_param_persistence_threshold": "auto",
|
13 |
+
"stage3_max_live_parameters": 1e9,
|
14 |
+
"stage3_max_reuse_distance": 1e9,
|
15 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
16 |
+
},
|
17 |
+
"gradient_accumulation_steps": "auto",
|
18 |
+
"gradient_clipping": "auto",
|
19 |
+
"steps_per_print": 2000,
|
20 |
+
"train_batch_size": "auto",
|
21 |
+
"train_micro_batch_size_per_gpu": "auto",
|
22 |
+
"wall_clock_breakdown": false,
|
23 |
+
"flops_profiler": {
|
24 |
+
"enabled": true,
|
25 |
+
"profile_step": 10,
|
26 |
+
"module_depth": -1,
|
27 |
+
"top_modules": 3,
|
28 |
+
"detailed": true,
|
29 |
+
"output_file": "flops_profiler.out"
|
30 |
+
}
|
31 |
+
}
|
config/ds_z3_fp16.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"auto_cast": false,
|
5 |
+
"loss_scale": 0,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"loss_scale_window": 1000,
|
8 |
+
"hysteresis": 2,
|
9 |
+
"min_loss_scale": 1
|
10 |
+
},
|
11 |
+
"zero_optimization": {
|
12 |
+
"stage": 3,
|
13 |
+
"overlap_comm": true,
|
14 |
+
"contiguous_gradients": true,
|
15 |
+
"sub_group_size": 1e9,
|
16 |
+
"reduce_bucket_size": "auto",
|
17 |
+
"stage3_prefetch_bucket_size": "auto",
|
18 |
+
"stage3_param_persistence_threshold": "auto",
|
19 |
+
"stage3_max_live_parameters": 1e9,
|
20 |
+
"stage3_max_reuse_distance": 1e9,
|
21 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
22 |
+
},
|
23 |
+
"gradient_accumulation_steps": "auto",
|
24 |
+
"gradient_clipping": "auto",
|
25 |
+
"steps_per_print": 2000,
|
26 |
+
"train_batch_size": "auto",
|
27 |
+
"train_micro_batch_size_per_gpu": "auto",
|
28 |
+
"wall_clock_breakdown": false,
|
29 |
+
"flops_profiler": {
|
30 |
+
"enabled": true,
|
31 |
+
"profile_step": 10,
|
32 |
+
"module_depth": -1,
|
33 |
+
"top_modules": 3,
|
34 |
+
"detailed": true,
|
35 |
+
"output_file": "flops_profiler.out"
|
36 |
+
}
|
37 |
+
}
|
config/ds_z3_fp16_save16bit.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"auto_cast": false,
|
5 |
+
"loss_scale": 0,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"loss_scale_window": 1000,
|
8 |
+
"hysteresis": 2,
|
9 |
+
"min_loss_scale": 1
|
10 |
+
},
|
11 |
+
"zero_optimization": {
|
12 |
+
"stage": 3,
|
13 |
+
"overlap_comm": true,
|
14 |
+
"contiguous_gradients": true,
|
15 |
+
"sub_group_size": 1e9,
|
16 |
+
"reduce_bucket_size": "auto",
|
17 |
+
"stage3_prefetch_bucket_size": "auto",
|
18 |
+
"stage3_param_persistence_threshold": "auto",
|
19 |
+
"stage3_max_live_parameters": 1e9,
|
20 |
+
"stage3_max_reuse_distance": 1e9,
|
21 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
22 |
+
},
|
23 |
+
"gradient_accumulation_steps": "auto",
|
24 |
+
"gradient_clipping": "auto",
|
25 |
+
"steps_per_print": 2000,
|
26 |
+
"train_batch_size": "auto",
|
27 |
+
"train_micro_batch_size_per_gpu": "auto",
|
28 |
+
"wall_clock_breakdown": false,
|
29 |
+
"flops_profiler": {
|
30 |
+
"enabled": true,
|
31 |
+
"profile_step": 10,
|
32 |
+
"module_depth": -1,
|
33 |
+
"top_modules": 3,
|
34 |
+
"detailed": true,
|
35 |
+
"output_file": "flops_profiler.out"
|
36 |
+
}
|
37 |
+
}
|
convert/convert.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
parser.add_argument("--source", "-s", type=str, default="", help="source path of models")
|
8 |
+
parser.add_argument("--target", "-t", type=str, default="", help="target path of models")
|
9 |
+
|
10 |
+
args, _ = parser.parse_known_args()
|
11 |
+
|
12 |
+
assert os.path.exists(args.source)
|
13 |
+
assert args.target != ""
|
14 |
+
|
15 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(args.source)
|
16 |
+
model.save_pretrained(args.target, state_dict=model.state_dict())
|
convert/convert.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model=$1
|
2 |
+
|
3 |
+
set -x
|
4 |
+
|
5 |
+
for step in `ls ${model} | grep checkpoint | awk -F'-' '{ print $2 }'`
|
6 |
+
do
|
7 |
+
mkdir ${model}/tmp-checkpoint-${step}
|
8 |
+
mkdir ${model}/final-checkpoint-${step}
|
9 |
+
python ./zero_to_fp32.py ${model}/checkpoint-${step}/ ${model}/tmp-checkpoint-${step}/pytorch_model.bin
|
10 |
+
cp ${model}/*.json ${model}/tmp-checkpoint-${step}
|
11 |
+
python ./convert.py -s ${model}/tmp-checkpoint-${step} -t ${model}/final-checkpoint-${step}
|
12 |
+
cp ${model}/checkpoint-${step}/*.json ${model}/final-checkpoint-${step}
|
13 |
+
cp ${model}/*.json ${model}/final-checkpoint-${step}
|
14 |
+
cp ${model}/tokenizer* ${model}/final-checkpoint-${step}
|
15 |
+
cp ${model}/train* ${model}/final-checkpoint-${step}
|
16 |
+
#rm -rf ${model}/tmp-checkpoint-${step} ${model}/checkpoint-${step} ${model}/global_step${step}
|
17 |
+
#mv ${model}/final-checkpoint-${step} ${model}/checkpoint-${step}
|
18 |
+
done
|
convert/convert_fp16.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def convert_fp16(in_checkpoint, out_checkpoint):
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(
|
11 |
+
in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
12 |
+
)
|
13 |
+
model.save_pretrained(out_checkpoint)
|
14 |
+
tokenizer.save_pretrained(out_checkpoint)
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
|
20 |
+
parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
convert_fp16(args.in_checkpoint, args.out_checkpoint)
|
convert/make_delta.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
|
8 |
+
|
9 |
+
def make_delta(base_model_path, target_model_path, delta_path):
|
10 |
+
print(f"Loading the base model from {base_model_path}")
|
11 |
+
base = AutoModelForCausalLM.from_pretrained(
|
12 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
13 |
+
)
|
14 |
+
|
15 |
+
print(f"Loading the target model from {target_model_path}")
|
16 |
+
target = AutoModelForCausalLM.from_pretrained(
|
17 |
+
target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
18 |
+
)
|
19 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False)
|
20 |
+
|
21 |
+
print("Calculating the delta")
|
22 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
23 |
+
assert name in base.state_dict()
|
24 |
+
if param.shape == base.state_dict()[name].shape:
|
25 |
+
param.data -= base.state_dict()[name]
|
26 |
+
else:
|
27 |
+
print(name)
|
28 |
+
|
29 |
+
print(f"Saving the delta to {delta_path}")
|
30 |
+
if args.hub_repo_id:
|
31 |
+
kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id}
|
32 |
+
else:
|
33 |
+
kwargs = {}
|
34 |
+
target.save_pretrained(delta_path, **kwargs)
|
35 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
41 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
42 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
43 |
+
parser.add_argument("--hub-repo-id", type=str)
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
make_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
convert/merge_delta.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import gc
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
import tempfile
|
9 |
+
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
15 |
+
|
16 |
+
|
17 |
+
GB = 1 << 30
|
18 |
+
|
19 |
+
|
20 |
+
def split_files(model_path, tmp_path, split_size):
|
21 |
+
if not os.path.exists(model_path):
|
22 |
+
model_path = snapshot_download(repo_id=model_path)
|
23 |
+
if not os.path.exists(tmp_path):
|
24 |
+
os.makedirs(tmp_path)
|
25 |
+
|
26 |
+
file_pattern = os.path.join(model_path, "pytorch_model-*.bin")
|
27 |
+
files = glob.glob(file_pattern)
|
28 |
+
|
29 |
+
part = 0
|
30 |
+
try:
|
31 |
+
for file_path in tqdm(files):
|
32 |
+
state_dict = torch.load(file_path)
|
33 |
+
new_state_dict = {}
|
34 |
+
|
35 |
+
current_size = 0
|
36 |
+
for name, param in state_dict.items():
|
37 |
+
param_size = param.numel() * param.element_size()
|
38 |
+
|
39 |
+
if current_size + param_size > split_size:
|
40 |
+
new_file_name = f"pytorch_model-{part}.bin"
|
41 |
+
new_file_path = os.path.join(tmp_path, new_file_name)
|
42 |
+
torch.save(new_state_dict, new_file_path)
|
43 |
+
current_size = 0
|
44 |
+
new_state_dict = None
|
45 |
+
gc.collect()
|
46 |
+
new_state_dict = {}
|
47 |
+
part += 1
|
48 |
+
|
49 |
+
new_state_dict[name] = param
|
50 |
+
current_size += param_size
|
51 |
+
|
52 |
+
new_file_name = f"pytorch_model-{part}.bin"
|
53 |
+
new_file_path = os.path.join(tmp_path, new_file_name)
|
54 |
+
torch.save(new_state_dict, new_file_path)
|
55 |
+
new_state_dict = None
|
56 |
+
gc.collect()
|
57 |
+
new_state_dict = {}
|
58 |
+
part += 1
|
59 |
+
except Exception as e:
|
60 |
+
print(f"An error occurred during split_files: {e}")
|
61 |
+
shutil.rmtree(tmp_path)
|
62 |
+
raise
|
63 |
+
|
64 |
+
|
65 |
+
def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):
|
66 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
|
67 |
+
delta_config = AutoConfig.from_pretrained(delta_path)
|
68 |
+
|
69 |
+
if os.path.exists(target_model_path):
|
70 |
+
shutil.rmtree(target_model_path)
|
71 |
+
os.makedirs(target_model_path)
|
72 |
+
|
73 |
+
split_size = 4 * GB
|
74 |
+
|
75 |
+
with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
|
76 |
+
print(f"Split files for the base model to {tmp_base_path}")
|
77 |
+
split_files(base_model_path, tmp_base_path, split_size)
|
78 |
+
print(f"Split files for the delta weights to {tmp_delta_path}")
|
79 |
+
split_files(delta_path, tmp_delta_path, split_size)
|
80 |
+
|
81 |
+
base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin")
|
82 |
+
base_files = glob.glob(base_pattern)
|
83 |
+
base_state_dict = torch.load(base_files[0])
|
84 |
+
delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin")
|
85 |
+
delta_files = glob.glob(delta_pattern)
|
86 |
+
# delta_state_dict = torch.load(delta_files[0])
|
87 |
+
|
88 |
+
print("Applying the delta")
|
89 |
+
weight_map = {}
|
90 |
+
total_size = 0
|
91 |
+
|
92 |
+
for i, delta_file in tqdm(enumerate(delta_files)):
|
93 |
+
state_dict = torch.load(delta_file)
|
94 |
+
file_name = f"pytorch_model-{i}.bin"
|
95 |
+
for name, param in state_dict.items():
|
96 |
+
if name not in base_state_dict:
|
97 |
+
for base_file in base_files:
|
98 |
+
base_state_dict = torch.load(base_file)
|
99 |
+
gc.collect()
|
100 |
+
if name in base_state_dict:
|
101 |
+
break
|
102 |
+
if state_dict[name].shape == base_state_dict[name].shape:
|
103 |
+
state_dict[name] += base_state_dict[name]
|
104 |
+
else:
|
105 |
+
print(name)
|
106 |
+
weight_map[name] = file_name
|
107 |
+
total_size += param.numel() * param.element_size()
|
108 |
+
gc.collect()
|
109 |
+
torch.save(state_dict, os.path.join(target_model_path, file_name))
|
110 |
+
|
111 |
+
with open(
|
112 |
+
os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w"
|
113 |
+
) as f:
|
114 |
+
json.dump(
|
115 |
+
{"weight_map": weight_map, "metadata": {"total_size": total_size}}, f
|
116 |
+
)
|
117 |
+
|
118 |
+
print(f"Saving the target model to {target_model_path}")
|
119 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
120 |
+
delta_config.save_pretrained(target_model_path)
|
121 |
+
|
122 |
+
|
123 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
124 |
+
print(f"Loading the delta weights from {delta_path}")
|
125 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
|
126 |
+
delta = AutoModelForCausalLM.from_pretrained(
|
127 |
+
delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
128 |
+
)
|
129 |
+
|
130 |
+
print(f"Loading the base model from {base_model_path}")
|
131 |
+
base = AutoModelForCausalLM.from_pretrained(
|
132 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
133 |
+
)
|
134 |
+
|
135 |
+
print("Applying the delta")
|
136 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
137 |
+
assert name in base.state_dict()
|
138 |
+
if param.shape == base.state_dict()[name].shape:
|
139 |
+
param.data += base.state_dict()[name]
|
140 |
+
else:
|
141 |
+
print(name)
|
142 |
+
|
143 |
+
|
144 |
+
print(f"Saving the target model to {target_model_path}")
|
145 |
+
delta.save_pretrained(target_model_path)
|
146 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
parser = argparse.ArgumentParser()
|
151 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
152 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
153 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
154 |
+
parser.add_argument(
|
155 |
+
"--low-cpu-mem",
|
156 |
+
action="store_true",
|
157 |
+
help="Lower the cpu memory usage. This will split large files and use "
|
158 |
+
"disk as swap to reduce the memory usage below 10GB.",
|
159 |
+
)
|
160 |
+
args = parser.parse_args()
|
161 |
+
|
162 |
+
if args.low_cpu_mem:
|
163 |
+
apply_delta_low_cpu_mem(
|
164 |
+
args.base_model_path, args.target_model_path, args.delta_path
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
convert/zero_to_fp32.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright (c) Microsoft Corporation.
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
# DeepSpeed Team
|
7 |
+
|
8 |
+
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
|
9 |
+
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
10 |
+
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
11 |
+
# application.
|
12 |
+
#
|
13 |
+
# example: python zero_to_fp32.py . pytorch_model.bin
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import torch
|
17 |
+
import glob
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from collections import OrderedDict
|
22 |
+
from dataclasses import dataclass
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
26 |
+
# DeepSpeed data structures it has to be available in the current python environment.
|
27 |
+
from deepspeed.utils import logger
|
28 |
+
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
29 |
+
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
30 |
+
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class zero_model_state:
|
35 |
+
buffers: dict()
|
36 |
+
param_shapes: dict()
|
37 |
+
shared_params: list
|
38 |
+
ds_version: int
|
39 |
+
frozen_param_shapes: dict()
|
40 |
+
frozen_param_fragments: dict()
|
41 |
+
|
42 |
+
|
43 |
+
debug = 0
|
44 |
+
|
45 |
+
# load to cpu
|
46 |
+
device = torch.device('cpu')
|
47 |
+
|
48 |
+
|
49 |
+
def atoi(text):
|
50 |
+
return int(text) if text.isdigit() else text
|
51 |
+
|
52 |
+
|
53 |
+
def natural_keys(text):
|
54 |
+
'''
|
55 |
+
alist.sort(key=natural_keys) sorts in human order
|
56 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
57 |
+
(See Toothy's implementation in the comments)
|
58 |
+
'''
|
59 |
+
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
60 |
+
|
61 |
+
|
62 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
63 |
+
if not os.path.isdir(checkpoint_dir):
|
64 |
+
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
65 |
+
|
66 |
+
# there should be only one file
|
67 |
+
if zero_stage == 2:
|
68 |
+
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
69 |
+
elif zero_stage == 3:
|
70 |
+
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
71 |
+
|
72 |
+
if not os.path.exists(file):
|
73 |
+
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
74 |
+
|
75 |
+
return file
|
76 |
+
|
77 |
+
|
78 |
+
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
79 |
+
# XXX: need to test that this simple glob rule works for multi-node setup too
|
80 |
+
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
81 |
+
|
82 |
+
if len(ckpt_files) == 0:
|
83 |
+
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
84 |
+
|
85 |
+
return ckpt_files
|
86 |
+
|
87 |
+
|
88 |
+
def get_optim_files(checkpoint_dir):
|
89 |
+
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
90 |
+
|
91 |
+
|
92 |
+
def get_model_state_files(checkpoint_dir):
|
93 |
+
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
94 |
+
|
95 |
+
|
96 |
+
def parse_model_states(files):
|
97 |
+
zero_model_states = []
|
98 |
+
for file in files:
|
99 |
+
state_dict = torch.load(file, map_location=device)
|
100 |
+
|
101 |
+
if BUFFER_NAMES not in state_dict:
|
102 |
+
raise ValueError(f"{file} is not a model state checkpoint")
|
103 |
+
buffer_names = state_dict[BUFFER_NAMES]
|
104 |
+
if debug:
|
105 |
+
print("Found buffers:", buffer_names)
|
106 |
+
|
107 |
+
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
108 |
+
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
109 |
+
param_shapes = state_dict[PARAM_SHAPES]
|
110 |
+
|
111 |
+
# collect parameters that are included in param_shapes
|
112 |
+
param_names = []
|
113 |
+
for s in param_shapes:
|
114 |
+
for name in s.keys():
|
115 |
+
param_names.append(name)
|
116 |
+
|
117 |
+
# update with frozen parameters
|
118 |
+
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
119 |
+
if frozen_param_shapes is not None:
|
120 |
+
if debug:
|
121 |
+
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
122 |
+
param_names += list(frozen_param_shapes.keys())
|
123 |
+
|
124 |
+
# record shared parameters so that they can be recovered based on partners
|
125 |
+
# this is because such parameters holding reference only are not saved by optimizer
|
126 |
+
shared_params = []
|
127 |
+
for param in state_dict["module"]:
|
128 |
+
if param not in [*param_names, *buffer_names]:
|
129 |
+
for share_param in state_dict["module"]:
|
130 |
+
if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr()
|
131 |
+
and share_param != param):
|
132 |
+
shared_params.append([param, share_param])
|
133 |
+
break
|
134 |
+
|
135 |
+
ds_version = state_dict.get(DS_VERSION, None)
|
136 |
+
|
137 |
+
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
138 |
+
|
139 |
+
z_model_state = zero_model_state(buffers=buffers,
|
140 |
+
param_shapes=param_shapes,
|
141 |
+
shared_params=shared_params,
|
142 |
+
ds_version=ds_version,
|
143 |
+
frozen_param_shapes=frozen_param_shapes,
|
144 |
+
frozen_param_fragments=frozen_param_fragments)
|
145 |
+
zero_model_states.append(z_model_state)
|
146 |
+
|
147 |
+
return zero_model_states
|
148 |
+
|
149 |
+
|
150 |
+
def parse_optim_states(files, ds_checkpoint_dir):
|
151 |
+
|
152 |
+
total_files = len(files)
|
153 |
+
state_dicts = []
|
154 |
+
for i, f in enumerate(tqdm(files)):
|
155 |
+
state_dicts.append(torch.load(f, map_location=device))
|
156 |
+
if i == 0:
|
157 |
+
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
158 |
+
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
159 |
+
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
160 |
+
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
161 |
+
|
162 |
+
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
163 |
+
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
164 |
+
# use the max of the partition_count to get the dp world_size.
|
165 |
+
|
166 |
+
if type(world_size) is list:
|
167 |
+
world_size = max(world_size)
|
168 |
+
|
169 |
+
if world_size != total_files:
|
170 |
+
raise ValueError(
|
171 |
+
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
172 |
+
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
173 |
+
)
|
174 |
+
|
175 |
+
# the groups are named differently in each stage
|
176 |
+
if zero_stage == 2:
|
177 |
+
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
178 |
+
elif zero_stage == 3:
|
179 |
+
fp32_groups_key = FP32_FLAT_GROUPS
|
180 |
+
else:
|
181 |
+
raise ValueError(f"unknown zero stage {zero_stage}")
|
182 |
+
|
183 |
+
key_list = list(state_dicts[-1][OPTIMIZER_STATE_DICT].keys())
|
184 |
+
for key in key_list:
|
185 |
+
if zero_stage == 2:
|
186 |
+
if key != fp32_groups_key:
|
187 |
+
del state_dicts[-1][OPTIMIZER_STATE_DICT][key]
|
188 |
+
elif zero_stage == 3:
|
189 |
+
if key == fp32_groups_key:
|
190 |
+
value = torch.cat(state_dicts[-1][OPTIMIZER_STATE_DICT][fp32_groups_key], 0)
|
191 |
+
del state_dicts[-1][OPTIMIZER_STATE_DICT][key]
|
192 |
+
if key == fp32_groups_key:
|
193 |
+
state_dicts[-1][OPTIMIZER_STATE_DICT][key] = value
|
194 |
+
|
195 |
+
print('zero_stage:', zero_stage)
|
196 |
+
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
197 |
+
# if zero_stage == 2:
|
198 |
+
# # fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
199 |
+
# elif zero_stage == 3:
|
200 |
+
# # if there is more than one param group, there will be multiple flattened tensors - one
|
201 |
+
# # flattened tensor per group - for simplicity merge them into a single tensor
|
202 |
+
# #
|
203 |
+
# # XXX: could make the script more memory efficient for when there are multiple groups - it
|
204 |
+
# # will require matching the sub-lists of param_shapes for each param group flattened tensor
|
205 |
+
|
206 |
+
# print('start!')
|
207 |
+
# # fp32_flat_groups = [
|
208 |
+
# # torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
|
209 |
+
# # ]
|
210 |
+
|
211 |
+
return zero_stage, world_size, fp32_flat_groups
|
212 |
+
|
213 |
+
|
214 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
|
215 |
+
"""
|
216 |
+
Returns fp32 state_dict reconstructed from ds checkpoint
|
217 |
+
|
218 |
+
Args:
|
219 |
+
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
220 |
+
|
221 |
+
"""
|
222 |
+
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
223 |
+
|
224 |
+
optim_files = get_optim_files(ds_checkpoint_dir)
|
225 |
+
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
226 |
+
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
227 |
+
|
228 |
+
model_files = get_model_state_files(ds_checkpoint_dir)
|
229 |
+
|
230 |
+
zero_model_states = parse_model_states(model_files)
|
231 |
+
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
232 |
+
|
233 |
+
if zero_stage == 2:
|
234 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
|
235 |
+
elif zero_stage == 3:
|
236 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
|
237 |
+
|
238 |
+
|
239 |
+
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
240 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
241 |
+
return
|
242 |
+
|
243 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
244 |
+
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
245 |
+
|
246 |
+
if debug:
|
247 |
+
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
248 |
+
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
249 |
+
|
250 |
+
wanted_params = len(frozen_param_shapes)
|
251 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
252 |
+
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
253 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
254 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
255 |
+
|
256 |
+
total_params = 0
|
257 |
+
total_numel = 0
|
258 |
+
for name, shape in frozen_param_shapes.items():
|
259 |
+
total_params += 1
|
260 |
+
unpartitioned_numel = shape.numel()
|
261 |
+
total_numel += unpartitioned_numel
|
262 |
+
|
263 |
+
state_dict[name] = frozen_param_fragments[name]
|
264 |
+
|
265 |
+
if debug:
|
266 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
267 |
+
|
268 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
269 |
+
|
270 |
+
|
271 |
+
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
272 |
+
param_shapes = zero_model_states[0].param_shapes
|
273 |
+
|
274 |
+
# Reconstruction protocol:
|
275 |
+
#
|
276 |
+
# XXX: document this
|
277 |
+
|
278 |
+
if debug:
|
279 |
+
for i in range(world_size):
|
280 |
+
for j in range(len(fp32_flat_groups[0])):
|
281 |
+
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
282 |
+
|
283 |
+
# XXX: memory usage doubles here (zero2)
|
284 |
+
num_param_groups = len(fp32_flat_groups[0])
|
285 |
+
merged_single_partition_of_fp32_groups = []
|
286 |
+
for i in range(num_param_groups):
|
287 |
+
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
288 |
+
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
289 |
+
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
290 |
+
avail_numel = sum(
|
291 |
+
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
292 |
+
|
293 |
+
if debug:
|
294 |
+
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
295 |
+
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
296 |
+
# not asserting if there is a mismatch due to possible padding
|
297 |
+
print(f"Have {avail_numel} numels to process.")
|
298 |
+
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
299 |
+
|
300 |
+
# params
|
301 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
302 |
+
# out-of-core computing solution
|
303 |
+
total_numel = 0
|
304 |
+
total_params = 0
|
305 |
+
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
306 |
+
offset = 0
|
307 |
+
avail_numel = full_single_fp32_vector.numel()
|
308 |
+
for name, shape in shapes.items():
|
309 |
+
|
310 |
+
unpartitioned_numel = shape.numel()
|
311 |
+
total_numel += unpartitioned_numel
|
312 |
+
total_params += 1
|
313 |
+
|
314 |
+
if debug:
|
315 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
316 |
+
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
317 |
+
offset += unpartitioned_numel
|
318 |
+
|
319 |
+
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
320 |
+
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
321 |
+
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
322 |
+
# live optimizer object, so we are checking that the numbers are within the right range
|
323 |
+
align_to = 2 * world_size
|
324 |
+
|
325 |
+
def zero2_align(x):
|
326 |
+
return align_to * math.ceil(x / align_to)
|
327 |
+
|
328 |
+
if debug:
|
329 |
+
print(f"original offset={offset}, avail_numel={avail_numel}")
|
330 |
+
|
331 |
+
offset = zero2_align(offset)
|
332 |
+
avail_numel = zero2_align(avail_numel)
|
333 |
+
|
334 |
+
if debug:
|
335 |
+
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
336 |
+
|
337 |
+
# Sanity check
|
338 |
+
if offset != avail_numel:
|
339 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
340 |
+
|
341 |
+
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
342 |
+
|
343 |
+
|
344 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
|
345 |
+
state_dict = OrderedDict()
|
346 |
+
|
347 |
+
# buffers
|
348 |
+
buffers = zero_model_states[0].buffers
|
349 |
+
state_dict.update(buffers)
|
350 |
+
if debug:
|
351 |
+
print(f"added {len(buffers)} buffers")
|
352 |
+
|
353 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
354 |
+
|
355 |
+
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
356 |
+
|
357 |
+
# recover shared parameters
|
358 |
+
for pair in zero_model_states[0].shared_params:
|
359 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
360 |
+
|
361 |
+
return state_dict
|
362 |
+
|
363 |
+
|
364 |
+
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
365 |
+
remainder = unpartitioned_numel % world_size
|
366 |
+
padding_numel = (world_size - remainder) if remainder else 0
|
367 |
+
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
368 |
+
return partitioned_numel, padding_numel
|
369 |
+
|
370 |
+
|
371 |
+
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
372 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
373 |
+
return
|
374 |
+
|
375 |
+
if debug:
|
376 |
+
for i in range(world_size):
|
377 |
+
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
378 |
+
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
379 |
+
|
380 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
381 |
+
wanted_params = len(frozen_param_shapes)
|
382 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
383 |
+
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
384 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
385 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
386 |
+
|
387 |
+
total_params = 0
|
388 |
+
total_numel = 0
|
389 |
+
for name, shape in tqdm(zero_model_states[0].frozen_param_shapes.items()):
|
390 |
+
total_params += 1
|
391 |
+
unpartitioned_numel = shape.numel()
|
392 |
+
total_numel += unpartitioned_numel
|
393 |
+
|
394 |
+
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
395 |
+
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
396 |
+
|
397 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
398 |
+
|
399 |
+
if debug:
|
400 |
+
print(
|
401 |
+
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
402 |
+
)
|
403 |
+
|
404 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
405 |
+
|
406 |
+
|
407 |
+
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
408 |
+
param_shapes = zero_model_states[0].param_shapes
|
409 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
410 |
+
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
411 |
+
# param, re-consolidating each param, while dealing with padding if any
|
412 |
+
|
413 |
+
# merge list of dicts, preserving order
|
414 |
+
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
415 |
+
|
416 |
+
if debug:
|
417 |
+
for i in range(world_size):
|
418 |
+
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
419 |
+
|
420 |
+
wanted_params = len(param_shapes)
|
421 |
+
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
422 |
+
# not asserting if there is a mismatch due to possible padding
|
423 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
424 |
+
print(f"Trainable params: Have {avail_numel} numels to process.")
|
425 |
+
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
426 |
+
|
427 |
+
# params
|
428 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
429 |
+
# out-of-core computing solution
|
430 |
+
offset = 0
|
431 |
+
total_numel = 0
|
432 |
+
total_params = 0
|
433 |
+
for name, shape in tqdm(param_shapes.items()):
|
434 |
+
|
435 |
+
unpartitioned_numel = shape.numel()
|
436 |
+
total_numel += unpartitioned_numel
|
437 |
+
total_params += 1
|
438 |
+
|
439 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
440 |
+
|
441 |
+
if debug:
|
442 |
+
print(
|
443 |
+
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
444 |
+
)
|
445 |
+
|
446 |
+
# XXX: memory usage doubles here
|
447 |
+
state_dict[name] = torch.cat(
|
448 |
+
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
|
449 |
+
0).narrow(0, 0, unpartitioned_numel).view(shape)
|
450 |
+
offset += partitioned_numel
|
451 |
+
|
452 |
+
offset *= world_size
|
453 |
+
|
454 |
+
# Sanity check
|
455 |
+
if offset != avail_numel:
|
456 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
457 |
+
|
458 |
+
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
459 |
+
|
460 |
+
|
461 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
|
462 |
+
state_dict = OrderedDict()
|
463 |
+
|
464 |
+
# buffers
|
465 |
+
buffers = zero_model_states[0].buffers
|
466 |
+
state_dict.update(buffers)
|
467 |
+
if debug:
|
468 |
+
print(f"added {len(buffers)} buffers")
|
469 |
+
|
470 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
471 |
+
|
472 |
+
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
473 |
+
|
474 |
+
# recover shared parameters
|
475 |
+
for pair in zero_model_states[0].shared_params:
|
476 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
477 |
+
|
478 |
+
return state_dict
|
479 |
+
|
480 |
+
|
481 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
|
482 |
+
"""
|
483 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
484 |
+
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
485 |
+
via a model hub.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder
|
489 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
- pytorch ``state_dict``
|
493 |
+
|
494 |
+
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
495 |
+
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
496 |
+
the checkpoint.
|
497 |
+
|
498 |
+
A typical usage might be ::
|
499 |
+
|
500 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
501 |
+
# do the training and checkpoint saving
|
502 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
503 |
+
model = model.cpu() # move to cpu
|
504 |
+
model.load_state_dict(state_dict)
|
505 |
+
# submit to model hub or save the model to share with others
|
506 |
+
|
507 |
+
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
508 |
+
application. i.e. you will need to re-initialize the deepspeed engine, since
|
509 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
510 |
+
|
511 |
+
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
512 |
+
|
513 |
+
"""
|
514 |
+
if tag is None:
|
515 |
+
latest_path = os.path.join(checkpoint_dir, 'latest')
|
516 |
+
if os.path.isfile(latest_path):
|
517 |
+
with open(latest_path, 'r') as fd:
|
518 |
+
tag = fd.read().strip()
|
519 |
+
else:
|
520 |
+
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
521 |
+
|
522 |
+
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
523 |
+
|
524 |
+
if not os.path.isdir(ds_checkpoint_dir):
|
525 |
+
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
526 |
+
|
527 |
+
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
|
528 |
+
|
529 |
+
|
530 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
|
531 |
+
"""
|
532 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
533 |
+
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
537 |
+
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
538 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
539 |
+
"""
|
540 |
+
|
541 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
542 |
+
print(f"Saving fp32 state dict to {output_file}")
|
543 |
+
torch.save(state_dict, output_file)
|
544 |
+
|
545 |
+
|
546 |
+
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
547 |
+
"""
|
548 |
+
1. Put the provided model to cpu
|
549 |
+
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
550 |
+
3. Load it into the provided model
|
551 |
+
|
552 |
+
Args:
|
553 |
+
- ``model``: the model object to update
|
554 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
555 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
556 |
+
|
557 |
+
Returns:
|
558 |
+
- ``model`: modified model
|
559 |
+
|
560 |
+
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
561 |
+
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
562 |
+
conveniently placed for you in the checkpoint folder.
|
563 |
+
|
564 |
+
A typical usage might be ::
|
565 |
+
|
566 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
567 |
+
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
568 |
+
# submit to model hub or save the model to share with others
|
569 |
+
|
570 |
+
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
571 |
+
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
572 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
573 |
+
|
574 |
+
"""
|
575 |
+
logger.info(f"Extracting fp32 weights")
|
576 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
577 |
+
|
578 |
+
logger.info(f"Overwriting model with fp32 weights")
|
579 |
+
model = model.cpu()
|
580 |
+
model.load_state_dict(state_dict, strict=False)
|
581 |
+
|
582 |
+
return model
|
583 |
+
|
584 |
+
|
585 |
+
if __name__ == "__main__":
|
586 |
+
|
587 |
+
parser = argparse.ArgumentParser()
|
588 |
+
parser.add_argument("checkpoint_dir",
|
589 |
+
type=str,
|
590 |
+
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
591 |
+
parser.add_argument(
|
592 |
+
"output_file",
|
593 |
+
type=str,
|
594 |
+
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
|
595 |
+
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
596 |
+
args = parser.parse_args()
|
597 |
+
|
598 |
+
debug = args.debug
|
599 |
+
|
600 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
|
data.py
ADDED
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from tqdm import tqdm
|
9 |
+
from collections import defaultdict
|
10 |
+
import torch.distributed as dist
|
11 |
+
import logging
|
12 |
+
import re
|
13 |
+
import pdb
|
14 |
+
import json
|
15 |
+
from prompt import sft_prompt, all_prompt
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
|
19 |
+
class BaseDataset(Dataset):
|
20 |
+
|
21 |
+
def __init__(self, args):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.args = args
|
25 |
+
self.dataset = args.dataset
|
26 |
+
self.data_path = os.path.join(args.data_path, self.dataset)
|
27 |
+
|
28 |
+
self.max_his_len = args.max_his_len
|
29 |
+
self.his_sep = args.his_sep
|
30 |
+
self.index_file = args.index_file
|
31 |
+
self.add_prefix = args.add_prefix
|
32 |
+
|
33 |
+
self.new_tokens = None
|
34 |
+
self.allowed_tokens = None
|
35 |
+
self.all_items = None
|
36 |
+
|
37 |
+
|
38 |
+
def _load_data(self):
|
39 |
+
|
40 |
+
with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
|
41 |
+
self.indices = json.load(f)
|
42 |
+
|
43 |
+
def get_new_tokens(self):
|
44 |
+
|
45 |
+
if self.new_tokens is not None:
|
46 |
+
return self.new_tokens
|
47 |
+
|
48 |
+
self.new_tokens = set()
|
49 |
+
for index in self.indices.values():
|
50 |
+
for token in index:
|
51 |
+
self.new_tokens.add(token)
|
52 |
+
self.new_tokens = sorted(list(self.new_tokens))
|
53 |
+
|
54 |
+
return self.new_tokens
|
55 |
+
|
56 |
+
def get_all_items(self):
|
57 |
+
|
58 |
+
if self.all_items is not None:
|
59 |
+
return self.all_items
|
60 |
+
|
61 |
+
self.all_items = set()
|
62 |
+
for index in self.indices.values():
|
63 |
+
self.all_items.add("".join(index))
|
64 |
+
|
65 |
+
return self.all_items
|
66 |
+
|
67 |
+
def get_prefix_allowed_tokens_fn(self, tokenizer):
|
68 |
+
|
69 |
+
|
70 |
+
if self.allowed_tokens is None:
|
71 |
+
self.allowed_tokens = {}
|
72 |
+
for index in self.indices.values():
|
73 |
+
for i, token in enumerate(index):
|
74 |
+
token_id = tokenizer(token)["input_ids"][1]
|
75 |
+
if i not in self.allowed_tokens.keys():
|
76 |
+
self.allowed_tokens[i] = set()
|
77 |
+
self.allowed_tokens[i].add(token_id)
|
78 |
+
self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id])
|
79 |
+
sep = tokenizer("Response:")["input_ids"][1:]
|
80 |
+
|
81 |
+
def prefix_allowed_tokens_fn(batch_id, sentence):
|
82 |
+
sentence = sentence.tolist()
|
83 |
+
reversed_sent = sentence[::-1]
|
84 |
+
for i in range(len(reversed_sent)):
|
85 |
+
if reversed_sent[i:i + len(sep)] == sep[::-1]:
|
86 |
+
# print(list(self.allowed_tokens[i]))
|
87 |
+
return list(self.allowed_tokens[i])
|
88 |
+
|
89 |
+
return prefix_allowed_tokens_fn
|
90 |
+
|
91 |
+
def _process_data(self):
|
92 |
+
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
class SeqRecDataset(BaseDataset):
|
98 |
+
|
99 |
+
def __init__(self, args, mode="train",
|
100 |
+
prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
101 |
+
super().__init__(args)
|
102 |
+
|
103 |
+
self.mode = mode
|
104 |
+
self.prompt_sample_num = prompt_sample_num
|
105 |
+
self.prompt_id = prompt_id
|
106 |
+
self.sample_num = sample_num
|
107 |
+
|
108 |
+
self.prompts = all_prompt["seqrec"]
|
109 |
+
|
110 |
+
|
111 |
+
# load data
|
112 |
+
self._load_data()
|
113 |
+
self._remap_items()
|
114 |
+
|
115 |
+
# load data
|
116 |
+
if self.mode == 'train':
|
117 |
+
self.inter_data = self._process_train_data()
|
118 |
+
elif self.mode == 'valid':
|
119 |
+
self.sample_valid = args.sample_valid
|
120 |
+
self.valid_prompt_id = args.valid_prompt_id
|
121 |
+
self.inter_data = self._process_valid_data()
|
122 |
+
self._construct_valid_text()
|
123 |
+
elif self.mode == 'test':
|
124 |
+
self.inter_data = self._process_test_data()
|
125 |
+
else:
|
126 |
+
raise NotImplementedError
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
def _load_data(self):
|
131 |
+
|
132 |
+
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
133 |
+
self.inters = json.load(f)
|
134 |
+
with open(self.index_file, 'r') as f:
|
135 |
+
self.indices = json.load(f)
|
136 |
+
|
137 |
+
|
138 |
+
def _remap_items(self):
|
139 |
+
|
140 |
+
self.remapped_inters = dict()
|
141 |
+
for uid, items in self.inters.items():
|
142 |
+
new_items = ["".join(self.indices[str(i)]) for i in items]
|
143 |
+
self.remapped_inters[uid] = new_items
|
144 |
+
|
145 |
+
|
146 |
+
def _process_train_data(self):
|
147 |
+
|
148 |
+
inter_data = []
|
149 |
+
for uid in self.remapped_inters:
|
150 |
+
items = self.remapped_inters[uid][:-2]
|
151 |
+
for i in range(1, len(items)):
|
152 |
+
one_data = dict()
|
153 |
+
# one_data["user"] = uid
|
154 |
+
one_data["item"] = items[i]
|
155 |
+
history = items[:i]
|
156 |
+
if self.max_his_len > 0:
|
157 |
+
history = history[-self.max_his_len:]
|
158 |
+
if self.add_prefix:
|
159 |
+
history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
160 |
+
one_data["inters"] = self.his_sep.join(history)
|
161 |
+
inter_data.append(one_data)
|
162 |
+
|
163 |
+
return inter_data
|
164 |
+
|
165 |
+
def _process_valid_data(self):
|
166 |
+
|
167 |
+
inter_data = []
|
168 |
+
for uid in self.remapped_inters:
|
169 |
+
items = self.remapped_inters[uid]
|
170 |
+
one_data = dict()
|
171 |
+
# one_data["user"] = uid
|
172 |
+
one_data["item"] = items[-2]
|
173 |
+
history = items[:-2]
|
174 |
+
if self.max_his_len > 0:
|
175 |
+
history = history[-self.max_his_len:]
|
176 |
+
if self.add_prefix:
|
177 |
+
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
178 |
+
one_data["inters"] = self.his_sep.join(history)
|
179 |
+
inter_data.append(one_data)
|
180 |
+
|
181 |
+
return inter_data
|
182 |
+
|
183 |
+
def _process_test_data(self):
|
184 |
+
|
185 |
+
inter_data = []
|
186 |
+
for uid in self.remapped_inters:
|
187 |
+
items = self.remapped_inters[uid]
|
188 |
+
one_data = dict()
|
189 |
+
# one_data["user"] = uid
|
190 |
+
one_data["item"] = items[-1]
|
191 |
+
history = items[:-1]
|
192 |
+
if self.max_his_len > 0:
|
193 |
+
history = history[-self.max_his_len:]
|
194 |
+
if self.add_prefix:
|
195 |
+
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
196 |
+
one_data["inters"] = self.his_sep.join(history)
|
197 |
+
inter_data.append(one_data)
|
198 |
+
|
199 |
+
if self.sample_num > 0:
|
200 |
+
all_inter_idx = range(len(inter_data))
|
201 |
+
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
202 |
+
inter_data = np.array(inter_data)[sample_idx].tolist()
|
203 |
+
|
204 |
+
return inter_data
|
205 |
+
|
206 |
+
def set_prompt(self, prompt_id):
|
207 |
+
|
208 |
+
self.prompt_id = prompt_id
|
209 |
+
|
210 |
+
def __len__(self):
|
211 |
+
if self.mode == 'train':
|
212 |
+
return len(self.inter_data) * self.prompt_sample_num
|
213 |
+
elif self.mode == 'valid':
|
214 |
+
return len(self.valid_text_data)
|
215 |
+
elif self.mode == 'test':
|
216 |
+
return len(self.inter_data)
|
217 |
+
else:
|
218 |
+
raise NotImplementedError
|
219 |
+
|
220 |
+
def _construct_valid_text(self):
|
221 |
+
self.valid_text_data = []
|
222 |
+
if self.sample_valid:
|
223 |
+
all_prompt_ids = range(len(self.prompts))
|
224 |
+
for i in range(len(self.inter_data)):
|
225 |
+
d = self.inter_data[i]
|
226 |
+
prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
|
227 |
+
for prompt_id in prompt_ids:
|
228 |
+
prompt = self.prompts[prompt_id]
|
229 |
+
input, output = self._get_text_data(d, prompt)
|
230 |
+
self.valid_text_data.append({"input_ids": input, "labels": output})
|
231 |
+
else:
|
232 |
+
self.prompt_sample_num = 1
|
233 |
+
prompt = self.prompts[self.valid_prompt_id]
|
234 |
+
for i in range(len(self.inter_data)):
|
235 |
+
d = self.inter_data[i]
|
236 |
+
input, output = self._get_text_data(d, prompt)
|
237 |
+
self.valid_text_data.append({"input_ids": input, "labels": output})
|
238 |
+
|
239 |
+
def _get_text_data(self, data, prompt):
|
240 |
+
|
241 |
+
instruction = prompt["instruction"].format(**data)
|
242 |
+
response = prompt["response"].format(**data)
|
243 |
+
|
244 |
+
input = sft_prompt.format(instruction = instruction, response = "")
|
245 |
+
output = sft_prompt.format(instruction = instruction, response = response)
|
246 |
+
|
247 |
+
if self.mode == 'test':
|
248 |
+
return input, response
|
249 |
+
|
250 |
+
return input, output
|
251 |
+
|
252 |
+
def __getitem__(self, index):
|
253 |
+
|
254 |
+
if self.mode == 'valid':
|
255 |
+
return self.valid_text_data[index]
|
256 |
+
|
257 |
+
idx = index // self.prompt_sample_num
|
258 |
+
d = self.inter_data[idx]
|
259 |
+
# print(index, idx)
|
260 |
+
|
261 |
+
if self.mode == 'train':
|
262 |
+
prompt_id = random.randint(0, len(self.prompts) - 1)
|
263 |
+
elif self.mode == 'test':
|
264 |
+
prompt_id = self.prompt_id
|
265 |
+
|
266 |
+
prompt = self.prompts[prompt_id]
|
267 |
+
|
268 |
+
input, output = self._get_text_data(d, prompt)
|
269 |
+
|
270 |
+
# print({"input": input, "output": output})
|
271 |
+
|
272 |
+
return dict(input_ids=input, labels=output)
|
273 |
+
|
274 |
+
|
275 |
+
class FusionSeqRecDataset(BaseDataset):
|
276 |
+
|
277 |
+
def __init__(self, args, mode="train",
|
278 |
+
prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
279 |
+
super().__init__(args)
|
280 |
+
|
281 |
+
self.mode = mode
|
282 |
+
self.prompt_sample_num = prompt_sample_num
|
283 |
+
self.prompt_id = prompt_id
|
284 |
+
self.sample_num = sample_num
|
285 |
+
|
286 |
+
self.prompts = all_prompt["fusionseqrec"]
|
287 |
+
|
288 |
+
# load data
|
289 |
+
self._load_data()
|
290 |
+
# self._remap_items()
|
291 |
+
|
292 |
+
# load data
|
293 |
+
if self.mode == 'train':
|
294 |
+
self.inter_data = self._process_train_data()
|
295 |
+
elif self.mode == 'valid':
|
296 |
+
self.sample_valid = args.sample_valid
|
297 |
+
self.valid_prompt_id = args.valid_prompt_id
|
298 |
+
self.inter_data = self._process_valid_data()
|
299 |
+
self._construct_valid_text()
|
300 |
+
elif self.mode == 'test':
|
301 |
+
self.inter_data = self._process_test_data()
|
302 |
+
else:
|
303 |
+
raise NotImplementedError
|
304 |
+
|
305 |
+
|
306 |
+
def _load_data(self):
|
307 |
+
|
308 |
+
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
309 |
+
self.inters = json.load(f)
|
310 |
+
with open(self.index_file, 'r') as f:
|
311 |
+
self.indices = json.load(f)
|
312 |
+
with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
|
313 |
+
self.item_feat = json.load(f)
|
314 |
+
|
315 |
+
def _process_train_data(self):
|
316 |
+
|
317 |
+
inter_data = []
|
318 |
+
for uid in self.inters:
|
319 |
+
items = self.inters[uid][:-2]
|
320 |
+
for i in range(1, len(items)):
|
321 |
+
one_data = dict()
|
322 |
+
# one_data["user"] = uid
|
323 |
+
one_data["item"] = "".join(self.indices[str(items[i])])
|
324 |
+
one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`")
|
325 |
+
one_data["description"] = self.item_feat[str(items[i])]["description"]
|
326 |
+
history = items[:i]
|
327 |
+
if self.max_his_len > 0:
|
328 |
+
history = history[-self.max_his_len:]
|
329 |
+
inters = ["".join(self.indices[str(j)]) for j in history]
|
330 |
+
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
331 |
+
|
332 |
+
|
333 |
+
if self.add_prefix:
|
334 |
+
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
335 |
+
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
336 |
+
|
337 |
+
one_data["inters"] = self.his_sep.join(inters)
|
338 |
+
one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
339 |
+
inter_data.append(one_data)
|
340 |
+
|
341 |
+
if self.sample_num > 0:
|
342 |
+
all_inter_idx = range(len(inter_data))
|
343 |
+
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
344 |
+
inter_data = np.array(inter_data)[sample_idx].tolist()
|
345 |
+
|
346 |
+
return inter_data
|
347 |
+
|
348 |
+
def _process_valid_data(self):
|
349 |
+
|
350 |
+
inter_data = []
|
351 |
+
for uid in self.inters:
|
352 |
+
items = self.inters[uid]
|
353 |
+
one_data = dict()
|
354 |
+
one_data["item"] = "".join(self.indices[str(items[-2])])
|
355 |
+
one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`")
|
356 |
+
one_data["description"] = self.item_feat[str(items[-2])]["description"]
|
357 |
+
|
358 |
+
|
359 |
+
history = items[:-2]
|
360 |
+
if self.max_his_len > 0:
|
361 |
+
history = history[-self.max_his_len:]
|
362 |
+
inters = ["".join(self.indices[str(j)]) for j in history]
|
363 |
+
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
364 |
+
|
365 |
+
if self.add_prefix:
|
366 |
+
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
367 |
+
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
368 |
+
|
369 |
+
one_data["inters"] = self.his_sep.join(inters)
|
370 |
+
one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
371 |
+
inter_data.append(one_data)
|
372 |
+
|
373 |
+
if self.sample_num > 0:
|
374 |
+
all_inter_idx = range(len(inter_data))
|
375 |
+
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
376 |
+
inter_data = np.array(inter_data)[sample_idx].tolist()
|
377 |
+
|
378 |
+
return inter_data
|
379 |
+
|
380 |
+
def _process_test_data(self):
|
381 |
+
|
382 |
+
inter_data = []
|
383 |
+
for uid in self.inters:
|
384 |
+
items = self.inters[uid]
|
385 |
+
one_data = dict()
|
386 |
+
one_data["item"] = "".join(self.indices[str(items[-1])])
|
387 |
+
one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`")
|
388 |
+
one_data["description"] = self.item_feat[str(items[-1])]["description"]
|
389 |
+
|
390 |
+
history = items[:-1]
|
391 |
+
if self.max_his_len > 0:
|
392 |
+
history = history[-self.max_his_len:]
|
393 |
+
inters = ["".join(self.indices[str(j)]) for j in history]
|
394 |
+
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
395 |
+
|
396 |
+
if self.add_prefix:
|
397 |
+
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
398 |
+
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
399 |
+
|
400 |
+
one_data["inters"] = self.his_sep.join(inters)
|
401 |
+
one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
402 |
+
inter_data.append(one_data)
|
403 |
+
|
404 |
+
if self.sample_num > 0:
|
405 |
+
all_inter_idx = range(len(inter_data))
|
406 |
+
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
407 |
+
inter_data = np.array(inter_data)[sample_idx].tolist()
|
408 |
+
|
409 |
+
return inter_data
|
410 |
+
|
411 |
+
def set_prompt(self, prompt_id):
|
412 |
+
|
413 |
+
self.prompt_id = prompt_id
|
414 |
+
|
415 |
+
def __len__(self):
|
416 |
+
if self.mode == 'train':
|
417 |
+
return len(self.inter_data) * self.prompt_sample_num
|
418 |
+
elif self.mode == 'valid':
|
419 |
+
return len(self.valid_text_data)
|
420 |
+
elif self.mode == 'test':
|
421 |
+
return len(self.inter_data)
|
422 |
+
else:
|
423 |
+
raise NotImplementedError
|
424 |
+
|
425 |
+
def _construct_valid_text(self):
|
426 |
+
self.valid_text_data = []
|
427 |
+
if self.sample_valid:
|
428 |
+
all_prompt_ids = range(len(self.prompts))
|
429 |
+
for i in range(len(self.inter_data)):
|
430 |
+
d = self.inter_data[i]
|
431 |
+
prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
|
432 |
+
for prompt_id in prompt_ids:
|
433 |
+
prompt = self.prompts[prompt_id]
|
434 |
+
input, output = self._get_text_data(d, prompt)
|
435 |
+
self.valid_text_data.append({"input_ids": input, "labels": output})
|
436 |
+
else:
|
437 |
+
self.prompt_sample_num = 1
|
438 |
+
prompt = self.prompts[self.valid_prompt_id]
|
439 |
+
for i in range(len(self.inter_data)):
|
440 |
+
d = self.inter_data[i]
|
441 |
+
input, output = self._get_text_data(d, prompt)
|
442 |
+
self.valid_text_data.append({"input_ids": input, "labels": output})
|
443 |
+
|
444 |
+
def _get_text_data(self, data, prompt):
|
445 |
+
|
446 |
+
instruction = prompt["instruction"].format(**data)
|
447 |
+
response = prompt["response"].format(**data)
|
448 |
+
|
449 |
+
input = sft_prompt.format(instruction=instruction, response="")
|
450 |
+
output = sft_prompt.format(instruction=instruction, response=response)
|
451 |
+
|
452 |
+
if self.mode == 'test':
|
453 |
+
return input, response
|
454 |
+
|
455 |
+
return input, output
|
456 |
+
|
457 |
+
def __getitem__(self, index):
|
458 |
+
|
459 |
+
if self.mode == 'valid':
|
460 |
+
return self.valid_text_data[index]
|
461 |
+
|
462 |
+
idx = index // self.prompt_sample_num
|
463 |
+
d = self.inter_data[idx]
|
464 |
+
|
465 |
+
if self.mode == 'train':
|
466 |
+
prompt_id = random.randint(0, len(self.prompts) - 1)
|
467 |
+
elif self.mode == 'test':
|
468 |
+
prompt_id = self.prompt_id
|
469 |
+
|
470 |
+
prompt = self.prompts[prompt_id]
|
471 |
+
|
472 |
+
input, output = self._get_text_data(d, prompt)
|
473 |
+
|
474 |
+
|
475 |
+
return dict(input_ids=input, labels=output)
|
476 |
+
|
477 |
+
|
478 |
+
class ItemFeatDataset(BaseDataset):
|
479 |
+
|
480 |
+
def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1):
|
481 |
+
super().__init__(args)
|
482 |
+
|
483 |
+
self.task = task.lower()
|
484 |
+
self.prompt_sample_num = prompt_sample_num
|
485 |
+
self.sample_num = sample_num
|
486 |
+
|
487 |
+
self.prompts = all_prompt[self.task]
|
488 |
+
|
489 |
+
# load data
|
490 |
+
self._load_data()
|
491 |
+
self.feat_data = self._process_data()
|
492 |
+
|
493 |
+
|
494 |
+
|
495 |
+
def _load_data(self):
|
496 |
+
|
497 |
+
with open(self.index_file, 'r') as f:
|
498 |
+
self.indices = json.load(f)
|
499 |
+
with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
|
500 |
+
self.item_feat = json.load(f)
|
501 |
+
|
502 |
+
|
503 |
+
def _process_data(self):
|
504 |
+
|
505 |
+
feat_data = []
|
506 |
+
for iid in self.item_feat:
|
507 |
+
feat = self.item_feat[iid]
|
508 |
+
index = "".join(self.indices[iid])
|
509 |
+
feat["item"] = index
|
510 |
+
feat["title"] = feat["title"].strip().strip(".!?,;:`")
|
511 |
+
feat_data.append(feat)
|
512 |
+
|
513 |
+
if self.sample_num > 0:
|
514 |
+
all_idx = range(len(feat_data))
|
515 |
+
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
516 |
+
|
517 |
+
feat_data = np.array(feat_data)[sample_idx].tolist()
|
518 |
+
|
519 |
+
return feat_data
|
520 |
+
|
521 |
+
|
522 |
+
def __len__(self):
|
523 |
+
return len(self.feat_data) * self.prompt_sample_num
|
524 |
+
|
525 |
+
def _get_text_data(self, data, prompt):
|
526 |
+
|
527 |
+
instruction = prompt["instruction"].format(**data)
|
528 |
+
response = prompt["response"].format(**data)
|
529 |
+
|
530 |
+
input = sft_prompt.format(instruction = instruction, response = "")
|
531 |
+
output = sft_prompt.format(instruction = instruction, response = response)
|
532 |
+
|
533 |
+
return input, output
|
534 |
+
|
535 |
+
def __getitem__(self, index):
|
536 |
+
|
537 |
+
idx = index // self.prompt_sample_num
|
538 |
+
d = self.feat_data[idx]
|
539 |
+
|
540 |
+
prompt_id = random.randint(0, len(self.prompts) - 1)
|
541 |
+
|
542 |
+
prompt = self.prompts[prompt_id]
|
543 |
+
|
544 |
+
input, output = self._get_text_data(d, prompt)
|
545 |
+
|
546 |
+
return dict(input_ids=input, labels=output)
|
547 |
+
|
548 |
+
|
549 |
+
class ItemSearchDataset(BaseDataset):
|
550 |
+
|
551 |
+
def __init__(self, args, mode="train",
|
552 |
+
prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
553 |
+
super().__init__(args)
|
554 |
+
|
555 |
+
self.mode = mode
|
556 |
+
self.prompt_sample_num = prompt_sample_num
|
557 |
+
self.prompt_id = prompt_id
|
558 |
+
self.sample_num = sample_num
|
559 |
+
|
560 |
+
self.prompts = all_prompt["itemsearch"]
|
561 |
+
|
562 |
+
# load data
|
563 |
+
self._load_data()
|
564 |
+
self.search_data = self._process_data()
|
565 |
+
|
566 |
+
|
567 |
+
|
568 |
+
def _load_data(self):
|
569 |
+
|
570 |
+
with open(self.index_file, 'r') as f:
|
571 |
+
self.indices = json.load(f)
|
572 |
+
with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
|
573 |
+
self.user_info = json.load(f)
|
574 |
+
|
575 |
+
|
576 |
+
def _process_data(self):
|
577 |
+
|
578 |
+
search_data = []
|
579 |
+
user_explicit_preference = self.user_info["user_explicit_preference"]
|
580 |
+
user_vague_intention = self.user_info["user_vague_intention"]
|
581 |
+
if self.mode == 'train':
|
582 |
+
user_vague_intention = user_vague_intention["train"]
|
583 |
+
elif self.mode == 'test':
|
584 |
+
user_vague_intention = user_vague_intention["test"]
|
585 |
+
else:
|
586 |
+
raise NotImplementedError
|
587 |
+
|
588 |
+
for uid in user_explicit_preference.keys():
|
589 |
+
one_data = {}
|
590 |
+
user_ep = user_explicit_preference[uid]
|
591 |
+
user_vi = user_vague_intention[uid]["querys"]
|
592 |
+
one_data["explicit_preferences"] = user_ep
|
593 |
+
one_data["user_related_intention"] = user_vi[0]
|
594 |
+
one_data["item_related_intention"] = user_vi[1]
|
595 |
+
|
596 |
+
iid = user_vague_intention[uid]["item"]
|
597 |
+
inters = user_vague_intention[uid]["inters"]
|
598 |
+
|
599 |
+
index = "".join(self.indices[str(iid)])
|
600 |
+
one_data["item"] = index
|
601 |
+
|
602 |
+
if self.max_his_len > 0:
|
603 |
+
inters = inters[-self.max_his_len:]
|
604 |
+
inters = ["".join(self.indices[str(i)]) for i in inters]
|
605 |
+
if self.add_prefix:
|
606 |
+
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
607 |
+
|
608 |
+
one_data["inters"] = self.his_sep.join(inters)
|
609 |
+
|
610 |
+
search_data.append(one_data)
|
611 |
+
|
612 |
+
if self.sample_num > 0:
|
613 |
+
all_idx = range(len(search_data))
|
614 |
+
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
615 |
+
|
616 |
+
search_data = np.array(search_data)[sample_idx].tolist()
|
617 |
+
|
618 |
+
return search_data
|
619 |
+
|
620 |
+
def set_prompt(self, prompt_id):
|
621 |
+
self.prompt_id = prompt_id
|
622 |
+
|
623 |
+
def __len__(self):
|
624 |
+
if self.mode == 'train':
|
625 |
+
return len(self.search_data) * self.prompt_sample_num
|
626 |
+
elif self.mode == 'test':
|
627 |
+
return len(self.search_data)
|
628 |
+
else:
|
629 |
+
return len(self.search_data)
|
630 |
+
|
631 |
+
|
632 |
+
def _get_text_data(self, data, prompt):
|
633 |
+
|
634 |
+
instruction = prompt["instruction"].format(**data)
|
635 |
+
response = prompt["response"].format(**data)
|
636 |
+
|
637 |
+
input = sft_prompt.format(instruction = instruction, response = "")
|
638 |
+
output = sft_prompt.format(instruction = instruction, response = response)
|
639 |
+
|
640 |
+
if self.mode == 'test':
|
641 |
+
return input, response
|
642 |
+
|
643 |
+
return input, output
|
644 |
+
|
645 |
+
def __getitem__(self, index):
|
646 |
+
|
647 |
+
idx = index // self.prompt_sample_num
|
648 |
+
|
649 |
+
d = self.search_data[idx]
|
650 |
+
if self.mode == 'train':
|
651 |
+
prompt_id = random.randint(0, len(self.prompts) - 1)
|
652 |
+
elif self.mode == 'test':
|
653 |
+
prompt_id = self.prompt_id
|
654 |
+
|
655 |
+
prompt = self.prompts[prompt_id]
|
656 |
+
|
657 |
+
d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
|
658 |
+
all_querys = [d["user_related_intention"], d["item_related_intention"]]
|
659 |
+
d["query"] = random.choice(all_querys)
|
660 |
+
|
661 |
+
input, output = self._get_text_data(d, prompt)
|
662 |
+
|
663 |
+
return dict(input_ids=input, labels=output)
|
664 |
+
|
665 |
+
|
666 |
+
|
667 |
+
class PreferenceObtainDataset(BaseDataset):
|
668 |
+
|
669 |
+
def __init__(self, args, prompt_sample_num=1, sample_num=-1):
|
670 |
+
super().__init__(args)
|
671 |
+
|
672 |
+
self.prompt_sample_num = prompt_sample_num
|
673 |
+
self.sample_num = sample_num
|
674 |
+
|
675 |
+
self.prompts = all_prompt["preferenceobtain"]
|
676 |
+
|
677 |
+
# load data
|
678 |
+
self._load_data()
|
679 |
+
self._remap_items()
|
680 |
+
|
681 |
+
self.preference_data = self._process_data()
|
682 |
+
|
683 |
+
|
684 |
+
|
685 |
+
def _load_data(self):
|
686 |
+
|
687 |
+
with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
|
688 |
+
self.user_info = json.load(f)
|
689 |
+
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
690 |
+
self.inters = json.load(f)
|
691 |
+
with open(self.index_file, 'r') as f:
|
692 |
+
self.indices = json.load(f)
|
693 |
+
|
694 |
+
|
695 |
+
def _remap_items(self):
|
696 |
+
|
697 |
+
self.remapped_inters = dict()
|
698 |
+
for uid, items in self.inters.items():
|
699 |
+
new_items = ["".join(self.indices[str(i)]) for i in items]
|
700 |
+
self.remapped_inters[uid] = new_items
|
701 |
+
|
702 |
+
def _process_data(self):
|
703 |
+
|
704 |
+
preference_data = []
|
705 |
+
user_explicit_preference = self.user_info["user_explicit_preference"]
|
706 |
+
|
707 |
+
for uid in user_explicit_preference.keys():
|
708 |
+
one_data = {}
|
709 |
+
inters = self.remapped_inters[uid][:-3]
|
710 |
+
user_ep = user_explicit_preference[uid]
|
711 |
+
|
712 |
+
if self.max_his_len > 0:
|
713 |
+
inters = inters[-self.max_his_len:]
|
714 |
+
if self.add_prefix:
|
715 |
+
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
716 |
+
|
717 |
+
one_data["explicit_preferences"] = user_ep
|
718 |
+
one_data["inters"] = self.his_sep.join(inters)
|
719 |
+
|
720 |
+
preference_data.append(one_data)
|
721 |
+
|
722 |
+
if self.sample_num > 0:
|
723 |
+
all_idx = range(len(preference_data))
|
724 |
+
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
725 |
+
|
726 |
+
preference_data = np.array(preference_data)[sample_idx].tolist()
|
727 |
+
|
728 |
+
return preference_data
|
729 |
+
|
730 |
+
def set_prompt(self, prompt_id):
|
731 |
+
self.prompt_id = prompt_id
|
732 |
+
|
733 |
+
def __len__(self):
|
734 |
+
return len(self.preference_data) * self.prompt_sample_num
|
735 |
+
|
736 |
+
|
737 |
+
def _get_text_data(self, data, prompt):
|
738 |
+
|
739 |
+
instruction = prompt["instruction"].format(**data)
|
740 |
+
response = prompt["response"].format(**data)
|
741 |
+
|
742 |
+
input = sft_prompt.format(instruction = instruction, response = "")
|
743 |
+
output = sft_prompt.format(instruction = instruction, response = response)
|
744 |
+
|
745 |
+
return input, output
|
746 |
+
|
747 |
+
def __getitem__(self, index):
|
748 |
+
|
749 |
+
idx = index // self.prompt_sample_num
|
750 |
+
|
751 |
+
d = self.preference_data[idx]
|
752 |
+
prompt_id = random.randint(0, len(self.prompts) - 1)
|
753 |
+
|
754 |
+
prompt = self.prompts[prompt_id]
|
755 |
+
|
756 |
+
d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
|
757 |
+
|
758 |
+
input, output = self._get_text_data(d, prompt)
|
759 |
+
|
760 |
+
return dict(input_ids=input, labels=output)
|
761 |
+
|
762 |
+
|
763 |
+
|
764 |
+
|
765 |
+
|
766 |
+
class SeqRecTestDataset(BaseDataset):
|
767 |
+
|
768 |
+
def __init__(self, args, prompt_id=0, sample_num=-1):
|
769 |
+
super().__init__(args)
|
770 |
+
|
771 |
+
self.prompt_id = prompt_id
|
772 |
+
self.sample_num = sample_num
|
773 |
+
|
774 |
+
self.prompt = all_prompt["seqrec"][self.prompt_id]
|
775 |
+
|
776 |
+
# load data
|
777 |
+
self._load_data()
|
778 |
+
self._remap_items()
|
779 |
+
|
780 |
+
self.inter_data = self._process_test_data()
|
781 |
+
|
782 |
+
def _load_data(self):
|
783 |
+
|
784 |
+
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
785 |
+
self.inters = json.load(f)
|
786 |
+
with open(self.index_file, 'r') as f:
|
787 |
+
self.indices = json.load(f)
|
788 |
+
|
789 |
+
|
790 |
+
def _remap_items(self):
|
791 |
+
|
792 |
+
self.remapped_inters = dict()
|
793 |
+
for uid, items in self.inters.items():
|
794 |
+
new_items = ["".join(self.indices[str(i)]) for i in items]
|
795 |
+
self.remapped_inters[uid] = new_items
|
796 |
+
|
797 |
+
def _process_test_data(self):
|
798 |
+
|
799 |
+
inter_data = []
|
800 |
+
for uid in self.remapped_inters:
|
801 |
+
items = self.remapped_inters[uid]
|
802 |
+
one_data = dict()
|
803 |
+
# one_data["user"] = uid
|
804 |
+
one_data["item"] = items[-1]
|
805 |
+
history = items[:-1]
|
806 |
+
if self.max_his_len > 0:
|
807 |
+
history = history[-self.max_his_len:]
|
808 |
+
if self.add_prefix:
|
809 |
+
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
810 |
+
one_data["inters"] = self.his_sep.join(history)
|
811 |
+
inter_data.append(one_data)
|
812 |
+
|
813 |
+
if self.sample_num > 0:
|
814 |
+
all_inter_idx = range(len(inter_data))
|
815 |
+
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
816 |
+
|
817 |
+
inter_data = np.array(inter_data)[sample_idx].tolist()
|
818 |
+
|
819 |
+
return inter_data
|
820 |
+
|
821 |
+
def set_prompt(self, prompt_id):
|
822 |
+
self.prompt_id = prompt_id
|
823 |
+
|
824 |
+
self.prompt = all_prompt["seqrec"][self.prompt_id]
|
825 |
+
|
826 |
+
def __len__(self):
|
827 |
+
|
828 |
+
return len(self.inter_data)
|
829 |
+
|
830 |
+
def _get_text_data(self, data, prompt):
|
831 |
+
|
832 |
+
instruction = prompt["instruction"].format(**data)
|
833 |
+
response = prompt["response"].format(**data)
|
834 |
+
|
835 |
+
input = sft_prompt.format(instruction=instruction, response="")
|
836 |
+
|
837 |
+
return input, response
|
838 |
+
|
839 |
+
def __getitem__(self, index):
|
840 |
+
|
841 |
+
d = self.inter_data[index]
|
842 |
+
input, target = self._get_text_data(d, self.prompt)
|
843 |
+
|
844 |
+
return dict(input_ids=input, labels=target)
|
data_process/amazon18_data_process.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import collections
|
3 |
+
import gzip
|
4 |
+
import html
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
from utils import check_path, clean_text, amazon18_dataset2fullname, write_json_file, write_remap_index
|
13 |
+
|
14 |
+
def load_ratings(file):
|
15 |
+
users, items, inters = set(), set(), set()
|
16 |
+
with open(file, 'r') as fp:
|
17 |
+
for line in tqdm(fp, desc='Load ratings'):
|
18 |
+
try:
|
19 |
+
item, user, rating, time = line.strip().split(',')
|
20 |
+
users.add(user)
|
21 |
+
items.add(item)
|
22 |
+
inters.add((user, item, float(rating), int(time)))
|
23 |
+
except ValueError:
|
24 |
+
print(line)
|
25 |
+
return users, items, inters
|
26 |
+
|
27 |
+
|
28 |
+
def load_meta_items(file):
|
29 |
+
items = {}
|
30 |
+
with gzip.open(file, "r") as fp:
|
31 |
+
for line in tqdm(fp, desc="Load metas"):
|
32 |
+
data = json.loads(line)
|
33 |
+
item = data["asin"]
|
34 |
+
title = clean_text(data["title"])
|
35 |
+
|
36 |
+
descriptions = data["description"]
|
37 |
+
descriptions = clean_text(descriptions)
|
38 |
+
|
39 |
+
brand = data["brand"].replace("by\n", "").strip()
|
40 |
+
|
41 |
+
categories = data["category"]
|
42 |
+
new_categories = []
|
43 |
+
for category in categories:
|
44 |
+
if "</span>" in category:
|
45 |
+
break
|
46 |
+
new_categories.append(category.strip())
|
47 |
+
categories = ",".join(new_categories).strip()
|
48 |
+
|
49 |
+
items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories}
|
50 |
+
# print(items[item])
|
51 |
+
return items
|
52 |
+
|
53 |
+
|
54 |
+
def load_review_data(args, user2id, item2id):
|
55 |
+
|
56 |
+
dataset_full_name = amazon18_dataset2fullname[args.dataset]
|
57 |
+
review_file_path = os.path.join(args.input_path, 'Review', dataset_full_name + '.json.gz')
|
58 |
+
|
59 |
+
reviews = {}
|
60 |
+
|
61 |
+
with gzip.open(review_file_path, "r") as fp:
|
62 |
+
|
63 |
+
for line in tqdm(fp,desc='Load reviews'):
|
64 |
+
inter = json.loads(line)
|
65 |
+
try:
|
66 |
+
user = inter['reviewerID']
|
67 |
+
item = inter['asin']
|
68 |
+
if user in user2id and item in item2id:
|
69 |
+
uid = user2id[user]
|
70 |
+
iid = item2id[item]
|
71 |
+
else:
|
72 |
+
continue
|
73 |
+
if 'reviewText' in inter:
|
74 |
+
review = clean_text(inter['reviewText'])
|
75 |
+
else:
|
76 |
+
review = ''
|
77 |
+
if 'summary' in inter:
|
78 |
+
summary = clean_text(inter['summary'])
|
79 |
+
else:
|
80 |
+
summary = ''
|
81 |
+
reviews[str((uid,iid))]={"review":review, "summary":summary}
|
82 |
+
|
83 |
+
except ValueError:
|
84 |
+
print(line)
|
85 |
+
|
86 |
+
return reviews
|
87 |
+
|
88 |
+
|
89 |
+
def get_user2count(inters):
|
90 |
+
user2count = collections.defaultdict(int)
|
91 |
+
for unit in inters:
|
92 |
+
user2count[unit[0]] += 1
|
93 |
+
return user2count
|
94 |
+
|
95 |
+
|
96 |
+
def get_item2count(inters):
|
97 |
+
item2count = collections.defaultdict(int)
|
98 |
+
for unit in inters:
|
99 |
+
item2count[unit[1]] += 1
|
100 |
+
return item2count
|
101 |
+
|
102 |
+
|
103 |
+
def generate_candidates(unit2count, threshold):
|
104 |
+
cans = set()
|
105 |
+
for unit, count in unit2count.items():
|
106 |
+
if count >= threshold:
|
107 |
+
cans.add(unit)
|
108 |
+
return cans, len(unit2count) - len(cans)
|
109 |
+
|
110 |
+
|
111 |
+
def filter_inters(inters, can_items=None,
|
112 |
+
user_k_core_threshold=0, item_k_core_threshold=0):
|
113 |
+
new_inters = []
|
114 |
+
|
115 |
+
# filter by meta items
|
116 |
+
if can_items:
|
117 |
+
print('\nFiltering by meta items: ')
|
118 |
+
for unit in inters:
|
119 |
+
if unit[1] in can_items.keys():
|
120 |
+
new_inters.append(unit)
|
121 |
+
inters, new_inters = new_inters, []
|
122 |
+
print(' The number of inters: ', len(inters))
|
123 |
+
|
124 |
+
# filter by k-core
|
125 |
+
if user_k_core_threshold or item_k_core_threshold:
|
126 |
+
print('\nFiltering by k-core:')
|
127 |
+
idx = 0
|
128 |
+
user2count = get_user2count(inters)
|
129 |
+
item2count = get_item2count(inters)
|
130 |
+
|
131 |
+
while True:
|
132 |
+
new_user2count = collections.defaultdict(int)
|
133 |
+
new_item2count = collections.defaultdict(int)
|
134 |
+
users, n_filtered_users = generate_candidates( # users is set
|
135 |
+
user2count, user_k_core_threshold)
|
136 |
+
items, n_filtered_items = generate_candidates(
|
137 |
+
item2count, item_k_core_threshold)
|
138 |
+
if n_filtered_users == 0 and n_filtered_items == 0:
|
139 |
+
break
|
140 |
+
for unit in inters:
|
141 |
+
if unit[0] in users and unit[1] in items:
|
142 |
+
new_inters.append(unit)
|
143 |
+
new_user2count[unit[0]] += 1
|
144 |
+
new_item2count[unit[1]] += 1
|
145 |
+
idx += 1
|
146 |
+
inters, new_inters = new_inters, []
|
147 |
+
user2count, item2count = new_user2count, new_item2count
|
148 |
+
print(' Epoch %d The number of inters: %d, users: %d, items: %d'
|
149 |
+
% (idx, len(inters), len(user2count), len(item2count)))
|
150 |
+
return inters
|
151 |
+
|
152 |
+
|
153 |
+
def make_inters_in_order(inters):
|
154 |
+
user2inters, new_inters = collections.defaultdict(list), list()
|
155 |
+
for inter in inters:
|
156 |
+
user, item, rating, timestamp = inter
|
157 |
+
user2inters[user].append((user, item, rating, timestamp))
|
158 |
+
for user in user2inters:
|
159 |
+
user_inters = user2inters[user]
|
160 |
+
user_inters.sort(key=lambda d: d[3])
|
161 |
+
interacted_item = set()
|
162 |
+
for inter in user_inters:
|
163 |
+
if inter[1] in interacted_item: # 过滤重复交互
|
164 |
+
continue
|
165 |
+
interacted_item.add(inter[1])
|
166 |
+
new_inters.append(inter)
|
167 |
+
return new_inters
|
168 |
+
|
169 |
+
|
170 |
+
def preprocess_rating(args):
|
171 |
+
dataset_full_name = amazon18_dataset2fullname[args.dataset]
|
172 |
+
|
173 |
+
print('Process rating data: ')
|
174 |
+
print(' Dataset: ', args.dataset)
|
175 |
+
|
176 |
+
# load ratings
|
177 |
+
rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv')
|
178 |
+
rating_users, rating_items, rating_inters = load_ratings(rating_file_path)
|
179 |
+
|
180 |
+
# load item IDs with meta data
|
181 |
+
meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz')
|
182 |
+
meta_items = load_meta_items(meta_file_path)
|
183 |
+
|
184 |
+
# 1. Filter items w/o meta data;
|
185 |
+
# 2. K-core filtering;
|
186 |
+
print('The number of raw inters: ', len(rating_inters))
|
187 |
+
|
188 |
+
rating_inters = make_inters_in_order(rating_inters)
|
189 |
+
|
190 |
+
rating_inters = filter_inters(rating_inters, can_items=meta_items,
|
191 |
+
user_k_core_threshold=args.user_k,
|
192 |
+
item_k_core_threshold=args.item_k)
|
193 |
+
|
194 |
+
# sort interactions chronologically for each user
|
195 |
+
rating_inters = make_inters_in_order(rating_inters)
|
196 |
+
print('\n')
|
197 |
+
|
198 |
+
# return: list of (user_ID, item_ID, rating, timestamp)
|
199 |
+
return rating_inters, meta_items
|
200 |
+
|
201 |
+
def convert_inters2dict(inters):
|
202 |
+
user2items = collections.defaultdict(list)
|
203 |
+
user2index, item2index = dict(), dict()
|
204 |
+
for inter in inters:
|
205 |
+
user, item, rating, timestamp = inter
|
206 |
+
if user not in user2index:
|
207 |
+
user2index[user] = len(user2index)
|
208 |
+
if item not in item2index:
|
209 |
+
item2index[item] = len(item2index)
|
210 |
+
user2items[user2index[user]].append(item2index[item])
|
211 |
+
return user2items, user2index, item2index
|
212 |
+
|
213 |
+
def generate_data(args, rating_inters):
|
214 |
+
print('Split dataset: ')
|
215 |
+
print(' Dataset: ', args.dataset)
|
216 |
+
|
217 |
+
# generate train valid temp
|
218 |
+
user2items, user2index, item2index = convert_inters2dict(rating_inters)
|
219 |
+
train_inters, valid_inters, test_inters = dict(), dict(), dict()
|
220 |
+
for u_index in range(len(user2index)):
|
221 |
+
inters = user2items[u_index]
|
222 |
+
# leave one out
|
223 |
+
train_inters[u_index] = [str(i_index) for i_index in inters[:-2]]
|
224 |
+
valid_inters[u_index] = [str(inters[-2])]
|
225 |
+
test_inters[u_index] = [str(inters[-1])]
|
226 |
+
assert len(user2items[u_index]) == len(train_inters[u_index]) + \
|
227 |
+
len(valid_inters[u_index]) + len(test_inters[u_index])
|
228 |
+
return user2items, train_inters, valid_inters, test_inters, user2index, item2index
|
229 |
+
|
230 |
+
def convert_to_atomic_files(args, train_data, valid_data, test_data):
|
231 |
+
print('Convert dataset: ')
|
232 |
+
print(' Dataset: ', args.dataset)
|
233 |
+
uid_list = list(train_data.keys())
|
234 |
+
uid_list.sort(key=lambda t: int(t))
|
235 |
+
|
236 |
+
with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.train.inter'), 'w') as file:
|
237 |
+
file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
|
238 |
+
for uid in uid_list:
|
239 |
+
item_seq = train_data[uid]
|
240 |
+
seq_len = len(item_seq)
|
241 |
+
for target_idx in range(1, seq_len):
|
242 |
+
target_item = item_seq[-target_idx]
|
243 |
+
seq = item_seq[:-target_idx][-50:]
|
244 |
+
file.write(f'{uid}\t{" ".join(seq)}\t{target_item}\n')
|
245 |
+
|
246 |
+
with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.valid.inter'), 'w') as file:
|
247 |
+
file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
|
248 |
+
for uid in uid_list:
|
249 |
+
item_seq = train_data[uid][-50:]
|
250 |
+
target_item = valid_data[uid][0]
|
251 |
+
file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n')
|
252 |
+
|
253 |
+
with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.test.inter'), 'w') as file:
|
254 |
+
file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
|
255 |
+
for uid in uid_list:
|
256 |
+
item_seq = (train_data[uid] + valid_data[uid])[-50:]
|
257 |
+
target_item = test_data[uid][0]
|
258 |
+
file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n')
|
259 |
+
|
260 |
+
def parse_args():
|
261 |
+
parser = argparse.ArgumentParser()
|
262 |
+
parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
|
263 |
+
parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering')
|
264 |
+
parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering')
|
265 |
+
parser.add_argument('--input_path', type=str, default='')
|
266 |
+
parser.add_argument('--output_path', type=str, default='')
|
267 |
+
return parser.parse_args()
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == '__main__':
|
271 |
+
args = parse_args()
|
272 |
+
|
273 |
+
# load interactions from raw rating file
|
274 |
+
rating_inters, meta_items = preprocess_rating(args)
|
275 |
+
|
276 |
+
|
277 |
+
# split train/valid/temp
|
278 |
+
all_inters,train_inters, valid_inters, test_inters, user2index, item2index = generate_data(args, rating_inters)
|
279 |
+
|
280 |
+
check_path(os.path.join(args.output_path, args.dataset))
|
281 |
+
|
282 |
+
write_json_file(all_inters, os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter.json'))
|
283 |
+
convert_to_atomic_files(args, train_inters, valid_inters, test_inters)
|
284 |
+
|
285 |
+
item2feature = collections.defaultdict(dict)
|
286 |
+
for item, item_id in item2index.items():
|
287 |
+
item2feature[item_id] = meta_items[item]
|
288 |
+
|
289 |
+
# reviews = load_review_data(args, user2index, item2index)
|
290 |
+
|
291 |
+
print("user:",len(user2index))
|
292 |
+
print("item:",len(item2index))
|
293 |
+
|
294 |
+
write_json_file(item2feature, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item.json'))
|
295 |
+
# write_json_file(reviews, os.path.join(args.output_path, args.dataset, f'{args.dataset}.review.json'))
|
296 |
+
|
297 |
+
|
298 |
+
write_remap_index(user2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.user2id'))
|
299 |
+
write_remap_index(item2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item2id'))
|
data_process/amazon18_recbole_data_process.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import collections
|
3 |
+
import gzip
|
4 |
+
import html
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
from utils import check_path, clean_text, amazon18_dataset2fullname,write_json_file,write_remap_index
|
13 |
+
|
14 |
+
def load_ratings(file):
|
15 |
+
users, items, inters = set(), set(), set()
|
16 |
+
with open(file, 'r') as fp:
|
17 |
+
for line in tqdm(fp, desc='Load ratings'):
|
18 |
+
try:
|
19 |
+
item, user, rating, time = line.strip().split(',')
|
20 |
+
users.add(user)
|
21 |
+
items.add(item)
|
22 |
+
inters.add((user, item, float(rating), int(time)))
|
23 |
+
except ValueError:
|
24 |
+
print(line)
|
25 |
+
return users, items, inters
|
26 |
+
|
27 |
+
|
28 |
+
def load_meta_items(file):
|
29 |
+
items = {}
|
30 |
+
# re_tag = re.compile('</?\w+[^>]*>')
|
31 |
+
with gzip.open(file, "r") as fp:
|
32 |
+
for line in tqdm(fp, desc="Load metas"):
|
33 |
+
data = json.loads(line)
|
34 |
+
item = data["asin"]
|
35 |
+
title = clean_text(data["title"])
|
36 |
+
|
37 |
+
descriptions = data["description"]
|
38 |
+
descriptions = clean_text(descriptions)
|
39 |
+
# new_descriptions = []
|
40 |
+
# for description in descriptions:
|
41 |
+
# description = re.sub(re_tag, '', description)
|
42 |
+
# new_descriptions.append(description.strip())
|
43 |
+
# descriptions = " ".join(new_descriptions).strip()
|
44 |
+
|
45 |
+
brand = data["brand"].replace("by\n", "").strip()
|
46 |
+
|
47 |
+
categories = data["category"]
|
48 |
+
new_categories = []
|
49 |
+
for category in categories:
|
50 |
+
if "</span>" in category:
|
51 |
+
break
|
52 |
+
new_categories.append(category.strip())
|
53 |
+
categories = ",".join(new_categories[1:]).strip()
|
54 |
+
|
55 |
+
items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories}
|
56 |
+
# print(items[item])
|
57 |
+
return items
|
58 |
+
|
59 |
+
|
60 |
+
def get_user2count(inters):
|
61 |
+
user2count = collections.defaultdict(int)
|
62 |
+
for unit in inters:
|
63 |
+
user2count[unit[0]] += 1
|
64 |
+
return user2count
|
65 |
+
|
66 |
+
|
67 |
+
def get_item2count(inters):
|
68 |
+
item2count = collections.defaultdict(int)
|
69 |
+
for unit in inters:
|
70 |
+
item2count[unit[1]] += 1
|
71 |
+
return item2count
|
72 |
+
|
73 |
+
|
74 |
+
def generate_candidates(unit2count, threshold):
|
75 |
+
cans = set()
|
76 |
+
for unit, count in unit2count.items():
|
77 |
+
if count >= threshold:
|
78 |
+
cans.add(unit)
|
79 |
+
return cans, len(unit2count) - len(cans)
|
80 |
+
|
81 |
+
|
82 |
+
def filter_inters(inters, can_items=None,
|
83 |
+
user_k_core_threshold=0, item_k_core_threshold=0):
|
84 |
+
new_inters = []
|
85 |
+
|
86 |
+
# filter by meta items
|
87 |
+
if can_items:
|
88 |
+
print('\nFiltering by meta items: ')
|
89 |
+
for unit in inters:
|
90 |
+
if unit[1] in can_items.keys():
|
91 |
+
new_inters.append(unit)
|
92 |
+
inters, new_inters = new_inters, []
|
93 |
+
print(' The number of inters: ', len(inters))
|
94 |
+
|
95 |
+
# filter by k-core
|
96 |
+
if user_k_core_threshold or item_k_core_threshold:
|
97 |
+
print('\nFiltering by k-core:')
|
98 |
+
idx = 0
|
99 |
+
user2count = get_user2count(inters)
|
100 |
+
item2count = get_item2count(inters)
|
101 |
+
|
102 |
+
while True:
|
103 |
+
new_user2count = collections.defaultdict(int)
|
104 |
+
new_item2count = collections.defaultdict(int)
|
105 |
+
users, n_filtered_users = generate_candidates( # users is set
|
106 |
+
user2count, user_k_core_threshold)
|
107 |
+
items, n_filtered_items = generate_candidates(
|
108 |
+
item2count, item_k_core_threshold)
|
109 |
+
if n_filtered_users == 0 and n_filtered_items == 0:
|
110 |
+
break
|
111 |
+
for unit in inters:
|
112 |
+
if unit[0] in users and unit[1] in items:
|
113 |
+
new_inters.append(unit)
|
114 |
+
new_user2count[unit[0]] += 1
|
115 |
+
new_item2count[unit[1]] += 1
|
116 |
+
idx += 1
|
117 |
+
inters, new_inters = new_inters, []
|
118 |
+
user2count, item2count = new_user2count, new_item2count
|
119 |
+
print(' Epoch %d The number of inters: %d, users: %d, items: %d'
|
120 |
+
% (idx, len(inters), len(user2count), len(item2count)))
|
121 |
+
return inters
|
122 |
+
|
123 |
+
|
124 |
+
def make_inters_in_order(inters):
|
125 |
+
user2inters, new_inters = collections.defaultdict(list), list()
|
126 |
+
for inter in inters:
|
127 |
+
user, item, rating, timestamp = inter
|
128 |
+
user2inters[user].append((user, item, rating, timestamp))
|
129 |
+
for user in user2inters:
|
130 |
+
user_inters = user2inters[user]
|
131 |
+
user_inters.sort(key=lambda d: d[3])
|
132 |
+
interacted_item = set()
|
133 |
+
for inter in user_inters:
|
134 |
+
if inter[1] in interacted_item: # 过滤重复交互
|
135 |
+
continue
|
136 |
+
interacted_item.add(inter[1])
|
137 |
+
new_inters.append(inter)
|
138 |
+
return new_inters
|
139 |
+
|
140 |
+
|
141 |
+
def preprocess_rating(args):
|
142 |
+
dataset_full_name = amazon18_dataset2fullname[args.dataset]
|
143 |
+
|
144 |
+
print('Process rating data: ')
|
145 |
+
print(' Dataset: ', args.dataset)
|
146 |
+
|
147 |
+
# load ratings
|
148 |
+
rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv')
|
149 |
+
rating_users, rating_items, rating_inters = load_ratings(rating_file_path)
|
150 |
+
|
151 |
+
# load item IDs with meta data
|
152 |
+
meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz')
|
153 |
+
meta_items = load_meta_items(meta_file_path)
|
154 |
+
|
155 |
+
# 1. Filter items w/o meta data;
|
156 |
+
# 2. K-core filtering;
|
157 |
+
print('The number of raw inters: ', len(rating_inters))
|
158 |
+
|
159 |
+
rating_inters = make_inters_in_order(rating_inters)
|
160 |
+
|
161 |
+
rating_inters = filter_inters(rating_inters, can_items=meta_items,
|
162 |
+
user_k_core_threshold=args.user_k,
|
163 |
+
item_k_core_threshold=args.item_k)
|
164 |
+
|
165 |
+
# sort interactions chronologically for each user
|
166 |
+
rating_inters = make_inters_in_order(rating_inters)
|
167 |
+
print('\n')
|
168 |
+
|
169 |
+
# return: list of (user_ID, item_ID, rating, timestamp)
|
170 |
+
return rating_inters, meta_items
|
171 |
+
|
172 |
+
def save_inter(args, inters):
|
173 |
+
print('Convert dataset: ')
|
174 |
+
print(' Dataset: ', args.dataset)
|
175 |
+
|
176 |
+
with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter'), 'w') as file:
|
177 |
+
file.write('user_id:token\titem_id:token\trating:float\ttimestamp:float\n')
|
178 |
+
for inter in inters:
|
179 |
+
user, item, rating, timestamp = inter
|
180 |
+
file.write(f'{user}\t{item}\t{rating}\t{timestamp}\n')
|
181 |
+
|
182 |
+
|
183 |
+
def save_feat(args, feat, all_items):
|
184 |
+
iid_list = list(feat.keys())
|
185 |
+
num_item = 0
|
186 |
+
with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.item'), 'w') as file:
|
187 |
+
# "title": title, "description": descriptions, "brand": brand, "categories": categories
|
188 |
+
file.write('item_id:token\ttitle:token_seq\tbrand:token\tcategories:token_seq\n')
|
189 |
+
for iid in iid_list:
|
190 |
+
if iid in all_items:
|
191 |
+
num_item += 1
|
192 |
+
title, brand, categories = feat[iid]["title"], feat[iid]["brand"], feat[iid]["categories"]
|
193 |
+
file.write(f'{iid}\t{title}\t{brand}\t{categories}\n')
|
194 |
+
print("num_item: ", num_item)
|
195 |
+
|
196 |
+
|
197 |
+
def parse_args():
|
198 |
+
parser = argparse.ArgumentParser()
|
199 |
+
parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
|
200 |
+
parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering')
|
201 |
+
parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering')
|
202 |
+
parser.add_argument('--input_path', type=str, default='')
|
203 |
+
parser.add_argument('--output_path', type=str, default='')
|
204 |
+
return parser.parse_args()
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == '__main__':
|
208 |
+
args = parse_args()
|
209 |
+
|
210 |
+
# load interactions from raw rating file
|
211 |
+
rating_inters, meta_items = preprocess_rating(args)
|
212 |
+
|
213 |
+
check_path(os.path.join(args.output_path, args.dataset))
|
214 |
+
|
215 |
+
|
216 |
+
all_items = set()
|
217 |
+
for inter in rating_inters:
|
218 |
+
user, item, rating, timestamp = inter
|
219 |
+
all_items.add(item)
|
220 |
+
|
221 |
+
print("total item: ", len(list(all_items)))
|
222 |
+
|
223 |
+
save_inter(args,rating_inters)
|
224 |
+
save_feat(args,meta_items, all_items)
|
225 |
+
|
226 |
+
|
data_process/amazon_text_emb.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import collections
|
3 |
+
import gzip
|
4 |
+
import html
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
from utils import *
|
13 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModel
|
14 |
+
|
15 |
+
|
16 |
+
def load_data(args):
|
17 |
+
|
18 |
+
item2feature_path = args.data_path
|
19 |
+
item2feature = load_json(item2feature_path)
|
20 |
+
|
21 |
+
return item2feature
|
22 |
+
|
23 |
+
def generate_text(item2feature, features):
|
24 |
+
item_text_list = []
|
25 |
+
|
26 |
+
for item in item2feature:
|
27 |
+
data = item2feature[item]
|
28 |
+
text = []
|
29 |
+
for meta_key in features:
|
30 |
+
if meta_key in data:
|
31 |
+
meta_value = clean_text(data[meta_key])
|
32 |
+
text.append(meta_value.strip())
|
33 |
+
|
34 |
+
item_text_list.append([int(item), text])
|
35 |
+
|
36 |
+
return item_text_list
|
37 |
+
|
38 |
+
def preprocess_text(args):
|
39 |
+
print('Process text data ...')
|
40 |
+
# print('Dataset: ', args.dataset)
|
41 |
+
|
42 |
+
item2feature = load_data(args)
|
43 |
+
# load item text and clean
|
44 |
+
item_text_list = generate_text(item2feature, ['title'])
|
45 |
+
# item_text_list = generate_text(item2feature, ['title'])
|
46 |
+
# return: list of (item_ID, cleaned_item_text)
|
47 |
+
return item_text_list
|
48 |
+
|
49 |
+
def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1, save_path = ''):
|
50 |
+
print('Generate text embedding ...')
|
51 |
+
# print(' Dataset: ', args.dataset)
|
52 |
+
|
53 |
+
items, texts = zip(*item_text_list)
|
54 |
+
order_texts = [[0]] * len(items)
|
55 |
+
for item, text in zip(items, texts):
|
56 |
+
order_texts[item] = text
|
57 |
+
for text in order_texts:
|
58 |
+
assert text != [0]
|
59 |
+
|
60 |
+
embeddings = []
|
61 |
+
emb_result = []
|
62 |
+
start, batch_size = 0, 1
|
63 |
+
with torch.no_grad():
|
64 |
+
while start < len(order_texts):
|
65 |
+
if (start+1)%100==0:
|
66 |
+
print("==>",start+1)
|
67 |
+
field_texts = order_texts[start: start + batch_size]
|
68 |
+
# print(field_texts)
|
69 |
+
field_texts = zip(*field_texts)
|
70 |
+
|
71 |
+
field_embeddings = []
|
72 |
+
for sentences in field_texts:
|
73 |
+
sentences = list(sentences)
|
74 |
+
# print(sentences)
|
75 |
+
if word_drop_ratio > 0:
|
76 |
+
print(f'Word drop with p={word_drop_ratio}')
|
77 |
+
new_sentences = []
|
78 |
+
for sent in sentences:
|
79 |
+
new_sent = []
|
80 |
+
sent = sent.split(' ')
|
81 |
+
for wd in sent:
|
82 |
+
rd = random.random()
|
83 |
+
if rd > word_drop_ratio:
|
84 |
+
new_sent.append(wd)
|
85 |
+
new_sent = ' '.join(new_sent)
|
86 |
+
new_sentences.append(new_sent)
|
87 |
+
sentences = new_sentences
|
88 |
+
encoded_sentences = tokenizer(sentences, max_length=args.max_sent_len,
|
89 |
+
truncation=True, return_tensors='pt',padding="longest").to(args.device)
|
90 |
+
outputs = model(input_ids=encoded_sentences.input_ids,
|
91 |
+
attention_mask=encoded_sentences.attention_mask)
|
92 |
+
|
93 |
+
masked_output = outputs.last_hidden_state * encoded_sentences['attention_mask'].unsqueeze(-1)
|
94 |
+
mean_output = masked_output.sum(dim=1) / encoded_sentences['attention_mask'].sum(dim=-1, keepdim=True)
|
95 |
+
mean_output = mean_output.detach().cpu()
|
96 |
+
emb_result.append(mean_output.numpy().tolist())
|
97 |
+
field_embeddings.append(mean_output)
|
98 |
+
|
99 |
+
field_mean_embedding = torch.stack(field_embeddings, dim=0).mean(dim=0)
|
100 |
+
embeddings.append(field_mean_embedding)
|
101 |
+
start += batch_size
|
102 |
+
|
103 |
+
embeddings = torch.cat(embeddings, dim=0).numpy()
|
104 |
+
print('Embeddings shape: ', embeddings.shape)
|
105 |
+
|
106 |
+
all_results = {
|
107 |
+
'text':[],
|
108 |
+
'node_type':[],
|
109 |
+
'emb':[]
|
110 |
+
}
|
111 |
+
|
112 |
+
all_results['text'] = [t[0] for t in texts]
|
113 |
+
all_results['node_type'] = [1] * len(all_results['text'])
|
114 |
+
for emb in emb_result:
|
115 |
+
str_emb = ''
|
116 |
+
for x in emb:
|
117 |
+
str_emb = str_emb + str(x) + ' '
|
118 |
+
all_results['emb'].append(str_emb[:-1])
|
119 |
+
|
120 |
+
import pandas as pd
|
121 |
+
df = pd.DataFrame(all_results)
|
122 |
+
# header = 0: w/o column name; index = False: w/o index column
|
123 |
+
df.to_csv(args.save_path, sep = '\t', header = 0, index = False)
|
124 |
+
|
125 |
+
# file = os.path.join(args.root, args.dataset + '.emb-' + args.plm_name + "-td" + ".npy")
|
126 |
+
# np.save(file, embeddings)
|
127 |
+
|
128 |
+
|
129 |
+
def parse_args():
|
130 |
+
parser = argparse.ArgumentParser()
|
131 |
+
parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
|
132 |
+
parser.add_argument('--root', type=str, default="")
|
133 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='ID of running GPU')
|
134 |
+
parser.add_argument('--plm_name', type=str, default='llama')
|
135 |
+
parser.add_argument('--plm_checkpoint', type=str,
|
136 |
+
default='')
|
137 |
+
parser.add_argument('--max_sent_len', type=int, default=2048)
|
138 |
+
parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default')
|
139 |
+
parser.add_argument('--data_path', type=str, default='')
|
140 |
+
parser.add_argument('--save_path', type=str, default='')
|
141 |
+
return parser.parse_args()
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
args = parse_args()
|
146 |
+
|
147 |
+
args.root = os.path.join(args.root, args.dataset)
|
148 |
+
|
149 |
+
device = set_device(args.gpu_id)
|
150 |
+
args.device = device
|
151 |
+
|
152 |
+
item_text_list = preprocess_text(args)
|
153 |
+
|
154 |
+
plm_tokenizer, plm_model = load_plm(args.plm_checkpoint)
|
155 |
+
if plm_tokenizer.pad_token_id is None:
|
156 |
+
plm_tokenizer.pad_token_id = 0
|
157 |
+
plm_model = plm_model.to(device)
|
158 |
+
|
159 |
+
generate_item_embedding(args, item_text_list, plm_tokenizer,
|
160 |
+
plm_model, word_drop_ratio = args.word_drop_ratio,
|
161 |
+
save_path = args.save_path)
|
data_process/get_llm_output.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from logging import getLogger
|
9 |
+
import openai
|
10 |
+
from utils import get_res_batch, load_json, intention_prompt, preference_prompt_1, preference_prompt_2, amazon18_dataset2fullname, write_json_file
|
11 |
+
import json
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def get_intention_train(args, inters, item2feature, reviews, api_info):
|
16 |
+
|
17 |
+
intention_train_output_file = os.path.join(args.root,"intention_train.json")
|
18 |
+
|
19 |
+
|
20 |
+
# Suggest modifying the prompt based on different datasets
|
21 |
+
prompt = intention_prompt
|
22 |
+
dataset_full_name = amazon18_dataset2fullname[args.dataset]
|
23 |
+
dataset_full_name = dataset_full_name.replace("_", " ").lower()
|
24 |
+
print(dataset_full_name)
|
25 |
+
|
26 |
+
prompt_list = []
|
27 |
+
|
28 |
+
inter_data = []
|
29 |
+
|
30 |
+
for (user,item_list) in inters.items():
|
31 |
+
user = int(user)
|
32 |
+
item = int(item_list[-3])
|
33 |
+
history = item_list[:-3]
|
34 |
+
|
35 |
+
inter_data.append((user,item,history))
|
36 |
+
|
37 |
+
review = reviews[str((user, item))]["review"]
|
38 |
+
item_title = item2feature[str(item)]["title"]
|
39 |
+
input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review)
|
40 |
+
prompt_list.append(input_prompt)
|
41 |
+
|
42 |
+
st = 0
|
43 |
+
with open(intention_train_output_file, mode='a') as f:
|
44 |
+
|
45 |
+
while st < len(prompt_list):
|
46 |
+
# while st < 3:
|
47 |
+
print(st)
|
48 |
+
# if st < 25631:
|
49 |
+
# st += args.batchsize
|
50 |
+
# continue
|
51 |
+
|
52 |
+
|
53 |
+
res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info)
|
54 |
+
|
55 |
+
for i, answer in enumerate(res):
|
56 |
+
user, item, history = inter_data[st+i]
|
57 |
+
# print(answer)
|
58 |
+
# print("=============")
|
59 |
+
|
60 |
+
if answer == '':
|
61 |
+
print("answer null error")
|
62 |
+
answer = "I enjoy high-quality item."
|
63 |
+
|
64 |
+
if answer.strip().count('\n') != 1:
|
65 |
+
if 'haracteristics:' in answer:
|
66 |
+
answer = answer.strip().split("The item's characteristics:")
|
67 |
+
else:
|
68 |
+
answer = answer.strip().split("The item's characteristic:")
|
69 |
+
else:
|
70 |
+
answer = answer.strip().split('\n')
|
71 |
+
|
72 |
+
if '' in answer:
|
73 |
+
answer.remove('')
|
74 |
+
|
75 |
+
if len(answer) == 1:
|
76 |
+
print(answer)
|
77 |
+
user_preference = item_character = answer[0]
|
78 |
+
elif len(answer) >= 3:
|
79 |
+
print(answer)
|
80 |
+
answer = answer[-1]
|
81 |
+
user_preference = item_character = answer
|
82 |
+
else:
|
83 |
+
user_preference, item_character = answer
|
84 |
+
|
85 |
+
if ':' in user_preference:
|
86 |
+
idx = user_preference.index(':')
|
87 |
+
user_preference = user_preference[idx+1:]
|
88 |
+
user_preference = user_preference.strip().replace('}','')
|
89 |
+
user_preference = user_preference.replace('\n','')
|
90 |
+
|
91 |
+
if ':' in item_character:
|
92 |
+
idx = item_character.index(':')
|
93 |
+
item_character = item_character[idx+1:]
|
94 |
+
item_character = item_character.strip().replace('}','')
|
95 |
+
item_character = item_character.replace('\n','')
|
96 |
+
|
97 |
+
|
98 |
+
dict = {"user":user, "item":item, "inters": history,
|
99 |
+
"user_related_intention":user_preference, "item_related_intention": item_character}
|
100 |
+
|
101 |
+
json.dump(dict, f)
|
102 |
+
f.write("\n")
|
103 |
+
|
104 |
+
st += args.batchsize
|
105 |
+
|
106 |
+
return intention_train_output_file
|
107 |
+
|
108 |
+
|
109 |
+
def get_intention_test(args, inters, item2feature, reviews, api_info):
|
110 |
+
|
111 |
+
intention_test_output_file = os.path.join(args.root,"intention_test.json")
|
112 |
+
|
113 |
+
# Suggest modifying the prompt based on different datasets
|
114 |
+
prompt = intention_prompt
|
115 |
+
dataset_full_name = amazon18_dataset2fullname[args.dataset]
|
116 |
+
dataset_full_name = dataset_full_name.replace("_", " ").lower()
|
117 |
+
print(dataset_full_name)
|
118 |
+
|
119 |
+
prompt_list = []
|
120 |
+
|
121 |
+
inter_data = []
|
122 |
+
|
123 |
+
for (user,item_list) in inters.items():
|
124 |
+
user = int(user)
|
125 |
+
item = int(item_list[-1])
|
126 |
+
history = item_list[:-1]
|
127 |
+
|
128 |
+
inter_data.append((user,item,history))
|
129 |
+
|
130 |
+
review = reviews[str((user, item))]["review"]
|
131 |
+
item_title = item2feature[str(item)]["title"]
|
132 |
+
input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review)
|
133 |
+
prompt_list.append(input_prompt)
|
134 |
+
|
135 |
+
st = 0
|
136 |
+
with open(intention_test_output_file, mode='a') as f:
|
137 |
+
|
138 |
+
while st < len(prompt_list):
|
139 |
+
# while st < 3:
|
140 |
+
print(st)
|
141 |
+
# if st < 4623:
|
142 |
+
# st += args.batchsize
|
143 |
+
# continue
|
144 |
+
|
145 |
+
res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info)
|
146 |
+
|
147 |
+
for i, answer in enumerate(res):
|
148 |
+
user, item, history = inter_data[st+i]
|
149 |
+
|
150 |
+
if answer == '':
|
151 |
+
print("answer null error")
|
152 |
+
answer = "I enjoy high-quality item."
|
153 |
+
|
154 |
+
if answer.strip().count('\n') != 1:
|
155 |
+
if 'haracteristics:' in answer:
|
156 |
+
answer = answer.strip().split("The item's characteristics:")
|
157 |
+
else:
|
158 |
+
answer = answer.strip().split("The item's characteristic:")
|
159 |
+
else:
|
160 |
+
answer = answer.strip().split('\n')
|
161 |
+
|
162 |
+
if '' in answer:
|
163 |
+
answer.remove('')
|
164 |
+
|
165 |
+
if len(answer) == 1:
|
166 |
+
print(answer)
|
167 |
+
user_preference = item_character = answer[0]
|
168 |
+
elif len(answer) >= 3:
|
169 |
+
print(answer)
|
170 |
+
answer = answer[-1]
|
171 |
+
user_preference = item_character = answer
|
172 |
+
else:
|
173 |
+
user_preference, item_character = answer
|
174 |
+
|
175 |
+
if ':' in user_preference:
|
176 |
+
idx = user_preference.index(':')
|
177 |
+
user_preference = user_preference[idx+1:]
|
178 |
+
user_preference = user_preference.strip().replace('}','')
|
179 |
+
user_preference = user_preference.replace('\n','')
|
180 |
+
|
181 |
+
if ':' in item_character:
|
182 |
+
idx = item_character.index(':')
|
183 |
+
item_character = item_character[idx+1:]
|
184 |
+
item_character = item_character.strip().replace('}','')
|
185 |
+
item_character = item_character.replace('\n','')
|
186 |
+
|
187 |
+
|
188 |
+
dict = {"user":user, "item":item, "inters": history,
|
189 |
+
"user_related_intention":user_preference, "item_related_intention": item_character}
|
190 |
+
|
191 |
+
json.dump(dict, f)
|
192 |
+
f.write("\n")
|
193 |
+
|
194 |
+
st += args.batchsize
|
195 |
+
|
196 |
+
return intention_test_output_file
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
def get_user_preference(args, inters, item2feature, reviews, api_info):
|
202 |
+
|
203 |
+
preference_output_file = os.path.join(args.root,"user_preference.json")
|
204 |
+
|
205 |
+
|
206 |
+
# Suggest modifying the prompt based on different datasets
|
207 |
+
prompt_1 = preference_prompt_1
|
208 |
+
prompt_2 = preference_prompt_2
|
209 |
+
|
210 |
+
|
211 |
+
dataset_full_name = amazon18_dataset2fullname[args.dataset]
|
212 |
+
dataset_full_name = dataset_full_name.replace("_", " ").lower()
|
213 |
+
print(dataset_full_name)
|
214 |
+
|
215 |
+
prompt_list_1 = []
|
216 |
+
prompt_list_2 = []
|
217 |
+
|
218 |
+
users = []
|
219 |
+
|
220 |
+
for (user,item_list) in inters.items():
|
221 |
+
users.append(user)
|
222 |
+
history = item_list[:-3]
|
223 |
+
item_titles = []
|
224 |
+
for j, item in enumerate(history):
|
225 |
+
item_titles.append(str(j+1) + '.' + item2feature[str(item)]["title"])
|
226 |
+
if len(item_titles) > args.max_his_len:
|
227 |
+
item_titles = item_titles[-args.max_his_len:]
|
228 |
+
item_titles = ", ".join(item_titles)
|
229 |
+
|
230 |
+
input_prompt_1 = prompt_1.format(dataset_full_name=dataset_full_name, item_titles=item_titles)
|
231 |
+
input_prompt_2 = prompt_2.format(dataset_full_name=dataset_full_name, item_titles=item_titles)
|
232 |
+
|
233 |
+
prompt_list_1.append(input_prompt_1)
|
234 |
+
prompt_list_2.append(input_prompt_2)
|
235 |
+
|
236 |
+
|
237 |
+
st = 0
|
238 |
+
with open(preference_output_file, mode='a') as f:
|
239 |
+
|
240 |
+
while st < len(prompt_list_1):
|
241 |
+
# while st < 3:
|
242 |
+
print(st)
|
243 |
+
# if st < 22895:
|
244 |
+
# st += args.batchsize
|
245 |
+
# continue
|
246 |
+
|
247 |
+
res_1 = get_res_batch(args.model_name, prompt_list_1[st:st + args.batchsize], args.max_tokens, api_info)
|
248 |
+
res_2 = get_res_batch(args.model_name, prompt_list_2[st:st + args.batchsize], args.max_tokens, api_info)
|
249 |
+
for i, answers in enumerate(zip(res_1, res_2)):
|
250 |
+
|
251 |
+
user = users[st + i]
|
252 |
+
|
253 |
+
answer_1, answer_2 = answers
|
254 |
+
# print(answers)
|
255 |
+
# print("=============")
|
256 |
+
|
257 |
+
if answer_1 == '':
|
258 |
+
print("answer null error")
|
259 |
+
answer_1 = "I enjoy high-quality item."
|
260 |
+
|
261 |
+
if answer_2 == '':
|
262 |
+
print("answer null error")
|
263 |
+
answer_2 = "I enjoy high-quality item."
|
264 |
+
|
265 |
+
if answer_2.strip().count('\n') != 1:
|
266 |
+
if 'references:' in answer_2:
|
267 |
+
answer_2 = answer_2.strip().split("Short-term preferences:")
|
268 |
+
else:
|
269 |
+
answer_2 = answer_2.strip().split("Short-term preference:")
|
270 |
+
else:
|
271 |
+
answer_2 = answer_2.strip().split('\n')
|
272 |
+
|
273 |
+
if '' in answer_2:
|
274 |
+
answer_2.remove('')
|
275 |
+
|
276 |
+
if len(answer_2) == 1:
|
277 |
+
print(answer_2)
|
278 |
+
long_preference = short_preference = answer_2[0]
|
279 |
+
elif len(answer_2) >= 3:
|
280 |
+
print(answer_2)
|
281 |
+
answer_2 = answer_2[-1]
|
282 |
+
long_preference = short_preference = answer_2
|
283 |
+
else:
|
284 |
+
long_preference, short_preference = answer_2
|
285 |
+
|
286 |
+
if ':' in long_preference:
|
287 |
+
idx = long_preference.index(':')
|
288 |
+
long_preference = long_preference[idx+1:]
|
289 |
+
long_preference = long_preference.strip().replace('}','')
|
290 |
+
long_preference = long_preference.replace('\n','')
|
291 |
+
|
292 |
+
if ':' in short_preference:
|
293 |
+
idx = short_preference.index(':')
|
294 |
+
short_preference = short_preference[idx+1:]
|
295 |
+
short_preference = short_preference.strip().replace('}','')
|
296 |
+
short_preference = short_preference.replace('\n','')
|
297 |
+
|
298 |
+
dict = {"user":user,"user_preference":[answer_1, long_preference, short_preference]}
|
299 |
+
# print(dict)
|
300 |
+
json.dump(dict, f)
|
301 |
+
f.write("\n")
|
302 |
+
|
303 |
+
st += args.batchsize
|
304 |
+
|
305 |
+
return preference_output_file
|
306 |
+
|
307 |
+
def parse_args():
|
308 |
+
parser = argparse.ArgumentParser()
|
309 |
+
parser.add_argument('--dataset', type=str, default='Instruments', help='Instruments / Arts / Games')
|
310 |
+
parser.add_argument('--root', type=str, default='')
|
311 |
+
parser.add_argument('--api_info', type=str, default='./api_info.json')
|
312 |
+
parser.add_argument('--model_name', type=str, default='text-davinci-003')
|
313 |
+
parser.add_argument('--max_tokens', type=int, default=512)
|
314 |
+
parser.add_argument('--batchsize', type=int, default=16)
|
315 |
+
parser.add_argument('--max_his_len', type=int, default=20)
|
316 |
+
return parser.parse_args()
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
args = parse_args()
|
320 |
+
|
321 |
+
args.root = os.path.join(args.root, args.dataset)
|
322 |
+
|
323 |
+
api_info = load_json(args.api_info)
|
324 |
+
openai.api_key = api_info["api_key_list"].pop()
|
325 |
+
|
326 |
+
|
327 |
+
inter_path = os.path.join(args.root, f'{args.dataset}.inter.json')
|
328 |
+
inters = load_json(inter_path)
|
329 |
+
|
330 |
+
|
331 |
+
item2feature_path = os.path.join(args.root, f'{args.dataset}.item.json')
|
332 |
+
item2feature = load_json(item2feature_path)
|
333 |
+
|
334 |
+
reviews_path = os.path.join(args.root, f'{args.dataset}.review.json')
|
335 |
+
reviews = load_json(reviews_path)
|
336 |
+
|
337 |
+
intention_train_output_file = get_intention_train(args, inters, item2feature, reviews, api_info)
|
338 |
+
intention_test_output_file = get_intention_test(args, inters, item2feature, reviews ,api_info)
|
339 |
+
preference_output_file = get_user_preference(args, inters, item2feature, reviews, api_info)
|
340 |
+
|
341 |
+
intention_train = {}
|
342 |
+
intention_test = {}
|
343 |
+
user_preference = {}
|
344 |
+
|
345 |
+
with open(intention_train_output_file, "r") as f:
|
346 |
+
for line in f:
|
347 |
+
# print(line)
|
348 |
+
content = json.loads(line)
|
349 |
+
if content["user"] not in intention_train:
|
350 |
+
intention_train[content["user"]] = {"item":content["item"],
|
351 |
+
"inters":content["inters"],
|
352 |
+
"querys":[ content["user_related_intention"], content["item_related_intention"] ]}
|
353 |
+
|
354 |
+
|
355 |
+
with open(intention_test_output_file, "r") as f:
|
356 |
+
for line in f:
|
357 |
+
content = json.loads(line)
|
358 |
+
if content["user"] not in intention_train:
|
359 |
+
intention_test[content["user"]] = {"item":content["item"],
|
360 |
+
"inters":content["inters"],
|
361 |
+
"querys":[ content["user_related_intention"], content["item_related_intention"] ]}
|
362 |
+
|
363 |
+
|
364 |
+
with open(preference_output_file, "r") as f:
|
365 |
+
for line in f:
|
366 |
+
content = json.loads(line)
|
367 |
+
user_preference[content["user"]] = content["user_preference"]
|
368 |
+
|
369 |
+
user_dict = {
|
370 |
+
"user_explicit_preference": user_preference,
|
371 |
+
"user_vague_intention": {"train": intention_train, "test": intention_test},
|
372 |
+
}
|
373 |
+
|
374 |
+
write_json_file(user_dict, os.path.join(args.root, f'{args.dataset}.user.json'))
|
data_process/utils.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import re
|
6 |
+
import time
|
7 |
+
|
8 |
+
import torch
|
9 |
+
# import gensim
|
10 |
+
from transformers import AutoModel, AutoTokenizer
|
11 |
+
import collections
|
12 |
+
import openai
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def get_res_batch(model_name, prompt_list, max_tokens, api_info):
|
17 |
+
|
18 |
+
while True:
|
19 |
+
try:
|
20 |
+
res = openai.Completion.create(
|
21 |
+
model=model_name,
|
22 |
+
prompt=prompt_list,
|
23 |
+
temperature=0.4,
|
24 |
+
max_tokens=max_tokens,
|
25 |
+
top_p=1,
|
26 |
+
frequency_penalty=0,
|
27 |
+
presence_penalty=0
|
28 |
+
)
|
29 |
+
output_list = []
|
30 |
+
for choice in res['choices']:
|
31 |
+
output = choice['text'].strip()
|
32 |
+
output_list.append(output)
|
33 |
+
|
34 |
+
return output_list
|
35 |
+
|
36 |
+
except openai.error.AuthenticationError as e:
|
37 |
+
print(e)
|
38 |
+
openai.api_key = api_info["api_key_list"].pop()
|
39 |
+
time.sleep(10)
|
40 |
+
except openai.error.RateLimitError as e:
|
41 |
+
print(e)
|
42 |
+
if str(e) == "You exceeded your current quota, please check your plan and billing details.":
|
43 |
+
openai.api_key = api_info["api_key_list"].pop()
|
44 |
+
time.sleep(10)
|
45 |
+
else:
|
46 |
+
print('\nopenai.error.RateLimitError\nRetrying...')
|
47 |
+
time.sleep(10)
|
48 |
+
except openai.error.ServiceUnavailableError as e:
|
49 |
+
print(e)
|
50 |
+
print('\nopenai.error.ServiceUnavailableError\nRetrying...')
|
51 |
+
time.sleep(10)
|
52 |
+
except openai.error.Timeout:
|
53 |
+
print('\nopenai.error.Timeout\nRetrying...')
|
54 |
+
time.sleep(10)
|
55 |
+
except openai.error.APIError as e:
|
56 |
+
print(e)
|
57 |
+
print('\nopenai.error.APIError\nRetrying...')
|
58 |
+
time.sleep(10)
|
59 |
+
except openai.error.APIConnectionError as e:
|
60 |
+
print(e)
|
61 |
+
print('\nopenai.error.APIConnectionError\nRetrying...')
|
62 |
+
time.sleep(10)
|
63 |
+
except Exception as e:
|
64 |
+
print(e)
|
65 |
+
return None
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def check_path(path):
|
71 |
+
if not os.path.exists(path):
|
72 |
+
os.makedirs(path)
|
73 |
+
|
74 |
+
|
75 |
+
def set_device(gpu_id):
|
76 |
+
if gpu_id == -1:
|
77 |
+
return torch.device('cpu')
|
78 |
+
else:
|
79 |
+
return torch.device(
|
80 |
+
'cuda:' + str(gpu_id) if torch.cuda.is_available() else 'cpu')
|
81 |
+
|
82 |
+
def load_plm(model_path='bert-base-uncased'):
|
83 |
+
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path,)
|
85 |
+
|
86 |
+
print("Load Model:", model_path)
|
87 |
+
|
88 |
+
model = AutoModel.from_pretrained(model_path,low_cpu_mem_usage=True,)
|
89 |
+
return tokenizer, model
|
90 |
+
|
91 |
+
def load_json(file):
|
92 |
+
with open(file, 'r') as f:
|
93 |
+
data = json.load(f)
|
94 |
+
return data
|
95 |
+
|
96 |
+
def clean_text(raw_text):
|
97 |
+
if isinstance(raw_text, list):
|
98 |
+
new_raw_text=[]
|
99 |
+
for raw in raw_text:
|
100 |
+
raw = html.unescape(raw)
|
101 |
+
raw = re.sub(r'</?\w+[^>]*>', '', raw)
|
102 |
+
raw = re.sub(r'["\n\r]*', '', raw)
|
103 |
+
new_raw_text.append(raw.strip())
|
104 |
+
cleaned_text = ' '.join(new_raw_text)
|
105 |
+
else:
|
106 |
+
if isinstance(raw_text, dict):
|
107 |
+
cleaned_text = str(raw_text)[1:-1].strip()
|
108 |
+
else:
|
109 |
+
cleaned_text = raw_text.strip()
|
110 |
+
cleaned_text = html.unescape(cleaned_text)
|
111 |
+
cleaned_text = re.sub(r'</?\w+[^>]*>', '', cleaned_text)
|
112 |
+
cleaned_text = re.sub(r'["\n\r]*', '', cleaned_text)
|
113 |
+
index = -1
|
114 |
+
while -index < len(cleaned_text) and cleaned_text[index] == '.':
|
115 |
+
index -= 1
|
116 |
+
index += 1
|
117 |
+
if index == 0:
|
118 |
+
cleaned_text = cleaned_text + '.'
|
119 |
+
else:
|
120 |
+
cleaned_text = cleaned_text[:index] + '.'
|
121 |
+
if len(cleaned_text) >= 2000:
|
122 |
+
cleaned_text = ''
|
123 |
+
return cleaned_text
|
124 |
+
|
125 |
+
def load_pickle(filename):
|
126 |
+
with open(filename, "rb") as f:
|
127 |
+
return pickle.load(f)
|
128 |
+
|
129 |
+
|
130 |
+
def make_inters_in_order(inters):
|
131 |
+
user2inters, new_inters = collections.defaultdict(list), list()
|
132 |
+
for inter in inters:
|
133 |
+
user, item, rating, timestamp = inter
|
134 |
+
user2inters[user].append((user, item, rating, timestamp))
|
135 |
+
for user in user2inters:
|
136 |
+
user_inters = user2inters[user]
|
137 |
+
user_inters.sort(key=lambda d: d[3])
|
138 |
+
for inter in user_inters:
|
139 |
+
new_inters.append(inter)
|
140 |
+
return new_inters
|
141 |
+
|
142 |
+
def write_json_file(dic, file):
|
143 |
+
print('Writing json file: ',file)
|
144 |
+
with open(file, 'w') as fp:
|
145 |
+
json.dump(dic, fp, indent=4)
|
146 |
+
|
147 |
+
def write_remap_index(unit2index, file):
|
148 |
+
print('Writing remap file: ',file)
|
149 |
+
with open(file, 'w') as fp:
|
150 |
+
for unit in unit2index:
|
151 |
+
fp.write(unit + '\t' + str(unit2index[unit]) + '\n')
|
152 |
+
|
153 |
+
|
154 |
+
intention_prompt = "After purchasing a {dataset_full_name} item named \"{item_title}\", the user left a comment expressing his opinion and personal preferences. The user's comment is as follows: \n\"{review}\" " \
|
155 |
+
"\nAs we all know, user comments often contain information about both their personal preferences and the characteristics of the item they interacted with. From this comment, you can infer both the user's personal preferences and the characteristics of the item. " \
|
156 |
+
"Please describe your inferred user preferences and item characteristics in the first person and in the following format:\n\nMy preferences: []\nThe item's characteristics: []\n\n" \
|
157 |
+
"Note that your inference of the personalized preferences should not include any information about the title of the item."
|
158 |
+
|
159 |
+
|
160 |
+
preference_prompt_1 = "Suppose the user has bought a variety of {dataset_full_name} items, they are: \n{item_titles}. \nAs we all know, these historically purchased items serve as a reflection of the user's personalized preferences. " \
|
161 |
+
"Please analyze the user's personalized preferences based on the items he has bought and provide a brief third-person summary of the user's preferences, highlighting the key factors that influence his choice of items. Avoid listing specific items and do not list multiple examples. " \
|
162 |
+
"Your analysis should be brief and in the third person."
|
163 |
+
|
164 |
+
preference_prompt_2 = "Given a chronological list of {dataset_full_name} items that a user has purchased, we can analyze his long-term and short-term preferences. Long-term preferences are inherent characteristics of the user, which are reflected in all the items he has interacted with over time. Short-term preferences are the user's recent preferences, which are reflected in some of the items he has bought more recently. " \
|
165 |
+
"To determine the user's long-term preferences, please analyze the contents of all the items he has bought. Look for common features that appear frequently across the user's shopping records. To determine the user's short-term preferences, focus on the items he has bought most recently. Identify any new or different features that have emerged in the user's shopping records. " \
|
166 |
+
"Here is a chronological list of items that the user has bought: \n{item_titles}. \nPlease provide separate analyses for the user's long-term and short-term preferences. Your answer should be concise and general, without listing specific items. Your answer should be in the third person and in the following format:\n\nLong-term preferences: []\nShort-term preferences: []\n\n"
|
167 |
+
|
168 |
+
|
169 |
+
# remove 'Magazine', 'Gift', 'Music', 'Kindle'
|
170 |
+
amazon18_dataset_list = [
|
171 |
+
'Appliances', 'Beauty',
|
172 |
+
'Fashion', 'Software', 'Luxury', 'Scientific', 'Pantry',
|
173 |
+
'Instruments', 'Arts', 'Games', 'Office', 'Garden',
|
174 |
+
'Food', 'Cell', 'CDs', 'Automotive', 'Toys',
|
175 |
+
'Pet', 'Tools', 'Kindle', 'Sports', 'Movies',
|
176 |
+
'Electronics', 'Home', 'Clothing', 'Books'
|
177 |
+
]
|
178 |
+
|
179 |
+
amazon18_dataset2fullname = {
|
180 |
+
'Beauty': 'All_Beauty',
|
181 |
+
'Fashion': 'AMAZON_FASHION',
|
182 |
+
'Appliances': 'Appliances',
|
183 |
+
'Arts': 'Arts_Crafts_and_Sewing',
|
184 |
+
'Automotive': 'Automotive',
|
185 |
+
'Books': 'Books',
|
186 |
+
'CDs': 'CDs_and_Vinyl',
|
187 |
+
'Cell': 'Cell_Phones_and_Accessories',
|
188 |
+
'Clothing': 'Clothing_Shoes_and_Jewelry',
|
189 |
+
'Music': 'Digital_Music',
|
190 |
+
'Electronics': 'Electronics',
|
191 |
+
'Gift': 'Gift_Cards',
|
192 |
+
'Food': 'Grocery_and_Gourmet_Food',
|
193 |
+
'Home': 'Home_and_Kitchen',
|
194 |
+
'Scientific': 'Industrial_and_Scientific',
|
195 |
+
'Kindle': 'Kindle_Store',
|
196 |
+
'Luxury': 'Luxury_Beauty',
|
197 |
+
'Magazine': 'Magazine_Subscriptions',
|
198 |
+
'Movies': 'Movies_and_TV',
|
199 |
+
'Instruments': 'Musical_Instruments',
|
200 |
+
'Office': 'Office_Products',
|
201 |
+
'Garden': 'Patio_Lawn_and_Garden',
|
202 |
+
'Pet': 'Pet_Supplies',
|
203 |
+
'Pantry': 'Prime_Pantry',
|
204 |
+
'Software': 'Software',
|
205 |
+
'Sports': 'Sports_and_Outdoors',
|
206 |
+
'Tools': 'Tools_and_Home_Improvement',
|
207 |
+
'Toys': 'Toys_and_Games',
|
208 |
+
'Games': 'Video_Games'
|
209 |
+
}
|
210 |
+
|
211 |
+
amazon14_dataset_list = [
|
212 |
+
'Beauty','Toys','Sports'
|
213 |
+
]
|
214 |
+
|
215 |
+
amazon14_dataset2fullname = {
|
216 |
+
'Beauty': 'Beauty',
|
217 |
+
'Sports': 'Sports_and_Outdoors',
|
218 |
+
'Toys': 'Toys_and_Games',
|
219 |
+
}
|
220 |
+
|
221 |
+
# c1. c2. c3. c4.
|
222 |
+
amazon_text_feature1 = ['title', 'category', 'brand']
|
223 |
+
|
224 |
+
# re-order
|
225 |
+
amazon_text_feature1_ro1 = ['brand', 'main_cat', 'category', 'title']
|
226 |
+
|
227 |
+
# remove
|
228 |
+
amazon_text_feature1_re1 = ['title']
|
229 |
+
|
230 |
+
amazon_text_feature2 = ['title']
|
231 |
+
|
232 |
+
amazon_text_feature3 = ['description']
|
233 |
+
|
234 |
+
amazon_text_feature4 = ['description', 'main_cat', 'category', 'brand']
|
235 |
+
|
236 |
+
amazon_text_feature5 = ['title', 'description']
|
237 |
+
|
238 |
+
|
evaluate.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
def get_topk_results(predictions, scores, targets, k, all_items=None):
|
4 |
+
results = []
|
5 |
+
B = len(targets)
|
6 |
+
predictions = [_.split("Response:")[-1] for _ in predictions]
|
7 |
+
predictions = [_.strip().replace(" ","") for _ in predictions]
|
8 |
+
|
9 |
+
if all_items is not None:
|
10 |
+
for i, seq in enumerate(predictions):
|
11 |
+
if seq not in all_items:
|
12 |
+
scores[i] = -1000
|
13 |
+
|
14 |
+
for b in range(B):
|
15 |
+
batch_seqs = predictions[b * k: (b + 1) * k]
|
16 |
+
batch_scores = scores[b * k: (b + 1) * k]
|
17 |
+
|
18 |
+
pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)]
|
19 |
+
sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True)
|
20 |
+
target_item = targets[b]
|
21 |
+
one_results = []
|
22 |
+
for sorted_pred in sorted_pairs:
|
23 |
+
if sorted_pred[0] == target_item:
|
24 |
+
one_results.append(1)
|
25 |
+
else:
|
26 |
+
one_results.append(0)
|
27 |
+
|
28 |
+
results.append(one_results)
|
29 |
+
|
30 |
+
return results
|
31 |
+
|
32 |
+
def get_metrics_results(topk_results, metrics):
|
33 |
+
res = {}
|
34 |
+
for m in metrics:
|
35 |
+
if m.lower().startswith("hit"):
|
36 |
+
k = int(m.split("@")[1])
|
37 |
+
res[m] = hit_k(topk_results, k)
|
38 |
+
elif m.lower().startswith("ndcg"):
|
39 |
+
k = int(m.split("@")[1])
|
40 |
+
res[m] = ndcg_k(topk_results, k)
|
41 |
+
else:
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
return res
|
45 |
+
|
46 |
+
|
47 |
+
def ndcg_k(topk_results, k):
|
48 |
+
|
49 |
+
ndcg = 0.0
|
50 |
+
for row in topk_results:
|
51 |
+
res = row[:k]
|
52 |
+
one_ndcg = 0.0
|
53 |
+
for i in range(len(res)):
|
54 |
+
one_ndcg += res[i] / math.log(i + 2, 2)
|
55 |
+
ndcg += one_ndcg
|
56 |
+
return ndcg
|
57 |
+
|
58 |
+
|
59 |
+
def hit_k(topk_results, k):
|
60 |
+
hit = 0.0
|
61 |
+
for row in topk_results:
|
62 |
+
res = row[:k]
|
63 |
+
if sum(res) > 0:
|
64 |
+
hit += 1
|
65 |
+
return hit
|
66 |
+
|
finetune.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import transformers
|
9 |
+
|
10 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
|
11 |
+
|
12 |
+
from utils import *
|
13 |
+
from collator import Collator
|
14 |
+
|
15 |
+
def train(args):
|
16 |
+
|
17 |
+
set_seed(args.seed)
|
18 |
+
ensure_dir(args.output_dir)
|
19 |
+
|
20 |
+
device_map = "auto"
|
21 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
22 |
+
ddp = world_size != 1
|
23 |
+
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
|
24 |
+
if local_rank == 0:
|
25 |
+
print(vars(args))
|
26 |
+
|
27 |
+
if ddp:
|
28 |
+
device_map = {"": local_rank}
|
29 |
+
|
30 |
+
config = LlamaConfig.from_pretrained(args.base_model)
|
31 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
32 |
+
args.base_model,
|
33 |
+
model_max_length = args.model_max_length,
|
34 |
+
padding_side="right",
|
35 |
+
)
|
36 |
+
tokenizer.pad_token_id = 0
|
37 |
+
gradient_checkpointing = True
|
38 |
+
|
39 |
+
train_data, valid_data = load_datasets(args)
|
40 |
+
add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
|
41 |
+
config.vocab_size = len(tokenizer)
|
42 |
+
if local_rank == 0:
|
43 |
+
print("add {} new token.".format(add_num))
|
44 |
+
print("data num:", len(train_data))
|
45 |
+
tokenizer.save_pretrained(args.output_dir)
|
46 |
+
config.save_pretrained(args.output_dir)
|
47 |
+
|
48 |
+
collator = Collator(args, tokenizer)
|
49 |
+
|
50 |
+
|
51 |
+
model = LlamaForCausalLM.from_pretrained(
|
52 |
+
args.base_model,
|
53 |
+
# torch_dtype=torch.float16,
|
54 |
+
device_map=device_map,
|
55 |
+
)
|
56 |
+
model.resize_token_embeddings(len(tokenizer))
|
57 |
+
|
58 |
+
|
59 |
+
if not ddp and torch.cuda.device_count() > 1:
|
60 |
+
model.is_parallelizable = True
|
61 |
+
model.model_parallel = True
|
62 |
+
|
63 |
+
|
64 |
+
trainer = transformers.Trainer(
|
65 |
+
model=model,
|
66 |
+
train_dataset=train_data,
|
67 |
+
eval_dataset=valid_data,
|
68 |
+
args=transformers.TrainingArguments(
|
69 |
+
seed=args.seed,
|
70 |
+
per_device_train_batch_size=args.per_device_batch_size,
|
71 |
+
per_device_eval_batch_size=args.per_device_batch_size,
|
72 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
73 |
+
warmup_ratio=args.warmup_ratio,
|
74 |
+
num_train_epochs=args.epochs,
|
75 |
+
learning_rate=args.learning_rate,
|
76 |
+
weight_decay=args.weight_decay,
|
77 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
78 |
+
fp16=args.fp16,
|
79 |
+
bf16=args.bf16,
|
80 |
+
logging_steps=args.logging_step,
|
81 |
+
optim=args.optim,
|
82 |
+
gradient_checkpointing=gradient_checkpointing,
|
83 |
+
evaluation_strategy=args.save_and_eval_strategy,
|
84 |
+
save_strategy=args.save_and_eval_strategy,
|
85 |
+
eval_steps=args.save_and_eval_steps,
|
86 |
+
save_steps=args.save_and_eval_steps,
|
87 |
+
output_dir=args.output_dir,
|
88 |
+
save_total_limit=5,
|
89 |
+
load_best_model_at_end=True,
|
90 |
+
deepspeed=args.deepspeed,
|
91 |
+
ddp_find_unused_parameters=False if ddp else None,
|
92 |
+
report_to=None,
|
93 |
+
eval_delay= 1 if args.save_and_eval_strategy=="epoch" else 2000,
|
94 |
+
dataloader_num_workers = args.dataloader_num_workers,
|
95 |
+
dataloader_prefetch_factor = args.dataloader_prefetch_factor
|
96 |
+
),
|
97 |
+
tokenizer=tokenizer,
|
98 |
+
data_collator=collator,
|
99 |
+
)
|
100 |
+
model.config.use_cache = False
|
101 |
+
|
102 |
+
|
103 |
+
trainer.train(
|
104 |
+
resume_from_checkpoint=args.resume_from_checkpoint,
|
105 |
+
)
|
106 |
+
|
107 |
+
trainer.save_state()
|
108 |
+
trainer.save_model(output_dir=args.output_dir)
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
parser = argparse.ArgumentParser(description='LLMRec')
|
115 |
+
parser = parse_global_args(parser)
|
116 |
+
parser = parse_train_args(parser)
|
117 |
+
parser = parse_dataset_args(parser)
|
118 |
+
|
119 |
+
args = parser.parse_args()
|
120 |
+
|
121 |
+
train(args)
|
index/datasets.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data as data
|
4 |
+
|
5 |
+
|
6 |
+
class EmbDataset(data.Dataset):
|
7 |
+
|
8 |
+
def __init__(self,data_path):
|
9 |
+
|
10 |
+
self.data_path = data_path
|
11 |
+
# self.embeddings = np.fromfile(data_path, dtype=np.float32).reshape(16859,-1)
|
12 |
+
self.embeddings = np.load(data_path)
|
13 |
+
self.dim = self.embeddings.shape[-1]
|
14 |
+
|
15 |
+
def __getitem__(self, index):
|
16 |
+
emb = self.embeddings[index]
|
17 |
+
tensor_emb=torch.FloatTensor(emb)
|
18 |
+
return tensor_emb
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.embeddings)
|
index/generate_indices.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from time import time
|
8 |
+
from torch import optim
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
from datasets import EmbDataset
|
14 |
+
from models.rqvae import RQVAE
|
15 |
+
|
16 |
+
import os
|
17 |
+
|
18 |
+
def check_collision(all_indices_str):
|
19 |
+
tot_item = len(all_indices_str)
|
20 |
+
tot_indice = len(set(all_indices_str.tolist()))
|
21 |
+
return tot_item==tot_indice
|
22 |
+
|
23 |
+
def get_indices_count(all_indices_str):
|
24 |
+
indices_count = collections.defaultdict(int)
|
25 |
+
for index in all_indices_str:
|
26 |
+
indices_count[index] += 1
|
27 |
+
return indices_count
|
28 |
+
|
29 |
+
def get_collision_item(all_indices_str):
|
30 |
+
index2id = {}
|
31 |
+
for i, index in enumerate(all_indices_str):
|
32 |
+
if index not in index2id:
|
33 |
+
index2id[index] = []
|
34 |
+
index2id[index].append(i)
|
35 |
+
|
36 |
+
collision_item_groups = []
|
37 |
+
|
38 |
+
for index in index2id:
|
39 |
+
if len(index2id[index]) > 1:
|
40 |
+
collision_item_groups.append(index2id[index])
|
41 |
+
|
42 |
+
return collision_item_groups
|
43 |
+
|
44 |
+
def parse_args():
|
45 |
+
parser = argparse.ArgumentParser(description = "Index")
|
46 |
+
parser.add_argument("--ckpt_path", type = str, default = "", help = "")
|
47 |
+
parser.add_argument("--data_path", type = str, default = "", help = "")
|
48 |
+
parser.add_argument("--save_path", type = str, default = "", help = "")
|
49 |
+
parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu")
|
50 |
+
return parser.parse_args()
|
51 |
+
|
52 |
+
infer_args = parse_args()
|
53 |
+
print(infer_args)
|
54 |
+
|
55 |
+
# dataset = "Games"
|
56 |
+
# ckpt_path = "/zhengbowen/rqvae_ckpt/xxxx"
|
57 |
+
# output_dir = f"/zhengbowen/data/{dataset}/"
|
58 |
+
# output_file = f"{dataset}.index.json"
|
59 |
+
# output_file = os.path.join(output_dir,output_file)
|
60 |
+
# device = torch.device("cuda:1")
|
61 |
+
|
62 |
+
device = torch.device(infer_args.device)
|
63 |
+
|
64 |
+
ckpt = torch.load(infer_args.ckpt_path, map_location = torch.device('cpu'))
|
65 |
+
args = ckpt["args"]
|
66 |
+
state_dict = ckpt["state_dict"]
|
67 |
+
|
68 |
+
data = EmbDataset(infer_args.data_path)
|
69 |
+
|
70 |
+
model = RQVAE(in_dim=data.dim,
|
71 |
+
num_emb_list=args.num_emb_list,
|
72 |
+
e_dim=args.e_dim,
|
73 |
+
layers=args.layers,
|
74 |
+
dropout_prob=args.dropout_prob,
|
75 |
+
bn=args.bn,
|
76 |
+
loss_type=args.loss_type,
|
77 |
+
quant_loss_weight=args.quant_loss_weight,
|
78 |
+
kmeans_init=args.kmeans_init,
|
79 |
+
kmeans_iters=args.kmeans_iters,
|
80 |
+
sk_epsilons=args.sk_epsilons,
|
81 |
+
sk_iters=args.sk_iters,
|
82 |
+
)
|
83 |
+
|
84 |
+
model.load_state_dict(state_dict)
|
85 |
+
model = model.to(device)
|
86 |
+
model.eval()
|
87 |
+
print(model)
|
88 |
+
|
89 |
+
data_loader = DataLoader(data,num_workers=args.num_workers,
|
90 |
+
batch_size=64, shuffle=False,
|
91 |
+
pin_memory=True)
|
92 |
+
|
93 |
+
all_indices = []
|
94 |
+
all_indices_str = []
|
95 |
+
prefix = ["<a_{}>","<b_{}>","<c_{}>","<d_{}>","<e_{}>"]
|
96 |
+
|
97 |
+
for d in tqdm(data_loader):
|
98 |
+
d = d.to(device)
|
99 |
+
indices = model.get_indices(d,use_sk=False)
|
100 |
+
indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
|
101 |
+
for index in indices:
|
102 |
+
code = []
|
103 |
+
for i, ind in enumerate(index):
|
104 |
+
code.append(prefix[i].format(int(ind)))
|
105 |
+
|
106 |
+
all_indices.append(code)
|
107 |
+
all_indices_str.append(str(code))
|
108 |
+
# break
|
109 |
+
|
110 |
+
all_indices = np.array(all_indices)
|
111 |
+
all_indices_str = np.array(all_indices_str)
|
112 |
+
|
113 |
+
for vq in model.rq.vq_layers[:-1]:
|
114 |
+
vq.sk_epsilon=0.0
|
115 |
+
|
116 |
+
if model.rq.vq_layers[-1].sk_epsilon == 0.0:
|
117 |
+
model.rq.vq_layers[-1].sk_epsilon = 0.003
|
118 |
+
|
119 |
+
tt = 0
|
120 |
+
#There are often duplicate items in the dataset, and we no longer differentiate them
|
121 |
+
while True:
|
122 |
+
if tt >= 10 or check_collision(all_indices_str):
|
123 |
+
break
|
124 |
+
|
125 |
+
collision_item_groups = get_collision_item(all_indices_str)
|
126 |
+
# print(collision_item_groups)
|
127 |
+
print(len(collision_item_groups))
|
128 |
+
for collision_items in collision_item_groups:
|
129 |
+
d = data[collision_items].to(device)
|
130 |
+
|
131 |
+
indices = model.get_indices(d, use_sk=True)
|
132 |
+
indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
|
133 |
+
for item, index in zip(collision_items, indices):
|
134 |
+
code = []
|
135 |
+
for i, ind in enumerate(index):
|
136 |
+
code.append(prefix[i].format(int(ind)))
|
137 |
+
|
138 |
+
all_indices[item] = code
|
139 |
+
all_indices_str[item] = str(code)
|
140 |
+
tt += 1
|
141 |
+
|
142 |
+
|
143 |
+
print("All indices number: ",len(all_indices))
|
144 |
+
print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
|
145 |
+
|
146 |
+
tot_item = len(all_indices_str)
|
147 |
+
tot_indice = len(set(all_indices_str.tolist()))
|
148 |
+
print("Collision Rate",(tot_item-tot_indice)/tot_item)
|
149 |
+
|
150 |
+
all_indices_dict = {}
|
151 |
+
for item, indices in enumerate(all_indices.tolist()):
|
152 |
+
all_indices_dict[item] = list(indices)
|
153 |
+
|
154 |
+
with open(infer_args.save_path, 'w') as fp:
|
155 |
+
json.dump(all_indices_dict, fp)
|
index/main.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from time import time
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
from datasets import EmbDataset
|
11 |
+
from models.rqvae import RQVAE
|
12 |
+
from trainer import Trainer
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser(description="Index")
|
16 |
+
|
17 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
|
18 |
+
parser.add_argument('--epochs', type=int, default=5000, help='number of epochs')
|
19 |
+
parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
|
20 |
+
parser.add_argument('--num_workers', type=int, default=4, )
|
21 |
+
parser.add_argument('--eval_step', type=int, default=50, help='eval step')
|
22 |
+
parser.add_argument('--learner', type=str, default="AdamW", help='optimizer')
|
23 |
+
parser.add_argument("--data_path", type=str,
|
24 |
+
default="../data/Games/Games.emb-llama-td.npy",
|
25 |
+
help="Input data path.")
|
26 |
+
|
27 |
+
parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight')
|
28 |
+
parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio")
|
29 |
+
parser.add_argument("--bn", type=bool, default=False, help="use bn or not")
|
30 |
+
parser.add_argument("--loss_type", type=str, default="mse", help="loss_type")
|
31 |
+
parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not")
|
32 |
+
parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters")
|
33 |
+
parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons")
|
34 |
+
parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters")
|
35 |
+
|
36 |
+
parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu")
|
37 |
+
|
38 |
+
parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq')
|
39 |
+
parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size')
|
40 |
+
parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight')
|
41 |
+
parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer')
|
42 |
+
|
43 |
+
parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model")
|
44 |
+
|
45 |
+
return parser.parse_args()
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
"""fix the random seed"""
|
50 |
+
seed = 2023
|
51 |
+
random.seed(seed)
|
52 |
+
np.random.seed(seed)
|
53 |
+
torch.manual_seed(seed)
|
54 |
+
torch.cuda.manual_seed_all(seed)
|
55 |
+
torch.backends.cudnn.deterministic = True
|
56 |
+
torch.backends.cudnn.benchmark = False
|
57 |
+
|
58 |
+
args = parse_args()
|
59 |
+
print(args)
|
60 |
+
|
61 |
+
logging.basicConfig(level=logging.DEBUG)
|
62 |
+
|
63 |
+
"""build dataset"""
|
64 |
+
data = EmbDataset(args.data_path)
|
65 |
+
model = RQVAE(in_dim=data.dim,
|
66 |
+
num_emb_list=args.num_emb_list,
|
67 |
+
e_dim=args.e_dim,
|
68 |
+
layers=args.layers,
|
69 |
+
dropout_prob=args.dropout_prob,
|
70 |
+
bn=args.bn,
|
71 |
+
loss_type=args.loss_type,
|
72 |
+
quant_loss_weight=args.quant_loss_weight,
|
73 |
+
kmeans_init=args.kmeans_init,
|
74 |
+
kmeans_iters=args.kmeans_iters,
|
75 |
+
sk_epsilons=args.sk_epsilons,
|
76 |
+
sk_iters=args.sk_iters,
|
77 |
+
)
|
78 |
+
print(model)
|
79 |
+
data_loader = DataLoader(data,num_workers=args.num_workers,
|
80 |
+
batch_size=args.batch_size, shuffle=True,
|
81 |
+
pin_memory=True)
|
82 |
+
trainer = Trainer(args,model)
|
83 |
+
best_loss, best_collision_rate = trainer.fit(data_loader)
|
84 |
+
|
85 |
+
print("Best Loss",best_loss)
|
86 |
+
print("Best Collision Rate", best_collision_rate)
|
87 |
+
|
index/models/layers.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.init import xavier_normal_
|
4 |
+
from sklearn.cluster import KMeans
|
5 |
+
|
6 |
+
|
7 |
+
class MLPLayers(nn.Module):
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self, layers, dropout=0.0, activation="relu", bn=False
|
11 |
+
):
|
12 |
+
super(MLPLayers, self).__init__()
|
13 |
+
self.layers = layers
|
14 |
+
self.dropout = dropout
|
15 |
+
self.activation = activation
|
16 |
+
self.use_bn = bn
|
17 |
+
|
18 |
+
mlp_modules = []
|
19 |
+
for idx, (input_size, output_size) in enumerate(
|
20 |
+
zip(self.layers[:-1], self.layers[1:])
|
21 |
+
):
|
22 |
+
mlp_modules.append(nn.Dropout(p=self.dropout))
|
23 |
+
mlp_modules.append(nn.Linear(input_size, output_size))
|
24 |
+
if self.use_bn:
|
25 |
+
mlp_modules.append(nn.BatchNorm1d(num_features=output_size))
|
26 |
+
activation_func = activation_layer(self.activation, output_size)
|
27 |
+
if activation_func is not None and idx != (len(self.layers)-2):
|
28 |
+
mlp_modules.append(activation_func)
|
29 |
+
|
30 |
+
self.mlp_layers = nn.Sequential(*mlp_modules)
|
31 |
+
self.apply(self.init_weights)
|
32 |
+
|
33 |
+
def init_weights(self, module):
|
34 |
+
# We just initialize the module with normal distribution as the paper said
|
35 |
+
if isinstance(module, nn.Linear):
|
36 |
+
xavier_normal_(module.weight.data)
|
37 |
+
if module.bias is not None:
|
38 |
+
module.bias.data.fill_(0.0)
|
39 |
+
|
40 |
+
def forward(self, input_feature):
|
41 |
+
return self.mlp_layers(input_feature)
|
42 |
+
|
43 |
+
def activation_layer(activation_name="relu", emb_dim=None):
|
44 |
+
|
45 |
+
if activation_name is None:
|
46 |
+
activation = None
|
47 |
+
elif isinstance(activation_name, str):
|
48 |
+
if activation_name.lower() == "sigmoid":
|
49 |
+
activation = nn.Sigmoid()
|
50 |
+
elif activation_name.lower() == "tanh":
|
51 |
+
activation = nn.Tanh()
|
52 |
+
elif activation_name.lower() == "relu":
|
53 |
+
activation = nn.ReLU()
|
54 |
+
elif activation_name.lower() == "leakyrelu":
|
55 |
+
activation = nn.LeakyReLU()
|
56 |
+
elif activation_name.lower() == "none":
|
57 |
+
activation = None
|
58 |
+
elif issubclass(activation_name, nn.Module):
|
59 |
+
activation = activation_name()
|
60 |
+
else:
|
61 |
+
raise NotImplementedError(
|
62 |
+
"activation function {} is not implemented".format(activation_name)
|
63 |
+
)
|
64 |
+
|
65 |
+
return activation
|
66 |
+
|
67 |
+
def kmeans(
|
68 |
+
samples,
|
69 |
+
num_clusters,
|
70 |
+
num_iters = 10,
|
71 |
+
):
|
72 |
+
B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device
|
73 |
+
x = samples.cpu().detach().numpy()
|
74 |
+
|
75 |
+
cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x)
|
76 |
+
|
77 |
+
centers = cluster.cluster_centers_
|
78 |
+
tensor_centers = torch.from_numpy(centers).to(device)
|
79 |
+
|
80 |
+
return tensor_centers
|
81 |
+
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations):
|
85 |
+
Q = torch.exp(- distances / epsilon)
|
86 |
+
|
87 |
+
B = Q.shape[0] # number of samples to assign
|
88 |
+
K = Q.shape[1] # how many centroids per block (usually set to 256)
|
89 |
+
|
90 |
+
# make the matrix sums to 1
|
91 |
+
sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
92 |
+
Q /= sum_Q
|
93 |
+
# print(Q.sum())
|
94 |
+
for it in range(sinkhorn_iterations):
|
95 |
+
|
96 |
+
# normalize each column: total weight per sample must be 1/B
|
97 |
+
Q /= torch.sum(Q, dim=1, keepdim=True)
|
98 |
+
Q /= B
|
99 |
+
|
100 |
+
# normalize each row: total weight per prototype must be 1/K
|
101 |
+
Q /= torch.sum(Q, dim=0, keepdim=True)
|
102 |
+
Q /= K
|
103 |
+
|
104 |
+
|
105 |
+
Q *= B # the colomns must sum to 1 so that Q is an assignment
|
106 |
+
return Q
|
index/models/rq.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vq import VectorQuantizer
|
5 |
+
|
6 |
+
|
7 |
+
class ResidualVectorQuantizer(nn.Module):
|
8 |
+
""" References:
|
9 |
+
SoundStream: An End-to-End Neural Audio Codec
|
10 |
+
https://arxiv.org/pdf/2107.03312.pdf
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, n_e_list, e_dim, sk_epsilons,
|
14 |
+
kmeans_init = False, kmeans_iters = 100, sk_iters=100,):
|
15 |
+
super().__init__()
|
16 |
+
self.n_e_list = n_e_list
|
17 |
+
self.e_dim = e_dim
|
18 |
+
self.num_quantizers = len(n_e_list)
|
19 |
+
self.kmeans_init = kmeans_init
|
20 |
+
self.kmeans_iters = kmeans_iters
|
21 |
+
self.sk_epsilons = sk_epsilons
|
22 |
+
self.sk_iters = sk_iters
|
23 |
+
self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim,
|
24 |
+
kmeans_init = self.kmeans_init,
|
25 |
+
kmeans_iters = self.kmeans_iters,
|
26 |
+
sk_epsilon=sk_epsilon,
|
27 |
+
sk_iters=sk_iters)
|
28 |
+
for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ])
|
29 |
+
|
30 |
+
def get_codebook(self):
|
31 |
+
all_codebook = []
|
32 |
+
for quantizer in self.vq_layers:
|
33 |
+
codebook = quantizer.get_codebook()
|
34 |
+
all_codebook.append(codebook)
|
35 |
+
return torch.stack(all_codebook)
|
36 |
+
|
37 |
+
def forward(self, x, use_sk=True):
|
38 |
+
all_losses = []
|
39 |
+
all_indices = []
|
40 |
+
|
41 |
+
x_q = 0
|
42 |
+
residual = x
|
43 |
+
for quantizer in self.vq_layers:
|
44 |
+
x_res, loss, indices = quantizer(residual, use_sk=use_sk)
|
45 |
+
residual = residual - x_res
|
46 |
+
x_q = x_q + x_res
|
47 |
+
|
48 |
+
all_losses.append(loss)
|
49 |
+
all_indices.append(indices)
|
50 |
+
|
51 |
+
mean_losses = torch.stack(all_losses).mean()
|
52 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
53 |
+
|
54 |
+
return x_q, mean_losses, all_indices
|
index/models/rqvae.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from .layers import MLPLayers
|
7 |
+
from .rq import ResidualVectorQuantizer
|
8 |
+
|
9 |
+
|
10 |
+
class RQVAE(nn.Module):
|
11 |
+
def __init__(self,
|
12 |
+
in_dim=768,
|
13 |
+
# num_emb_list=[256,256,256,256],
|
14 |
+
num_emb_list=None,
|
15 |
+
e_dim=64,
|
16 |
+
# layers=[512,256,128],
|
17 |
+
layers=None,
|
18 |
+
dropout_prob=0.0,
|
19 |
+
bn=False,
|
20 |
+
loss_type="mse",
|
21 |
+
quant_loss_weight=1.0,
|
22 |
+
kmeans_init=False,
|
23 |
+
kmeans_iters=100,
|
24 |
+
# sk_epsilons=[0,0,0.003,0.01]],
|
25 |
+
sk_epsilons=None,
|
26 |
+
sk_iters=100,
|
27 |
+
):
|
28 |
+
super(RQVAE, self).__init__()
|
29 |
+
|
30 |
+
self.in_dim = in_dim
|
31 |
+
self.num_emb_list = num_emb_list
|
32 |
+
self.e_dim = e_dim
|
33 |
+
|
34 |
+
self.layers = layers
|
35 |
+
self.dropout_prob = dropout_prob
|
36 |
+
self.bn = bn
|
37 |
+
self.loss_type = loss_type
|
38 |
+
self.quant_loss_weight=quant_loss_weight
|
39 |
+
self.kmeans_init = kmeans_init
|
40 |
+
self.kmeans_iters = kmeans_iters
|
41 |
+
self.sk_epsilons = sk_epsilons
|
42 |
+
self.sk_iters = sk_iters
|
43 |
+
|
44 |
+
self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim]
|
45 |
+
self.encoder = MLPLayers(layers=self.encode_layer_dims,
|
46 |
+
dropout=self.dropout_prob,bn=self.bn)
|
47 |
+
|
48 |
+
self.rq = ResidualVectorQuantizer(num_emb_list, e_dim,
|
49 |
+
kmeans_init = self.kmeans_init,
|
50 |
+
kmeans_iters = self.kmeans_iters,
|
51 |
+
sk_epsilons=self.sk_epsilons,
|
52 |
+
sk_iters=self.sk_iters,)
|
53 |
+
|
54 |
+
self.decode_layer_dims = self.encode_layer_dims[::-1]
|
55 |
+
self.decoder = MLPLayers(layers=self.decode_layer_dims,
|
56 |
+
dropout=self.dropout_prob,bn=self.bn)
|
57 |
+
|
58 |
+
def forward(self, x, use_sk=True):
|
59 |
+
x = self.encoder(x)
|
60 |
+
x_q, rq_loss, indices = self.rq(x,use_sk=use_sk)
|
61 |
+
out = self.decoder(x_q)
|
62 |
+
|
63 |
+
return out, rq_loss, indices
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def get_indices(self, xs, use_sk=False):
|
67 |
+
x_e = self.encoder(xs)
|
68 |
+
_, _, indices = self.rq(x_e, use_sk=use_sk)
|
69 |
+
return indices
|
70 |
+
|
71 |
+
def compute_loss(self, out, quant_loss, xs=None):
|
72 |
+
|
73 |
+
if self.loss_type == 'mse':
|
74 |
+
loss_recon = F.mse_loss(out, xs, reduction='mean')
|
75 |
+
elif self.loss_type == 'l1':
|
76 |
+
loss_recon = F.l1_loss(out, xs, reduction='mean')
|
77 |
+
else:
|
78 |
+
raise ValueError('incompatible loss type')
|
79 |
+
|
80 |
+
loss_total = loss_recon + self.quant_loss_weight * quant_loss
|
81 |
+
|
82 |
+
return loss_total, loss_recon
|
index/models/vq.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .layers import kmeans, sinkhorn_algorithm
|
5 |
+
|
6 |
+
|
7 |
+
class VectorQuantizer(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, n_e, e_dim,
|
10 |
+
beta = 0.25, kmeans_init = False, kmeans_iters = 10,
|
11 |
+
sk_epsilon=0.01, sk_iters=100):
|
12 |
+
super().__init__()
|
13 |
+
self.n_e = n_e
|
14 |
+
self.e_dim = e_dim
|
15 |
+
self.beta = beta
|
16 |
+
self.kmeans_init = kmeans_init
|
17 |
+
self.kmeans_iters = kmeans_iters
|
18 |
+
self.sk_epsilon = sk_epsilon
|
19 |
+
self.sk_iters = sk_iters
|
20 |
+
|
21 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
22 |
+
if not kmeans_init:
|
23 |
+
self.initted = True
|
24 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
25 |
+
else:
|
26 |
+
self.initted = False
|
27 |
+
self.embedding.weight.data.zero_()
|
28 |
+
|
29 |
+
def get_codebook(self):
|
30 |
+
return self.embedding.weight
|
31 |
+
|
32 |
+
def get_codebook_entry(self, indices, shape=None):
|
33 |
+
# get quantized latent vectors
|
34 |
+
z_q = self.embedding(indices)
|
35 |
+
if shape is not None:
|
36 |
+
z_q = z_q.view(shape)
|
37 |
+
|
38 |
+
return z_q
|
39 |
+
|
40 |
+
def init_emb(self, data):
|
41 |
+
|
42 |
+
centers = kmeans(
|
43 |
+
data,
|
44 |
+
self.n_e,
|
45 |
+
self.kmeans_iters,
|
46 |
+
)
|
47 |
+
|
48 |
+
self.embedding.weight.data.copy_(centers)
|
49 |
+
self.initted = True
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def center_distance_for_constraint(distances):
|
53 |
+
# distances: B, K
|
54 |
+
max_distance = distances.max()
|
55 |
+
min_distance = distances.min()
|
56 |
+
|
57 |
+
middle = (max_distance + min_distance) / 2
|
58 |
+
amplitude = max_distance - middle + 1e-5
|
59 |
+
assert amplitude > 0
|
60 |
+
centered_distances = (distances - middle) / amplitude
|
61 |
+
return centered_distances
|
62 |
+
|
63 |
+
def forward(self, x, use_sk=True):
|
64 |
+
# Flatten input
|
65 |
+
latent = x.view(-1, self.e_dim)
|
66 |
+
|
67 |
+
if not self.initted and self.training:
|
68 |
+
self.init_emb(latent)
|
69 |
+
|
70 |
+
# Calculate the L2 Norm between latent and Embedded weights
|
71 |
+
d = torch.sum(latent**2, dim=1, keepdim=True) + \
|
72 |
+
torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \
|
73 |
+
2 * torch.matmul(latent, self.embedding.weight.t())
|
74 |
+
if not use_sk or self.sk_epsilon <= 0:
|
75 |
+
indices = torch.argmin(d, dim=-1)
|
76 |
+
# print("=======",self.sk_epsilon)
|
77 |
+
else:
|
78 |
+
# print("++++++++",self.sk_epsilon)
|
79 |
+
d = self.center_distance_for_constraint(d)
|
80 |
+
d = d.double()
|
81 |
+
Q = sinkhorn_algorithm(d,self.sk_epsilon,self.sk_iters)
|
82 |
+
# print(Q.sum(0)[:10])
|
83 |
+
if torch.isnan(Q).any() or torch.isinf(Q).any():
|
84 |
+
print(f"Sinkhorn Algorithm returns nan/inf values.")
|
85 |
+
indices = torch.argmax(Q, dim=-1)
|
86 |
+
|
87 |
+
# indices = torch.argmin(d, dim=-1)
|
88 |
+
|
89 |
+
x_q = self.embedding(indices).view(x.shape)
|
90 |
+
|
91 |
+
# compute loss for embedding
|
92 |
+
commitment_loss = F.mse_loss(x_q.detach(), x)
|
93 |
+
codebook_loss = F.mse_loss(x_q, x.detach())
|
94 |
+
loss = codebook_loss + self.beta * commitment_loss
|
95 |
+
|
96 |
+
# preserve gradients
|
97 |
+
x_q = x + (x_q - x).detach()
|
98 |
+
|
99 |
+
indices = indices.view(x.shape[:-1])
|
100 |
+
|
101 |
+
return x_q, loss, indices
|
102 |
+
|
103 |
+
|
index/run.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA_PATH=/datain/v-yinju/rqvae-zzx/data/instruments-emb-llama.npy
|
2 |
+
CKPT_DIR=your_ckpt_save_directory # E.g., /datain/v-yinju/rqvae-zzx/model
|
3 |
+
mkdir -p $CKPT_DIR
|
4 |
+
python -u main.py \
|
5 |
+
--num_emb_list 256 256 256 256 \
|
6 |
+
--sk_epsilons 0.0 0.0 0.0 0.003 \
|
7 |
+
--lr 1e-3 \
|
8 |
+
--device cuda:0 \
|
9 |
+
--batch_size 1024 \
|
10 |
+
--data_path $DATA_PATH \
|
11 |
+
--ckpt_dir $CKPT_DIR
|
12 |
+
|
13 |
+
# Infer item index
|
14 |
+
# python generate_indices.py \
|
15 |
+
# --ckpt_path your_rqvae_model_path \ E.g., /datain/v-yinju/rqvae-zzx/model/20241127/best_collision_model.pth
|
16 |
+
# --data_path $DATA_PATH \
|
17 |
+
# --save_path your_index_save_path \ E.g., /datain/v-yinju/rqvae-zzx/model/20241127/indices.json
|
18 |
+
# --device cuda:0
|
index/trainer.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from time import time
|
6 |
+
from torch import optim
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from utils import ensure_dir,set_color,get_local_time
|
10 |
+
import os
|
11 |
+
|
12 |
+
class Trainer(object):
|
13 |
+
|
14 |
+
def __init__(self, args, model):
|
15 |
+
self.args = args
|
16 |
+
self.model = model
|
17 |
+
self.logger = logging.getLogger()
|
18 |
+
|
19 |
+
self.lr = args.lr
|
20 |
+
self.learner = args.learner
|
21 |
+
self.weight_decay = args.weight_decay
|
22 |
+
self.epochs = args.epochs
|
23 |
+
self.eval_step = min(args.eval_step, self.epochs)
|
24 |
+
self.device = args.device
|
25 |
+
self.device = torch.device(self.device)
|
26 |
+
self.ckpt_dir = args.ckpt_dir
|
27 |
+
saved_model_dir = "{}".format(get_local_time())
|
28 |
+
self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir)
|
29 |
+
ensure_dir(self.ckpt_dir)
|
30 |
+
|
31 |
+
self.best_loss = np.inf
|
32 |
+
self.best_collision_rate = np.inf
|
33 |
+
self.best_loss_ckpt = "best_loss_model.pth"
|
34 |
+
self.best_collision_ckpt = "best_collision_model.pth"
|
35 |
+
self.optimizer = self._build_optimizer()
|
36 |
+
self.model = self.model.to(self.device)
|
37 |
+
|
38 |
+
def _build_optimizer(self):
|
39 |
+
|
40 |
+
params = self.model.parameters()
|
41 |
+
learner = self.learner
|
42 |
+
learning_rate = self.lr
|
43 |
+
weight_decay = self.weight_decay
|
44 |
+
|
45 |
+
if learner.lower() == "adam":
|
46 |
+
optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)
|
47 |
+
elif learner.lower() == "sgd":
|
48 |
+
optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
|
49 |
+
elif learner.lower() == "adagrad":
|
50 |
+
optimizer = optim.Adagrad(
|
51 |
+
params, lr=learning_rate, weight_decay=weight_decay
|
52 |
+
)
|
53 |
+
for state in optimizer.state.values():
|
54 |
+
for k, v in state.items():
|
55 |
+
if torch.is_tensor(v):
|
56 |
+
state[k] = v.to(self.device)
|
57 |
+
elif learner.lower() == "rmsprop":
|
58 |
+
optimizer = optim.RMSprop(
|
59 |
+
params, lr=learning_rate, weight_decay=weight_decay
|
60 |
+
)
|
61 |
+
elif learner.lower() == 'adamw':
|
62 |
+
optimizer = optim.AdamW(
|
63 |
+
params, lr=learning_rate, weight_decay=weight_decay
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
self.logger.warning(
|
67 |
+
"Received unrecognized optimizer, set default Adam optimizer"
|
68 |
+
)
|
69 |
+
optimizer = optim.Adam(params, lr=learning_rate)
|
70 |
+
return optimizer
|
71 |
+
def _check_nan(self, loss):
|
72 |
+
if torch.isnan(loss):
|
73 |
+
raise ValueError("Training loss is nan")
|
74 |
+
|
75 |
+
def _train_epoch(self, train_data, epoch_idx):
|
76 |
+
|
77 |
+
self.model.train()
|
78 |
+
|
79 |
+
total_loss = 0
|
80 |
+
total_recon_loss = 0
|
81 |
+
iter_data = tqdm(
|
82 |
+
train_data,
|
83 |
+
total=len(train_data),
|
84 |
+
ncols=100,
|
85 |
+
desc=set_color(f"Train {epoch_idx}","pink"),
|
86 |
+
)
|
87 |
+
|
88 |
+
for batch_idx, data in enumerate(iter_data):
|
89 |
+
data = data.to(self.device)
|
90 |
+
self.optimizer.zero_grad()
|
91 |
+
out, rq_loss, indices = self.model(data)
|
92 |
+
loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data)
|
93 |
+
self._check_nan(loss)
|
94 |
+
loss.backward()
|
95 |
+
self.optimizer.step()
|
96 |
+
total_loss += loss.item()
|
97 |
+
total_recon_loss += loss_recon.item()
|
98 |
+
|
99 |
+
return total_loss, total_recon_loss
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def _valid_epoch(self, valid_data):
|
103 |
+
|
104 |
+
self.model.eval()
|
105 |
+
|
106 |
+
iter_data =tqdm(
|
107 |
+
valid_data,
|
108 |
+
total=len(valid_data),
|
109 |
+
ncols=100,
|
110 |
+
desc=set_color(f"Evaluate ", "pink"),
|
111 |
+
)
|
112 |
+
indices_set = set()
|
113 |
+
num_sample = 0
|
114 |
+
for batch_idx, data in enumerate(iter_data):
|
115 |
+
num_sample += len(data)
|
116 |
+
data = data.to(self.device)
|
117 |
+
indices = self.model.get_indices(data)
|
118 |
+
indices = indices.view(-1,indices.shape[-1]).cpu().numpy()
|
119 |
+
for index in indices:
|
120 |
+
code = "-".join([str(int(_)) for _ in index])
|
121 |
+
indices_set.add(code)
|
122 |
+
|
123 |
+
collision_rate = (num_sample - len(indices_set))/num_sample
|
124 |
+
|
125 |
+
return collision_rate
|
126 |
+
|
127 |
+
def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None):
|
128 |
+
|
129 |
+
ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \
|
130 |
+
else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate))
|
131 |
+
state = {
|
132 |
+
"args": self.args,
|
133 |
+
"epoch": epoch,
|
134 |
+
"best_loss": self.best_loss,
|
135 |
+
"best_collision_rate": self.best_collision_rate,
|
136 |
+
"state_dict": self.model.state_dict(),
|
137 |
+
"optimizer": self.optimizer.state_dict(),
|
138 |
+
}
|
139 |
+
torch.save(state, ckpt_path, pickle_protocol=4)
|
140 |
+
|
141 |
+
self.logger.info(
|
142 |
+
set_color("Saving current", "blue") + f": {ckpt_path}"
|
143 |
+
)
|
144 |
+
|
145 |
+
def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss):
|
146 |
+
train_loss_output = (
|
147 |
+
set_color("epoch %d training", "green")
|
148 |
+
+ " ["
|
149 |
+
+ set_color("time", "blue")
|
150 |
+
+ ": %.2fs, "
|
151 |
+
) % (epoch_idx, e_time - s_time)
|
152 |
+
train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss
|
153 |
+
train_loss_output +=", "
|
154 |
+
train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss
|
155 |
+
return train_loss_output + "]"
|
156 |
+
|
157 |
+
|
158 |
+
def fit(self, data):
|
159 |
+
|
160 |
+
cur_eval_step = 0
|
161 |
+
|
162 |
+
for epoch_idx in range(self.epochs):
|
163 |
+
# train
|
164 |
+
training_start_time = time()
|
165 |
+
train_loss, train_recon_loss = self._train_epoch(data, epoch_idx)
|
166 |
+
training_end_time = time()
|
167 |
+
train_loss_output = self._generate_train_loss_output(
|
168 |
+
epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss
|
169 |
+
)
|
170 |
+
self.logger.info(train_loss_output)
|
171 |
+
|
172 |
+
if train_loss < self.best_loss:
|
173 |
+
self.best_loss = train_loss
|
174 |
+
# self._save_checkpoint(epoch=epoch_idx,ckpt_file=self.best_loss_ckpt)
|
175 |
+
|
176 |
+
# eval
|
177 |
+
if (epoch_idx + 1) % self.eval_step == 0:
|
178 |
+
valid_start_time = time()
|
179 |
+
collision_rate = self._valid_epoch(data)
|
180 |
+
|
181 |
+
if collision_rate < self.best_collision_rate:
|
182 |
+
self.best_collision_rate = collision_rate
|
183 |
+
cur_eval_step = 0
|
184 |
+
self._save_checkpoint(epoch_idx, collision_rate=collision_rate,
|
185 |
+
ckpt_file=self.best_collision_ckpt)
|
186 |
+
else:
|
187 |
+
cur_eval_step += 1
|
188 |
+
|
189 |
+
|
190 |
+
valid_end_time = time()
|
191 |
+
valid_score_output = (
|
192 |
+
set_color("epoch %d evaluating", "green")
|
193 |
+
+ " ["
|
194 |
+
+ set_color("time", "blue")
|
195 |
+
+ ": %.2fs, "
|
196 |
+
+ set_color("collision_rate", "blue")
|
197 |
+
+ ": %f]"
|
198 |
+
) % (epoch_idx, valid_end_time - valid_start_time, collision_rate)
|
199 |
+
|
200 |
+
self.logger.info(valid_score_output)
|
201 |
+
if epoch_idx>1000:
|
202 |
+
self._save_checkpoint(epoch_idx, collision_rate=collision_rate)
|
203 |
+
|
204 |
+
|
205 |
+
return self.best_loss, self.best_collision_rate
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
index/utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def ensure_dir(dir_path):
|
7 |
+
|
8 |
+
os.makedirs(dir_path, exist_ok=True)
|
9 |
+
|
10 |
+
def set_color(log, color, highlight=True):
|
11 |
+
color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"]
|
12 |
+
try:
|
13 |
+
index = color_set.index(color)
|
14 |
+
except:
|
15 |
+
index = len(color_set) - 1
|
16 |
+
prev_log = "\033["
|
17 |
+
if highlight:
|
18 |
+
prev_log += "1;3"
|
19 |
+
else:
|
20 |
+
prev_log += "0;3"
|
21 |
+
prev_log += str(index) + "m"
|
22 |
+
return prev_log + log + "\033[0m"
|
23 |
+
|
24 |
+
def get_local_time():
|
25 |
+
r"""Get current time
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
str: current time
|
29 |
+
"""
|
30 |
+
cur = datetime.datetime.now()
|
31 |
+
cur = cur.strftime("%b-%d-%Y_%H-%M-%S")
|
32 |
+
|
33 |
+
return cur
|
34 |
+
|
35 |
+
|
36 |
+
|
instruments_eval.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET=Instruments
|
2 |
+
BASE=/datain/v-yinju/llama-7b
|
3 |
+
DATA_PATH=/datain/v-yinju/rqvae-zzx/data
|
4 |
+
CKPT_PATH=/datain/v-yinju/RQVAE_Bench/llama
|
5 |
+
RESULTS_FILE=$CKPT_PATH/result.json
|
6 |
+
INDEX=/datain/v-yinju/RQVAE_Bench/rqvae/Nov-27-2024_23-08-08/indices.json
|
7 |
+
|
8 |
+
torchrun --nproc_per_node=8 test_ddp.py \
|
9 |
+
--base_model $BASE \
|
10 |
+
--ckpt_path $CKPT_PATH \
|
11 |
+
--dataset $DATASET \
|
12 |
+
--data_path $DATA_PATH \
|
13 |
+
--results_file $RESULTS_FILE \
|
14 |
+
--test_batch_size 1 \
|
15 |
+
--num_beams 10 \
|
16 |
+
--test_prompt_ids all \
|
17 |
+
--index_file $INDEX
|
instruments_train.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export WANDB_MODE=disabled
|
2 |
+
export CUDA_LAUNCH_BLOCKING=1
|
3 |
+
|
4 |
+
DATASET=Instruments
|
5 |
+
BASE_MODEL=/datain/v-yinju/llama-7b
|
6 |
+
DATA_PATH=/datain/v-yinju/rqvae-zzx/data
|
7 |
+
INDEX=your_index_save_path
|
8 |
+
OUTPUT_DIR=your_ckpt_save_dir
|
9 |
+
|
10 |
+
mkdir -p $OUTPUT_DIR
|
11 |
+
|
12 |
+
torchrun --nproc_per_node=8 lora_finetune.py \
|
13 |
+
--base_model $BASE_MODEL \
|
14 |
+
--output_dir $OUTPUT_DIR \
|
15 |
+
--dataset $DATASET \
|
16 |
+
--data_path $DATA_PATH \
|
17 |
+
--per_device_batch_size 6 \
|
18 |
+
--gradient_accumulation_steps 2 \
|
19 |
+
--learning_rate 5e-5 \
|
20 |
+
--epochs 4 \
|
21 |
+
--weight_decay 0.01 \
|
22 |
+
--save_and_eval_strategy epoch \
|
23 |
+
--fp16 \
|
24 |
+
--deepspeed ./config/ds_z2_fp16.json \
|
25 |
+
--dataloader_num_workers 4 \
|
26 |
+
--only_train_response \
|
27 |
+
--tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
|
28 |
+
--train_prompt_sample_num 1,1,1,1,1,1 \
|
29 |
+
--train_data_sample_num 0,0,0,0,0,0 \
|
30 |
+
--index_file $INDEX
|
31 |
+
|
32 |
+
cd convert
|
33 |
+
nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
|
34 |
+
cd ..
|
lora_finetune.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
|
9 |
+
|
10 |
+
from peft import (
|
11 |
+
TaskType,
|
12 |
+
LoraConfig,
|
13 |
+
get_peft_model,
|
14 |
+
get_peft_model_state_dict,
|
15 |
+
set_peft_model_state_dict,
|
16 |
+
)
|
17 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
|
18 |
+
|
19 |
+
from utils import *
|
20 |
+
from collator import Collator
|
21 |
+
|
22 |
+
def train(args):
|
23 |
+
|
24 |
+
set_seed(args.seed)
|
25 |
+
ensure_dir(args.output_dir)
|
26 |
+
|
27 |
+
device_map = "auto"
|
28 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
29 |
+
ddp = world_size != 1
|
30 |
+
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
|
31 |
+
if local_rank == 0:
|
32 |
+
print(vars(args))
|
33 |
+
|
34 |
+
if ddp:
|
35 |
+
device_map = {"": local_rank}
|
36 |
+
|
37 |
+
config = LlamaConfig.from_pretrained(args.base_model)
|
38 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
39 |
+
args.base_model,
|
40 |
+
model_max_length=args.model_max_length,
|
41 |
+
padding_side="right",
|
42 |
+
)
|
43 |
+
tokenizer.pad_token_id = 0
|
44 |
+
|
45 |
+
train_data, valid_data = load_datasets(args)
|
46 |
+
add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
|
47 |
+
config.vocab_size = len(tokenizer)
|
48 |
+
if local_rank == 0:
|
49 |
+
print("add {} new token.".format(add_num))
|
50 |
+
print("data num:", len(train_data))
|
51 |
+
tokenizer.save_pretrained(args.output_dir)
|
52 |
+
config.save_pretrained(args.output_dir)
|
53 |
+
|
54 |
+
collator = Collator(args, tokenizer)
|
55 |
+
|
56 |
+
model = LlamaForCausalLM.from_pretrained(
|
57 |
+
args.base_model,
|
58 |
+
torch_dtype=torch.float16,
|
59 |
+
device_map=device_map,
|
60 |
+
)
|
61 |
+
model.resize_token_embeddings(len(tokenizer))
|
62 |
+
|
63 |
+
config = LoraConfig(
|
64 |
+
r=args.lora_r,
|
65 |
+
lora_alpha=args.lora_alpha,
|
66 |
+
target_modules=args.lora_target_modules.split(","),
|
67 |
+
modules_to_save=args.lora_modules_to_save.split(","),
|
68 |
+
lora_dropout=args.lora_dropout,
|
69 |
+
bias="none",
|
70 |
+
inference_mode=False,
|
71 |
+
task_type=TaskType.CAUSAL_LM,
|
72 |
+
)
|
73 |
+
model = get_peft_model(model, config)
|
74 |
+
|
75 |
+
if args.resume_from_checkpoint:
|
76 |
+
checkpoint_name = os.path.join(
|
77 |
+
args.resume_from_checkpoint, "adapter_model.bin"
|
78 |
+
) # only LoRA model - LoRA config above has to fit
|
79 |
+
args.resume_from_checkpoint = False # So the trainer won't try loading its state
|
80 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
81 |
+
if os.path.exists(checkpoint_name):
|
82 |
+
if local_rank == 0:
|
83 |
+
print(f"Restarting from {checkpoint_name}")
|
84 |
+
adapters_weights = torch.load(checkpoint_name)
|
85 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
86 |
+
else:
|
87 |
+
if local_rank == 0:
|
88 |
+
print(f"Checkpoint {checkpoint_name} not found")
|
89 |
+
|
90 |
+
for n, p in model.named_parameters():
|
91 |
+
if "original_module" in n and any(module_name in n for module_name in config.modules_to_save):
|
92 |
+
p.requires_grad = False
|
93 |
+
|
94 |
+
if local_rank == 0:
|
95 |
+
model.print_trainable_parameters()
|
96 |
+
|
97 |
+
|
98 |
+
if not ddp and torch.cuda.device_count() > 1:
|
99 |
+
model.is_parallelizable = True
|
100 |
+
model.model_parallel = True
|
101 |
+
|
102 |
+
trainer = transformers.Trainer(
|
103 |
+
model=model,
|
104 |
+
train_dataset=train_data,
|
105 |
+
eval_dataset=valid_data,
|
106 |
+
args=transformers.TrainingArguments(
|
107 |
+
seed=args.seed,
|
108 |
+
per_device_train_batch_size=args.per_device_batch_size,
|
109 |
+
per_device_eval_batch_size=args.per_device_batch_size,
|
110 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
111 |
+
warmup_ratio=args.warmup_ratio,
|
112 |
+
num_train_epochs=args.epochs,
|
113 |
+
learning_rate=args.learning_rate,
|
114 |
+
weight_decay=args.weight_decay,
|
115 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
116 |
+
fp16=args.fp16,
|
117 |
+
bf16=args.bf16,
|
118 |
+
logging_steps=args.logging_step,
|
119 |
+
optim=args.optim,
|
120 |
+
gradient_checkpointing=True,
|
121 |
+
evaluation_strategy=args.save_and_eval_strategy,
|
122 |
+
save_strategy=args.save_and_eval_strategy,
|
123 |
+
eval_steps=args.save_and_eval_steps,
|
124 |
+
save_steps=args.save_and_eval_steps,
|
125 |
+
output_dir=args.output_dir,
|
126 |
+
save_total_limit=5,
|
127 |
+
load_best_model_at_end=True,
|
128 |
+
deepspeed=args.deepspeed,
|
129 |
+
ddp_find_unused_parameters=False if ddp else None,
|
130 |
+
report_to=None,
|
131 |
+
eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000,
|
132 |
+
dataloader_num_workers = args.dataloader_num_workers,
|
133 |
+
dataloader_prefetch_factor = args.dataloader_prefetch_factor
|
134 |
+
),
|
135 |
+
tokenizer=tokenizer,
|
136 |
+
data_collator=collator,
|
137 |
+
)
|
138 |
+
model.config.use_cache = False
|
139 |
+
|
140 |
+
# old_state_dict = model.state_dict
|
141 |
+
# model.state_dict = (
|
142 |
+
# lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
143 |
+
# ).__get__(model, type(model))
|
144 |
+
|
145 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
146 |
+
model = torch.compile(model)
|
147 |
+
|
148 |
+
trainer.train(
|
149 |
+
resume_from_checkpoint=args.resume_from_checkpoint,
|
150 |
+
)
|
151 |
+
|
152 |
+
trainer.save_state()
|
153 |
+
trainer.save_model(output_dir=args.output_dir)
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
parser = argparse.ArgumentParser(description='LLMRec')
|
158 |
+
parser = parse_global_args(parser)
|
159 |
+
parser = parse_train_args(parser)
|
160 |
+
parser = parse_dataset_args(parser)
|
161 |
+
|
162 |
+
args = parser.parse_args()
|
163 |
+
|
164 |
+
train(args)
|
prompt.py
ADDED
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
sft_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." \
|
4 |
+
"\n\n### Instruction:\n{instruction}\n\n### Response:{response}"
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
all_prompt = {}
|
13 |
+
|
14 |
+
# =====================================================
|
15 |
+
# Task 1 -- Sequential Recommendation -- 17 Prompt
|
16 |
+
# =====================================================
|
17 |
+
|
18 |
+
seqrec_prompt = []
|
19 |
+
|
20 |
+
#####——0
|
21 |
+
prompt = {}
|
22 |
+
prompt["instruction"] = "The user has interacted with items {inters} in chronological order. Can you predict the next possible item that the user may expect?"
|
23 |
+
prompt["response"] = "{item}"
|
24 |
+
seqrec_prompt.append(prompt)
|
25 |
+
|
26 |
+
#####——1
|
27 |
+
prompt = {}
|
28 |
+
prompt["instruction"] = "I find the user's historical interactive items: {inters}, and I want to know what next item the user needs. Can you help me decide?"
|
29 |
+
prompt["response"] = "{item}"
|
30 |
+
seqrec_prompt.append(prompt)
|
31 |
+
|
32 |
+
#####——2
|
33 |
+
prompt = {}
|
34 |
+
prompt["instruction"] = "Here are the user's historical interactions: {inters}, try to recommend another item to the user. Note that the historical interactions are arranged in chronological order."
|
35 |
+
prompt["response"] = "{item}"
|
36 |
+
seqrec_prompt.append(prompt)
|
37 |
+
|
38 |
+
#####——3
|
39 |
+
prompt = {}
|
40 |
+
prompt["instruction"] = "Based on the items that the user has interacted with: {inters}, can you determine what item would be recommended to him next?"
|
41 |
+
prompt["response"] = "{item}"
|
42 |
+
seqrec_prompt.append(prompt)
|
43 |
+
|
44 |
+
#####——4
|
45 |
+
prompt = {}
|
46 |
+
prompt["instruction"] = "The user has interacted with the following items in order: {inters}. What else do you think the user need?"
|
47 |
+
prompt["response"] = "{item}"
|
48 |
+
seqrec_prompt.append(prompt)
|
49 |
+
|
50 |
+
#####——5
|
51 |
+
prompt = {}
|
52 |
+
prompt["instruction"] = "Here is the item interaction history of the user: {inters}, what to recommend to the user next?"
|
53 |
+
prompt["response"] = "{item}"
|
54 |
+
seqrec_prompt.append(prompt)
|
55 |
+
|
56 |
+
#####——6
|
57 |
+
prompt = {}
|
58 |
+
prompt["instruction"] = "Which item would the user be likely to interact with next after interacting with items {inters}?"
|
59 |
+
prompt["response"] = "{item}"
|
60 |
+
seqrec_prompt.append(prompt)
|
61 |
+
|
62 |
+
#####——7
|
63 |
+
prompt = {}
|
64 |
+
prompt["instruction"] = "By analyzing the user's historical interactions with items {inters}, what is the next expected interaction item?"
|
65 |
+
prompt["response"] = "{item}"
|
66 |
+
seqrec_prompt.append(prompt)
|
67 |
+
|
68 |
+
#####——8
|
69 |
+
prompt = {}
|
70 |
+
prompt["instruction"] = "After interacting with items {inters}, what is the next item that could be recommended for the user?"
|
71 |
+
prompt["response"] = "{item}"
|
72 |
+
seqrec_prompt.append(prompt)
|
73 |
+
|
74 |
+
#####——9
|
75 |
+
prompt = {}
|
76 |
+
prompt["instruction"] = "Given the user's historical interactive items arranged in chronological order: {inters}, can you recommend a suitable item for the user?"
|
77 |
+
prompt["response"] = "{item}"
|
78 |
+
seqrec_prompt.append(prompt)
|
79 |
+
|
80 |
+
#####——10
|
81 |
+
prompt = {}
|
82 |
+
prompt["instruction"] = "Considering the user has interacted with items {inters}. What is the next recommendation for the user?"
|
83 |
+
prompt["response"] = "{item}"
|
84 |
+
seqrec_prompt.append(prompt)
|
85 |
+
|
86 |
+
#####——11
|
87 |
+
prompt = {}
|
88 |
+
prompt["instruction"] = "What is the top recommended item for the user who has previously interacted with items {inters} in order?"
|
89 |
+
prompt["response"] = "{item}"
|
90 |
+
seqrec_prompt.append(prompt)
|
91 |
+
|
92 |
+
#####——12
|
93 |
+
prompt = {}
|
94 |
+
prompt["instruction"] = "The user has interacted with the following items in the past in order: {inters}. Please predict the next item that the user most desires based on the given interaction records."
|
95 |
+
prompt["response"] = "{item}"
|
96 |
+
seqrec_prompt.append(prompt)
|
97 |
+
|
98 |
+
# prompt = {}
|
99 |
+
# prompt["instruction"] = "The user has interacted with the following items in the past in order: {inters}. Please predict the next item that the user is most likely to interact with based on the given interaction record. Note that his most recently interacted item is {}."
|
100 |
+
# prompt["response"] = "{item}"
|
101 |
+
# prompt["task"] = "sequential"
|
102 |
+
# prompt["id"] = "1-13"
|
103 |
+
#
|
104 |
+
# seqrec_prompt.append(prompt)
|
105 |
+
|
106 |
+
#####——13
|
107 |
+
prompt = {}
|
108 |
+
prompt["instruction"] = "Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy. The historical interactions are provided as follows: {inters}."
|
109 |
+
prompt["response"] = "{item}"
|
110 |
+
seqrec_prompt.append(prompt)
|
111 |
+
|
112 |
+
#####——14
|
113 |
+
prompt = {}
|
114 |
+
prompt["instruction"] = "You can access the user's historical item interaction records: {inters}. Now your task is to recommend the next potential item to him, considering his past interactions."
|
115 |
+
prompt["response"] = "{item}"
|
116 |
+
seqrec_prompt.append(prompt)
|
117 |
+
|
118 |
+
#####——15
|
119 |
+
prompt = {}
|
120 |
+
prompt["instruction"] = "You have observed that the user has interacted with the following items: {inters}, please recommend a next item that you think would be suitable for the user."
|
121 |
+
prompt["response"] = "{item}"
|
122 |
+
seqrec_prompt.append(prompt)
|
123 |
+
|
124 |
+
#####——16
|
125 |
+
prompt = {}
|
126 |
+
prompt["instruction"] = "You have obtained the ordered list of user historical interaction items, which is as follows: {inters}. Using this history as a reference, please select the next item to recommend to the user."
|
127 |
+
prompt["response"] = "{item}"
|
128 |
+
seqrec_prompt.append(prompt)
|
129 |
+
|
130 |
+
all_prompt["seqrec"] = seqrec_prompt
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
# ========================================================
|
135 |
+
# Task 2 -- Item2Index -- 19 Prompt
|
136 |
+
# ========================================================
|
137 |
+
# Remove periods when inputting
|
138 |
+
|
139 |
+
item2index_prompt = []
|
140 |
+
|
141 |
+
# ========================================================
|
142 |
+
# Title2Index
|
143 |
+
|
144 |
+
#####——0
|
145 |
+
prompt = {}
|
146 |
+
prompt["instruction"] = "Which item has the title: \"{title}\"?"
|
147 |
+
prompt["response"] = "{item}"
|
148 |
+
item2index_prompt.append(prompt)
|
149 |
+
|
150 |
+
#####——1
|
151 |
+
prompt = {}
|
152 |
+
prompt["instruction"] = "Which item is assigned the title: \"{title}\"?"
|
153 |
+
prompt["response"] = "{item}"
|
154 |
+
item2index_prompt.append(prompt)
|
155 |
+
|
156 |
+
#####——2
|
157 |
+
prompt = {}
|
158 |
+
prompt["instruction"] = "An item is called \"{title}\", could you please let me know which item it is?"
|
159 |
+
prompt["response"] = "{item}"
|
160 |
+
item2index_prompt.append(prompt)
|
161 |
+
|
162 |
+
#####——3
|
163 |
+
prompt = {}
|
164 |
+
prompt["instruction"] = "Which item is called \"{title}\"?"
|
165 |
+
prompt["response"] = "{item}"
|
166 |
+
item2index_prompt.append(prompt)
|
167 |
+
|
168 |
+
#####——4
|
169 |
+
prompt = {}
|
170 |
+
prompt["instruction"] = "One of the items is named \"{title}\", can you tell me which item this is?"
|
171 |
+
prompt["response"] = "{item}"
|
172 |
+
item2index_prompt.append(prompt)
|
173 |
+
|
174 |
+
#####——5
|
175 |
+
prompt = {}
|
176 |
+
prompt["instruction"] = "What is the item that goes by the title \"{title}\"?"
|
177 |
+
prompt["response"] = "{item}"
|
178 |
+
item2index_prompt.append(prompt)
|
179 |
+
|
180 |
+
# prompt = {}
|
181 |
+
# prompt["instruction"] = "Which item is referred to as \"{title}\"?"
|
182 |
+
# prompt["response"] = "{item}"
|
183 |
+
# item2index_prompt.append(prompt)
|
184 |
+
|
185 |
+
# ========================================================
|
186 |
+
# Description2Index
|
187 |
+
|
188 |
+
#####——6
|
189 |
+
prompt = {}
|
190 |
+
prompt["instruction"] = "An item can be described as follows: \"{description}\". Which item is it describing?"
|
191 |
+
prompt["response"] = "{item}"
|
192 |
+
item2index_prompt.append(prompt)
|
193 |
+
|
194 |
+
#####——7
|
195 |
+
prompt = {}
|
196 |
+
prompt["instruction"] = "Can you tell me what item is described as \"{description}\"?"
|
197 |
+
prompt["response"] = "{item}"
|
198 |
+
item2index_prompt.append(prompt)
|
199 |
+
|
200 |
+
#####——8
|
201 |
+
prompt = {}
|
202 |
+
prompt["instruction"] = "Can you provide the item that corresponds to the following description: \"{description}\"?"
|
203 |
+
prompt["response"] = "{item}"
|
204 |
+
item2index_prompt.append(prompt)
|
205 |
+
|
206 |
+
|
207 |
+
# prompt = {}
|
208 |
+
# prompt["instruction"] = "What is the item described as follows: \"{description}\"?"
|
209 |
+
# prompt["response"] = "{item}"
|
210 |
+
# item2index_prompt.append(prompt)
|
211 |
+
|
212 |
+
#####——9
|
213 |
+
prompt = {}
|
214 |
+
prompt["instruction"] = "Which item has the following characteristics: \"{description}\"?"
|
215 |
+
prompt["response"] = "{item}"
|
216 |
+
item2index_prompt.append(prompt)
|
217 |
+
|
218 |
+
#####——10
|
219 |
+
prompt = {}
|
220 |
+
prompt["instruction"] = "Which item is characterized by the following description: \"{description}\"?"
|
221 |
+
prompt["response"] = "{item}"
|
222 |
+
item2index_prompt.append(prompt)
|
223 |
+
|
224 |
+
#####——11
|
225 |
+
prompt = {}
|
226 |
+
prompt["instruction"] = "I am curious to know which item can be described as follows: \"{description}\". Can you tell me?"
|
227 |
+
prompt["response"] = "{item}"
|
228 |
+
item2index_prompt.append(prompt)
|
229 |
+
|
230 |
+
# ========================================================
|
231 |
+
# Title and Description to index
|
232 |
+
|
233 |
+
#####——12
|
234 |
+
prompt = {}
|
235 |
+
prompt["instruction"] = "An item is called \"{title}\" and described as \"{description}\", can you tell me which item it is?"
|
236 |
+
prompt["response"] = "{item}"
|
237 |
+
item2index_prompt.append(prompt)
|
238 |
+
|
239 |
+
#####——13
|
240 |
+
prompt = {}
|
241 |
+
prompt["instruction"] = "Could you please identify what item is called \"{title}\" and described as \"{description}\"?"
|
242 |
+
prompt["response"] = "{item}"
|
243 |
+
item2index_prompt.append(prompt)
|
244 |
+
|
245 |
+
#####——14
|
246 |
+
prompt = {}
|
247 |
+
prompt["instruction"] = "Which item is called \"{title}\" and has the characteristics described below: \"{description}\"?"
|
248 |
+
prompt["response"] = "{item}"
|
249 |
+
item2index_prompt.append(prompt)
|
250 |
+
|
251 |
+
#####——15
|
252 |
+
prompt = {}
|
253 |
+
prompt["instruction"] = "Please show me which item is named \"{title}\" and its corresponding description is: \"{description}\"."
|
254 |
+
prompt["response"] = "{item}"
|
255 |
+
item2index_prompt.append(prompt)
|
256 |
+
|
257 |
+
|
258 |
+
# prompt = {}
|
259 |
+
# prompt["instruction"] = "Here is an item called \"{title}\" and described as \"{description}\". Which item is it?"
|
260 |
+
# prompt["response"] = "{item}"
|
261 |
+
# item2index_prompt.append(prompt)
|
262 |
+
|
263 |
+
#####——16
|
264 |
+
prompt = {}
|
265 |
+
prompt["instruction"] = "Determine which item this is by its title and description. The title is: \"{title}\", and the description is: \"{description}\"."
|
266 |
+
prompt["response"] = "{item}"
|
267 |
+
item2index_prompt.append(prompt)
|
268 |
+
|
269 |
+
#####——17
|
270 |
+
prompt = {}
|
271 |
+
prompt["instruction"] = "Based on the title: \"{title}\", and the description: \"{description}\", answer which item is this?"
|
272 |
+
prompt["response"] = "{item}"
|
273 |
+
item2index_prompt.append(prompt)
|
274 |
+
|
275 |
+
#####——18
|
276 |
+
prompt = {}
|
277 |
+
prompt["instruction"] = "Can you identify the item from the provided title: \"{title}\", and description: \"{description}\"?"
|
278 |
+
prompt["response"] = "{item}"
|
279 |
+
item2index_prompt.append(prompt)
|
280 |
+
|
281 |
+
all_prompt["item2index"] = item2index_prompt
|
282 |
+
|
283 |
+
|
284 |
+
# ========================================================
|
285 |
+
# Task 3 -- Index2Item --17 Prompt
|
286 |
+
# ========================================================
|
287 |
+
# Remove periods when inputting
|
288 |
+
|
289 |
+
index2item_prompt = []
|
290 |
+
|
291 |
+
# ========================================================
|
292 |
+
# Index2Title
|
293 |
+
|
294 |
+
#####——0
|
295 |
+
prompt = {}
|
296 |
+
prompt["instruction"] = "What is the title of item {item}?"
|
297 |
+
prompt["response"] = "{title}"
|
298 |
+
index2item_prompt.append(prompt)
|
299 |
+
|
300 |
+
#####——1
|
301 |
+
prompt = {}
|
302 |
+
prompt["instruction"] = "What title is assigned to item {item}?"
|
303 |
+
prompt["response"] = "{title}"
|
304 |
+
index2item_prompt.append(prompt)
|
305 |
+
|
306 |
+
#####——2
|
307 |
+
prompt = {}
|
308 |
+
prompt["instruction"] = "Could you please tell me what item {item} is called?"
|
309 |
+
prompt["response"] = "{title}"
|
310 |
+
index2item_prompt.append(prompt)
|
311 |
+
|
312 |
+
#####——3
|
313 |
+
prompt = {}
|
314 |
+
prompt["instruction"] = "Can you provide the title of item {item}?"
|
315 |
+
prompt["response"] = "{title}"
|
316 |
+
index2item_prompt.append(prompt)
|
317 |
+
|
318 |
+
#####——4
|
319 |
+
prompt = {}
|
320 |
+
prompt["instruction"] = "What item {item} is referred to as?"
|
321 |
+
prompt["response"] = "{title}"
|
322 |
+
index2item_prompt.append(prompt)
|
323 |
+
|
324 |
+
#####——5
|
325 |
+
prompt = {}
|
326 |
+
prompt["instruction"] = "Would you mind informing me about the title of item {item}?"
|
327 |
+
prompt["response"] = "{title}"
|
328 |
+
index2item_prompt.append(prompt)
|
329 |
+
|
330 |
+
# ========================================================
|
331 |
+
# Index2Description
|
332 |
+
|
333 |
+
#####——6
|
334 |
+
prompt = {}
|
335 |
+
prompt["instruction"] = "Please provide a description of item {item}."
|
336 |
+
prompt["response"] = "{description}"
|
337 |
+
index2item_prompt.append(prompt)
|
338 |
+
|
339 |
+
#####——7
|
340 |
+
prompt = {}
|
341 |
+
prompt["instruction"] = "Briefly describe item {item}."
|
342 |
+
prompt["response"] = "{description}"
|
343 |
+
index2item_prompt.append(prompt)
|
344 |
+
|
345 |
+
#####——8
|
346 |
+
prompt = {}
|
347 |
+
prompt["instruction"] = "Can you share with me the description corresponding to item {item}?"
|
348 |
+
prompt["response"] = "{description}"
|
349 |
+
index2item_prompt.append(prompt)
|
350 |
+
|
351 |
+
#####——9
|
352 |
+
prompt = {}
|
353 |
+
prompt["instruction"] = "What is the description of item {item}?"
|
354 |
+
prompt["response"] = "{description}"
|
355 |
+
index2item_prompt.append(prompt)
|
356 |
+
|
357 |
+
#####——10
|
358 |
+
prompt = {}
|
359 |
+
prompt["instruction"] = "How to describe the characteristics of item {item}?"
|
360 |
+
prompt["response"] = "{description}"
|
361 |
+
index2item_prompt.append(prompt)
|
362 |
+
|
363 |
+
#####——11
|
364 |
+
prompt = {}
|
365 |
+
prompt["instruction"] = "Could you please tell me what item {item} looks like?"
|
366 |
+
prompt["response"] = "{description}"
|
367 |
+
index2item_prompt.append(prompt)
|
368 |
+
|
369 |
+
|
370 |
+
# ========================================================
|
371 |
+
# index to Title and Description
|
372 |
+
|
373 |
+
#####——12
|
374 |
+
prompt = {}
|
375 |
+
prompt["instruction"] = "What is the title and description of item {item}?"
|
376 |
+
prompt["response"] = "{title}\n\n{description}"
|
377 |
+
index2item_prompt.append(prompt)
|
378 |
+
|
379 |
+
#####——13
|
380 |
+
prompt = {}
|
381 |
+
prompt["instruction"] = "Can you provide the corresponding title and description for item {item}?"
|
382 |
+
prompt["response"] = "{title}\n\n{description}"
|
383 |
+
index2item_prompt.append(prompt)
|
384 |
+
|
385 |
+
#####——14
|
386 |
+
prompt = {}
|
387 |
+
prompt["instruction"] = "Please tell me what item {item} is called, along with a brief description of it."
|
388 |
+
prompt["response"] = "{title}\n\n{description}"
|
389 |
+
index2item_prompt.append(prompt)
|
390 |
+
|
391 |
+
#####——15
|
392 |
+
prompt = {}
|
393 |
+
prompt["instruction"] = "Would you mind informing me about the title of the item {item} and how to describe its characteristics?"
|
394 |
+
prompt["response"] = "{title}\n\n{description}"
|
395 |
+
index2item_prompt.append(prompt)
|
396 |
+
|
397 |
+
#####——16
|
398 |
+
prompt = {}
|
399 |
+
prompt["instruction"] = "I need to know the title and description of item {item}. Could you help me with that?"
|
400 |
+
prompt["response"] = "{title}\n\n{description}"
|
401 |
+
index2item_prompt.append(prompt)
|
402 |
+
|
403 |
+
all_prompt["index2item"] = index2item_prompt
|
404 |
+
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
|
409 |
+
# ========================================================
|
410 |
+
# Task 4 -- FusionSequentialRec -- Prompt
|
411 |
+
# ========================================================
|
412 |
+
|
413 |
+
|
414 |
+
fusionseqrec_prompt = []
|
415 |
+
|
416 |
+
#####——0
|
417 |
+
prompt = {}
|
418 |
+
prompt["instruction"] = "The user has sequentially interacted with items {inters}. Can you recommend the next item for him? Tell me the title of the item?"
|
419 |
+
prompt["response"] = "{title}"
|
420 |
+
fusionseqrec_prompt.append(prompt)
|
421 |
+
|
422 |
+
#####——1
|
423 |
+
prompt = {}
|
424 |
+
prompt["instruction"] = "Based on the user's historical interactions: {inters}, try to predict the title of the item that the user may need next."
|
425 |
+
prompt["response"] = "{title}"
|
426 |
+
fusionseqrec_prompt.append(prompt)
|
427 |
+
|
428 |
+
#####——2
|
429 |
+
prompt = {}
|
430 |
+
prompt["instruction"] = "Utilizing the user's past ordered interactions, which include items {inters}, please recommend the next item you think is suitable for the user and provide its title."
|
431 |
+
prompt["response"] = "{title}"
|
432 |
+
fusionseqrec_prompt.append(prompt)
|
433 |
+
|
434 |
+
|
435 |
+
#####——3
|
436 |
+
prompt = {}
|
437 |
+
prompt["instruction"] = "After interacting with items {inters}, what is the most probable item for the user to interact with next? Kindly provide the item's title."
|
438 |
+
prompt["response"] = "{title}"
|
439 |
+
fusionseqrec_prompt.append(prompt)
|
440 |
+
|
441 |
+
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
#####——4
|
446 |
+
prompt = {}
|
447 |
+
prompt["instruction"] = "Please review the user's historical interactions: {inters}, and describe what kind of item he still needs."
|
448 |
+
prompt["response"] = "{description}"
|
449 |
+
fusionseqrec_prompt.append(prompt)
|
450 |
+
|
451 |
+
#####——5
|
452 |
+
prompt = {}
|
453 |
+
prompt["instruction"] = "Here is the item interaction history of the user: {inters}, please tell me what features he expects from his next item."
|
454 |
+
prompt["response"] = "{description}"
|
455 |
+
fusionseqrec_prompt.append(prompt)
|
456 |
+
|
457 |
+
#####——6
|
458 |
+
prompt = {}
|
459 |
+
prompt["instruction"] = "By analyzing the user's historical interactions with items {inters}, can you infer what the user's next interactive item will look like?"
|
460 |
+
prompt["response"] = "{description}"
|
461 |
+
fusionseqrec_prompt.append(prompt)
|
462 |
+
|
463 |
+
#####——7
|
464 |
+
prompt = {}
|
465 |
+
prompt["instruction"] = "Access the user's historical item interaction records: {inters}. Your objective is to describe the next potential item for him, taking into account his past interactions."
|
466 |
+
prompt["response"] = "{description}"
|
467 |
+
fusionseqrec_prompt.append(prompt)
|
468 |
+
|
469 |
+
|
470 |
+
|
471 |
+
|
472 |
+
|
473 |
+
|
474 |
+
#####——8
|
475 |
+
prompt = {}
|
476 |
+
prompt["instruction"] = "Given the title sequence of user historical interactive items: {inter_titles}, can you recommend a suitable next item for the user?"
|
477 |
+
prompt["response"] = "{item}"
|
478 |
+
fusionseqrec_prompt.append(prompt)
|
479 |
+
|
480 |
+
#####——9
|
481 |
+
prompt = {}
|
482 |
+
prompt["instruction"] = "I possess a user's past interaction history, denoted by the title sequence of interactive items: {inter_titles}, and I am interested in knowing the user's next most desired item. Can you help me?"
|
483 |
+
prompt["response"] = "{item}"
|
484 |
+
fusionseqrec_prompt.append(prompt)
|
485 |
+
|
486 |
+
#####——10
|
487 |
+
prompt = {}
|
488 |
+
prompt["instruction"] = "Considering the title sequence of user history interaction items: {inter_titles}. What is the next recommendation for the user?"
|
489 |
+
prompt["response"] = "{item}"
|
490 |
+
fusionseqrec_prompt.append(prompt)
|
491 |
+
|
492 |
+
#####——11
|
493 |
+
prompt = {}
|
494 |
+
prompt["instruction"] = "You have obtained the ordered title list of user historical interaction items, as follows: {inter_titles}. Based on this historical context, kindly choose the subsequent item for user recommendation."
|
495 |
+
prompt["response"] = "{item}"
|
496 |
+
fusionseqrec_prompt.append(prompt)
|
497 |
+
|
498 |
+
|
499 |
+
all_prompt["fusionseqrec"] = fusionseqrec_prompt
|
500 |
+
|
501 |
+
|
502 |
+
|
503 |
+
|
504 |
+
|
505 |
+
|
506 |
+
|
507 |
+
# ========================================================
|
508 |
+
# Task 5 -- ItemSearch -- Prompt
|
509 |
+
# ========================================================
|
510 |
+
|
511 |
+
|
512 |
+
itemsearch_prompt = []
|
513 |
+
|
514 |
+
#####——0
|
515 |
+
prompt = {}
|
516 |
+
prompt["instruction"] = "Here is the historical interactions of a user: {inters}. And his personalized preferences are as follows: \"{explicit_preference}\". Your task is to recommend an item that is consistent with the user's preference."
|
517 |
+
prompt["response"] = "{item}"
|
518 |
+
itemsearch_prompt.append(prompt)
|
519 |
+
|
520 |
+
#####——1
|
521 |
+
prompt = {}
|
522 |
+
prompt["instruction"] = "The user has interacted with a list of items, which are as follows: {inters}. Based on these interacted items, the user current intent is as follows \"{user_related_intention}\", and your task is to generate an item that matches the user's current intent."
|
523 |
+
prompt["response"] = "{item}"
|
524 |
+
itemsearch_prompt.append(prompt)
|
525 |
+
|
526 |
+
#####——2
|
527 |
+
prompt = {}
|
528 |
+
prompt["instruction"] = "As a recommender system, you are assisting a user who has recently interacted with the following items: {inters}. The user expresses a desire to obtain another item with the following characteristics: \"{item_related_intention}\". Please recommend an item that meets these criteria."
|
529 |
+
prompt["response"] = "{item}"
|
530 |
+
itemsearch_prompt.append(prompt)
|
531 |
+
|
532 |
+
#####——3
|
533 |
+
prompt = {}
|
534 |
+
prompt["instruction"] = "Using the user's current query: \"{query}\" and his historical interactions: {inters}, you can estimate the user's preferences \"{explicit_preference}\". Please respond to the user's query by selecting an item that best matches his preference and query."
|
535 |
+
prompt["response"] = "{item}"
|
536 |
+
itemsearch_prompt.append(prompt)
|
537 |
+
|
538 |
+
#####——4
|
539 |
+
prompt = {}
|
540 |
+
prompt["instruction"] = "The user needs a new item and searches for: \"{query}\". In addition, he has previously interacted with: {inters}. You can obtain his preference by analyzing his historical interactions: \"{explicit_preference}\". Can you recommend an item that best matches the search query and preferences?"
|
541 |
+
prompt["response"] = "{item}"
|
542 |
+
itemsearch_prompt.append(prompt)
|
543 |
+
|
544 |
+
#####——5
|
545 |
+
prompt = {}
|
546 |
+
prompt["instruction"] = "Based on the user's historical interactions with the following items: {inters}. You can infer his preference by observing the historical interactions: \"{explicit_preference}\". Now the user wants a new item and searches for: \"{query}\". Please select a suitable item that matches his preference and search intent."
|
547 |
+
prompt["response"] = "{item}"
|
548 |
+
itemsearch_prompt.append(prompt)
|
549 |
+
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
#####——6
|
555 |
+
prompt = {}
|
556 |
+
prompt["instruction"] = "Suppose you are a search engine, now a user searches that: \"{query}\", can you select an item to respond to the user's query?"
|
557 |
+
prompt["response"] = "{item}"
|
558 |
+
itemsearch_prompt.append(prompt)
|
559 |
+
|
560 |
+
#####——7
|
561 |
+
prompt = {}
|
562 |
+
prompt["instruction"] = "As a search engine, your task is to answer the user's query by generating a related item. The user's query is provided as \"{query}\". Please provide your generated item as your answer."
|
563 |
+
prompt["response"] = "{item}"
|
564 |
+
itemsearch_prompt.append(prompt)
|
565 |
+
|
566 |
+
#####——8
|
567 |
+
prompt = {}
|
568 |
+
prompt["instruction"] = "As a recommender system, your task is to recommend an item that is related to the user's request, which is specified as follows: \"{query}\". Please provide your recommendation."
|
569 |
+
prompt["response"] = "{item}"
|
570 |
+
itemsearch_prompt.append(prompt)
|
571 |
+
|
572 |
+
#####——9
|
573 |
+
prompt = {}
|
574 |
+
prompt["instruction"] = "You meet a user's query: \"{query}\". Please respond to this user by selecting an appropriate item."
|
575 |
+
prompt["response"] = "{item}"
|
576 |
+
itemsearch_prompt.append(prompt)
|
577 |
+
|
578 |
+
|
579 |
+
#####——10
|
580 |
+
prompt = {}
|
581 |
+
prompt["instruction"] = "Your task is to recommend the best item that matches the user's query. Here is the search query of the user: \"{query}\", tell me the item you recommend."
|
582 |
+
prompt["response"] = "{item}"
|
583 |
+
itemsearch_prompt.append(prompt)
|
584 |
+
|
585 |
+
all_prompt["itemsearch"] = itemsearch_prompt
|
586 |
+
|
587 |
+
|
588 |
+
|
589 |
+
|
590 |
+
|
591 |
+
# ========================================================
|
592 |
+
# Task 6 -- PreferenceObtain -- Prompt
|
593 |
+
# ========================================================
|
594 |
+
|
595 |
+
preferenceobtain_prompt = []
|
596 |
+
|
597 |
+
#####——0
|
598 |
+
prompt = {}
|
599 |
+
prompt["instruction"] = "The user has interacted with items {inters} in chronological order. Please estimate his preferences."
|
600 |
+
prompt["response"] = "{explicit_preference}"
|
601 |
+
preferenceobtain_prompt.append(prompt)
|
602 |
+
|
603 |
+
#####——1
|
604 |
+
prompt = {}
|
605 |
+
prompt["instruction"] = "Based on the items that the user has interacted with: {inters}, can you infer what preferences he has?"
|
606 |
+
prompt["response"] = "{explicit_preference}"
|
607 |
+
preferenceobtain_prompt.append(prompt)
|
608 |
+
|
609 |
+
#####——3
|
610 |
+
prompt = {}
|
611 |
+
prompt["instruction"] = "Can you provide a summary of the user's preferences based on his historical interactions: {inters}?"
|
612 |
+
prompt["response"] = "{explicit_preference}"
|
613 |
+
preferenceobtain_prompt.append(prompt)
|
614 |
+
|
615 |
+
#####——4
|
616 |
+
prompt = {}
|
617 |
+
prompt["instruction"] = "After interacting with items {inters} in order, what preferences do you think the user has?"
|
618 |
+
prompt["response"] = "{explicit_preference}"
|
619 |
+
preferenceobtain_prompt.append(prompt)
|
620 |
+
|
621 |
+
#####——5
|
622 |
+
prompt = {}
|
623 |
+
prompt["instruction"] = "Here is the item interaction history of the user: {inters}, could you please infer the user's preferences."
|
624 |
+
prompt["response"] = "{explicit_preference}"
|
625 |
+
preferenceobtain_prompt.append(prompt)
|
626 |
+
|
627 |
+
#####——6
|
628 |
+
prompt = {}
|
629 |
+
prompt["instruction"] = "Based on the user's historical interaction records: {inters}, what are your speculations about his preferences?"
|
630 |
+
prompt["response"] = "{explicit_preference}"
|
631 |
+
preferenceobtain_prompt.append(prompt)
|
632 |
+
|
633 |
+
#####——7
|
634 |
+
prompt = {}
|
635 |
+
prompt["instruction"] = "Given the user's historical interactive items arranged in chronological order: {inters}, what can be inferred about the preferences of the user?"
|
636 |
+
prompt["response"] = "{explicit_preference}"
|
637 |
+
preferenceobtain_prompt.append(prompt)
|
638 |
+
|
639 |
+
#####——8
|
640 |
+
prompt = {}
|
641 |
+
prompt["instruction"] = "Can you speculate on the user's preferences based on his historical item interaction records: {inters}?"
|
642 |
+
prompt["response"] = "{explicit_preference}"
|
643 |
+
preferenceobtain_prompt.append(prompt)
|
644 |
+
|
645 |
+
#####——9
|
646 |
+
prompt = {}
|
647 |
+
prompt["instruction"] = "What is the preferences of a user who has previously interacted with items {inters} sequentially?"
|
648 |
+
prompt["response"] = "{explicit_preference}"
|
649 |
+
preferenceobtain_prompt.append(prompt)
|
650 |
+
|
651 |
+
#####——10
|
652 |
+
prompt = {}
|
653 |
+
prompt["instruction"] = "Using the user's historical interactions as input data, summarize the user's preferences. The historical interactions are provided as follows: {inters}."
|
654 |
+
prompt["response"] = "{explicit_preference}"
|
655 |
+
preferenceobtain_prompt.append(prompt)
|
656 |
+
|
657 |
+
#####——11
|
658 |
+
prompt = {}
|
659 |
+
prompt["instruction"] = "Utilizing the ordered list of the user's historical interaction items as a reference, please make an informed estimation of the user's preferences. The historical interactions are as follows: {inters}."
|
660 |
+
prompt["response"] = "{explicit_preference}"
|
661 |
+
preferenceobtain_prompt.append(prompt)
|
662 |
+
|
663 |
+
all_prompt["preferenceobtain"] = preferenceobtain_prompt
|
run.sh
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export WANDB_MODE=disabled
|
2 |
+
export CUDA_LAUNCH_BLOCKING=1
|
3 |
+
|
4 |
+
DATASET=Games
|
5 |
+
BASE_MODEL=huggyllama/llama-7b
|
6 |
+
DATA_PATH=./data
|
7 |
+
OUTPUT_DIR=./ckpt/$DATASET/
|
8 |
+
|
9 |
+
torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
|
10 |
+
--base_model $BASE_MODEL \
|
11 |
+
--output_dir $OUTPUT_DIR \
|
12 |
+
--dataset $DATASET \
|
13 |
+
--data_path $DATA_PATH \
|
14 |
+
--per_device_batch_size 8 \
|
15 |
+
--gradient_accumulation_steps 2 \
|
16 |
+
--learning_rate 5e-5 \
|
17 |
+
--epochs 4 \
|
18 |
+
--weight_decay 0.01 \
|
19 |
+
--save_and_eval_strategy epoch \
|
20 |
+
--deepspeed ./config/ds_z3_bf16.json \
|
21 |
+
--bf16 \
|
22 |
+
--only_train_response \
|
23 |
+
--tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
|
24 |
+
--train_prompt_sample_num 1,1,1,1,1,1 \
|
25 |
+
--train_data_sample_num 0,0,0,100000,0,0 \
|
26 |
+
--index_file .index.json
|
27 |
+
|
28 |
+
|
29 |
+
cd convert
|
30 |
+
nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
|
31 |
+
cd ..
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
DATASET=Arts
|
39 |
+
BASE_MODEL=huggyllama/llama-7b
|
40 |
+
DATA_PATH=./data
|
41 |
+
OUTPUT_DIR=./ckpt/$DATASET/
|
42 |
+
|
43 |
+
torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
|
44 |
+
--base_model $BASE_MODEL \
|
45 |
+
--output_dir $OUTPUT_DIR \
|
46 |
+
--dataset $DATASET \
|
47 |
+
--data_path $DATA_PATH \
|
48 |
+
--per_device_batch_size 8 \
|
49 |
+
--gradient_accumulation_steps 2 \
|
50 |
+
--learning_rate 5e-5 \
|
51 |
+
--epochs 4 \
|
52 |
+
--weight_decay 0.01 \
|
53 |
+
--save_and_eval_strategy epoch \
|
54 |
+
--deepspeed ./config/ds_z3_bf16.json \
|
55 |
+
--bf16 \
|
56 |
+
--only_train_response \
|
57 |
+
--tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
|
58 |
+
--train_prompt_sample_num 1,1,1,1,1,1 \
|
59 |
+
--train_data_sample_num 0,0,0,30000,0,0 \
|
60 |
+
--index_file .index.json
|
61 |
+
|
62 |
+
|
63 |
+
cd convert
|
64 |
+
nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
|
65 |
+
cd ..
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
DATASET=Instruments
|
72 |
+
BASE_MODEL=huggyllama/llama-7b
|
73 |
+
DATA_PATH=./data
|
74 |
+
OUTPUT_DIR=./ckpt/$DATASET/
|
75 |
+
|
76 |
+
torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
|
77 |
+
--base_model $BASE_MODEL \
|
78 |
+
--output_dir $OUTPUT_DIR \
|
79 |
+
--dataset $DATASET \
|
80 |
+
--data_path $DATA_PATH \
|
81 |
+
--per_device_batch_size 8 \
|
82 |
+
--gradient_accumulation_steps 2 \
|
83 |
+
--learning_rate 5e-5 \
|
84 |
+
--epochs 4 \
|
85 |
+
--weight_decay 0.01 \
|
86 |
+
--save_and_eval_strategy epoch \
|
87 |
+
--deepspeed ./config/ds_z3_bf16.json \
|
88 |
+
--bf16 \
|
89 |
+
--only_train_response \
|
90 |
+
--tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
|
91 |
+
--train_prompt_sample_num 1,1,1,1,1,1 \
|
92 |
+
--train_data_sample_num 0,0,0,20000,0,0 \
|
93 |
+
--index_file .index.json
|
94 |
+
|
95 |
+
|
96 |
+
cd convert
|
97 |
+
nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
|
98 |
+
cd ..
|
run_test.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
DATASET=Games
|
4 |
+
DATA_PATH=./data
|
5 |
+
OUTPUT_DIR=./ckpt/$DATASET/
|
6 |
+
RESULTS_FILE=./results/$DATASET/xxx.json
|
7 |
+
|
8 |
+
python test.py \
|
9 |
+
--gpu_id 0 \
|
10 |
+
--ckpt_path $CKPT_PATH \
|
11 |
+
--dataset $DATASET \
|
12 |
+
--data_path $DATA_PATH \
|
13 |
+
--results_file $RESULTS_FILE \
|
14 |
+
--test_batch_size 1 \
|
15 |
+
--num_beams 20 \
|
16 |
+
--test_prompt_ids all \
|
17 |
+
--index_file .index.json
|
test.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import transformers
|
9 |
+
from peft import PeftModel
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm import tqdm
|
12 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
|
13 |
+
|
14 |
+
from utils import *
|
15 |
+
from collator import TestCollator
|
16 |
+
from prompt import all_prompt
|
17 |
+
from evaluate import get_topk_results, get_metrics_results
|
18 |
+
|
19 |
+
|
20 |
+
def test(args):
|
21 |
+
|
22 |
+
set_seed(args.seed)
|
23 |
+
print(vars(args))
|
24 |
+
|
25 |
+
device_map = {"": args.gpu_id}
|
26 |
+
device = torch.device("cuda",args.gpu_id)
|
27 |
+
|
28 |
+
|
29 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path)
|
30 |
+
if args.lora:
|
31 |
+
model = LlamaForCausalLM.from_pretrained(
|
32 |
+
args.base_model,
|
33 |
+
torch_dtype=torch.bfloat16,
|
34 |
+
low_cpu_mem_usage=True,
|
35 |
+
device_map=device_map,
|
36 |
+
)
|
37 |
+
model.resize_token_embeddings(len(tokenizer))
|
38 |
+
model = PeftModel.from_pretrained(
|
39 |
+
model,
|
40 |
+
args.ckpt_path,
|
41 |
+
torch_dtype=torch.bfloat16,
|
42 |
+
device_map=device_map,
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
model = LlamaForCausalLM.from_pretrained(
|
46 |
+
args.ckpt_path,
|
47 |
+
torch_dtype=torch.bfloat16,
|
48 |
+
low_cpu_mem_usage=True,
|
49 |
+
device_map=device_map,
|
50 |
+
)
|
51 |
+
# assert model.config.vocab_size == len(tokenizer)
|
52 |
+
|
53 |
+
if args.test_prompt_ids == "all":
|
54 |
+
if args.test_task.lower() == "seqrec":
|
55 |
+
prompt_ids = range(len(all_prompt["seqrec"]))
|
56 |
+
elif args.test_task.lower() == "itemsearch":
|
57 |
+
prompt_ids = range(len(all_prompt["itemsearch"]))
|
58 |
+
elif args.test_task.lower() == "fusionseqrec":
|
59 |
+
prompt_ids = range(len(all_prompt["fusionseqrec"]))
|
60 |
+
else:
|
61 |
+
prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")]
|
62 |
+
|
63 |
+
test_data = load_test_dataset(args)
|
64 |
+
collator = TestCollator(args, tokenizer)
|
65 |
+
all_items = test_data.get_all_items()
|
66 |
+
|
67 |
+
|
68 |
+
prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer)
|
69 |
+
|
70 |
+
test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator,
|
71 |
+
shuffle=True, num_workers=4, pin_memory=True)
|
72 |
+
|
73 |
+
|
74 |
+
print("data num:", len(test_data))
|
75 |
+
|
76 |
+
model.eval()
|
77 |
+
|
78 |
+
metrics = args.metrics.split(",")
|
79 |
+
all_prompt_results = []
|
80 |
+
with torch.no_grad():
|
81 |
+
for prompt_id in prompt_ids:
|
82 |
+
|
83 |
+
test_loader.dataset.set_prompt(prompt_id)
|
84 |
+
metrics_results = {}
|
85 |
+
total = 0
|
86 |
+
|
87 |
+
for step, batch in enumerate(tqdm(test_loader)):
|
88 |
+
inputs = batch[0].to(device)
|
89 |
+
targets = batch[1]
|
90 |
+
total += len(targets)
|
91 |
+
|
92 |
+
output = model.generate(
|
93 |
+
input_ids=inputs["input_ids"],
|
94 |
+
attention_mask=inputs["attention_mask"],
|
95 |
+
max_new_tokens=10,
|
96 |
+
# max_length=10,
|
97 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens,
|
98 |
+
num_beams=args.num_beams,
|
99 |
+
num_return_sequences=args.num_beams,
|
100 |
+
output_scores=True,
|
101 |
+
return_dict_in_generate=True,
|
102 |
+
early_stopping=True,
|
103 |
+
)
|
104 |
+
output_ids = output["sequences"]
|
105 |
+
scores = output["sequences_scores"]
|
106 |
+
|
107 |
+
output = tokenizer.batch_decode(
|
108 |
+
output_ids, skip_special_tokens=True
|
109 |
+
)
|
110 |
+
# print(output)
|
111 |
+
topk_res = get_topk_results(output,scores,targets,args.num_beams,
|
112 |
+
all_items=all_items if args.filter_items else None)
|
113 |
+
|
114 |
+
batch_metrics_res = get_metrics_results(topk_res, metrics)
|
115 |
+
# print(batch_metrics_res)
|
116 |
+
|
117 |
+
for m, res in batch_metrics_res.items():
|
118 |
+
if m not in metrics_results:
|
119 |
+
metrics_results[m] = res
|
120 |
+
else:
|
121 |
+
metrics_results[m] += res
|
122 |
+
|
123 |
+
if (step+1)%10 == 0:
|
124 |
+
temp={}
|
125 |
+
for m in metrics_results:
|
126 |
+
temp[m] = metrics_results[m] / total
|
127 |
+
print(temp)
|
128 |
+
|
129 |
+
for m in metrics_results:
|
130 |
+
metrics_results[m] = metrics_results[m] / total
|
131 |
+
|
132 |
+
all_prompt_results.append(metrics_results)
|
133 |
+
print("======================================================")
|
134 |
+
print("Prompt {} results: ".format(prompt_id), metrics_results)
|
135 |
+
print("======================================================")
|
136 |
+
print("")
|
137 |
+
|
138 |
+
mean_results = {}
|
139 |
+
min_results = {}
|
140 |
+
max_results = {}
|
141 |
+
|
142 |
+
for m in metrics:
|
143 |
+
all_res = [_[m] for _ in all_prompt_results]
|
144 |
+
mean_results[m] = sum(all_res)/len(all_res)
|
145 |
+
min_results[m] = min(all_res)
|
146 |
+
max_results[m] = max(all_res)
|
147 |
+
|
148 |
+
print("======================================================")
|
149 |
+
print("Mean results: ", mean_results)
|
150 |
+
print("Min results: ", min_results)
|
151 |
+
print("Max results: ", max_results)
|
152 |
+
print("======================================================")
|
153 |
+
|
154 |
+
|
155 |
+
save_data={}
|
156 |
+
save_data["test_prompt_ids"] = args.test_prompt_ids
|
157 |
+
save_data["mean_results"] = mean_results
|
158 |
+
save_data["min_results"] = min_results
|
159 |
+
save_data["max_results"] = max_results
|
160 |
+
save_data["all_prompt_results"] = all_prompt_results
|
161 |
+
|
162 |
+
with open(args.results_file, "w") as f:
|
163 |
+
json.dump(save_data, f, indent=4)
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
parser = argparse.ArgumentParser(description="LLMRec_test")
|
169 |
+
parser = parse_global_args(parser)
|
170 |
+
parser = parse_dataset_args(parser)
|
171 |
+
parser = parse_test_args(parser)
|
172 |
+
|
173 |
+
args = parser.parse_args()
|
174 |
+
|
175 |
+
test(args)
|
test_ddp.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
import torch.distributed as dist
|
9 |
+
from torch.utils.data.distributed import DistributedSampler
|
10 |
+
from torch.nn.parallel import DistributedDataParallel
|
11 |
+
from peft import PeftModel
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
|
15 |
+
|
16 |
+
from utils import *
|
17 |
+
from collator import TestCollator
|
18 |
+
from prompt import all_prompt
|
19 |
+
from evaluate import get_topk_results, get_metrics_results
|
20 |
+
|
21 |
+
|
22 |
+
def test_ddp(args):
|
23 |
+
|
24 |
+
set_seed(args.seed)
|
25 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
26 |
+
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
|
27 |
+
torch.cuda.set_device(local_rank)
|
28 |
+
if local_rank == 0:
|
29 |
+
print(vars(args))
|
30 |
+
|
31 |
+
dist.init_process_group(backend="nccl", world_size=world_size, rank=local_rank)
|
32 |
+
|
33 |
+
device_map = {"": local_rank}
|
34 |
+
device = torch.device("cuda",local_rank)
|
35 |
+
|
36 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path)
|
37 |
+
args.lora=True
|
38 |
+
if args.lora:
|
39 |
+
model = LlamaForCausalLM.from_pretrained(
|
40 |
+
args.base_model,
|
41 |
+
torch_dtype=torch.float16,
|
42 |
+
low_cpu_mem_usage=True,
|
43 |
+
device_map=device_map,
|
44 |
+
)
|
45 |
+
model.resize_token_embeddings(len(tokenizer))
|
46 |
+
model = PeftModel.from_pretrained(
|
47 |
+
model,
|
48 |
+
args.ckpt_path,
|
49 |
+
torch_dtype=torch.float16,
|
50 |
+
device_map=device_map,
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
model = LlamaForCausalLM.from_pretrained(
|
54 |
+
args.ckpt_path,
|
55 |
+
torch_dtype=torch.float16,
|
56 |
+
low_cpu_mem_usage=True,
|
57 |
+
device_map=device_map,
|
58 |
+
)
|
59 |
+
# assert model.config.vocab_size == len(tokenizer)
|
60 |
+
model = DistributedDataParallel(model, device_ids=[local_rank])
|
61 |
+
|
62 |
+
if args.test_prompt_ids == "all":
|
63 |
+
if args.test_task.lower() == "seqrec":
|
64 |
+
prompt_ids = range(len(all_prompt["seqrec"]))
|
65 |
+
elif args.test_task.lower() == "itemsearch":
|
66 |
+
prompt_ids = range(len(all_prompt["itemsearch"]))
|
67 |
+
elif args.test_task.lower() == "fusionseqrec":
|
68 |
+
prompt_ids = range(len(all_prompt["fusionseqrec"]))
|
69 |
+
else:
|
70 |
+
prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")]
|
71 |
+
|
72 |
+
test_data = load_test_dataset(args)
|
73 |
+
ddp_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=local_rank, drop_last=True)
|
74 |
+
|
75 |
+
test_data = load_test_dataset(args)
|
76 |
+
collator = TestCollator(args, tokenizer)
|
77 |
+
all_items = test_data.get_all_items()
|
78 |
+
|
79 |
+
|
80 |
+
prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer)
|
81 |
+
|
82 |
+
|
83 |
+
test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator,
|
84 |
+
sampler=ddp_sampler, num_workers=2, pin_memory=True)
|
85 |
+
|
86 |
+
if local_rank == 0:
|
87 |
+
print("data num:", len(test_data))
|
88 |
+
|
89 |
+
model.eval()
|
90 |
+
|
91 |
+
metrics = args.metrics.split(",")
|
92 |
+
all_prompt_results = []
|
93 |
+
with torch.no_grad():
|
94 |
+
|
95 |
+
for prompt_id in prompt_ids:
|
96 |
+
|
97 |
+
if local_rank == 0:
|
98 |
+
print("Start prompt: ",prompt_id)
|
99 |
+
|
100 |
+
test_loader.dataset.set_prompt(prompt_id)
|
101 |
+
metrics_results = {}
|
102 |
+
total = 0
|
103 |
+
|
104 |
+
for step, batch in enumerate(tqdm(test_loader)):
|
105 |
+
inputs = batch[0].to(device)
|
106 |
+
targets = batch[1]
|
107 |
+
bs = len(targets)
|
108 |
+
num_beams = args.num_beams
|
109 |
+
while True:
|
110 |
+
try:
|
111 |
+
output = model.module.generate(
|
112 |
+
input_ids=inputs["input_ids"],
|
113 |
+
attention_mask=inputs["attention_mask"],
|
114 |
+
max_new_tokens=10,
|
115 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens,
|
116 |
+
num_beams=num_beams,
|
117 |
+
num_return_sequences=num_beams,
|
118 |
+
output_scores=True,
|
119 |
+
return_dict_in_generate=True,
|
120 |
+
early_stopping=True,
|
121 |
+
)
|
122 |
+
break
|
123 |
+
except torch.cuda.OutOfMemoryError as e:
|
124 |
+
print("Out of memory!")
|
125 |
+
num_beams = num_beams -1
|
126 |
+
print("Beam:", num_beams)
|
127 |
+
except Exception:
|
128 |
+
raise RuntimeError
|
129 |
+
|
130 |
+
output_ids = output["sequences"]
|
131 |
+
scores = output["sequences_scores"]
|
132 |
+
|
133 |
+
output = tokenizer.batch_decode(
|
134 |
+
output_ids, skip_special_tokens=True
|
135 |
+
)
|
136 |
+
|
137 |
+
topk_res = get_topk_results(output, scores, targets, num_beams,
|
138 |
+
all_items=all_items if args.filter_items else None)
|
139 |
+
|
140 |
+
bs_gather_list = [None for _ in range(world_size)]
|
141 |
+
dist.all_gather_object(obj=bs, object_list=bs_gather_list)
|
142 |
+
total += sum(bs_gather_list)
|
143 |
+
res_gather_list = [None for _ in range(world_size)]
|
144 |
+
dist.all_gather_object(obj=topk_res, object_list=res_gather_list)
|
145 |
+
|
146 |
+
|
147 |
+
if local_rank == 0:
|
148 |
+
all_device_topk_res = []
|
149 |
+
for ga_res in res_gather_list:
|
150 |
+
all_device_topk_res += ga_res
|
151 |
+
batch_metrics_res = get_metrics_results(all_device_topk_res, metrics)
|
152 |
+
for m, res in batch_metrics_res.items():
|
153 |
+
if m not in metrics_results:
|
154 |
+
metrics_results[m] = res
|
155 |
+
else:
|
156 |
+
metrics_results[m] += res
|
157 |
+
|
158 |
+
if (step + 1) % 50 == 0:
|
159 |
+
temp = {}
|
160 |
+
for m in metrics_results:
|
161 |
+
temp[m] = metrics_results[m] / total
|
162 |
+
print(temp)
|
163 |
+
|
164 |
+
dist.barrier()
|
165 |
+
|
166 |
+
if local_rank == 0:
|
167 |
+
for m in metrics_results:
|
168 |
+
metrics_results[m] = metrics_results[m] / total
|
169 |
+
|
170 |
+
all_prompt_results.append(metrics_results)
|
171 |
+
print("======================================================")
|
172 |
+
print("Prompt {} results: ".format(prompt_id), metrics_results)
|
173 |
+
print("======================================================")
|
174 |
+
print("")
|
175 |
+
|
176 |
+
dist.barrier()
|
177 |
+
|
178 |
+
dist.barrier()
|
179 |
+
|
180 |
+
if local_rank == 0:
|
181 |
+
mean_results = {}
|
182 |
+
min_results = {}
|
183 |
+
max_results = {}
|
184 |
+
|
185 |
+
for m in metrics:
|
186 |
+
all_res = [_[m] for _ in all_prompt_results]
|
187 |
+
mean_results[m] = sum(all_res)/len(all_res)
|
188 |
+
min_results[m] = min(all_res)
|
189 |
+
max_results[m] = max(all_res)
|
190 |
+
|
191 |
+
print("======================================================")
|
192 |
+
print("Mean results: ", mean_results)
|
193 |
+
print("Min results: ", min_results)
|
194 |
+
print("Max results: ", max_results)
|
195 |
+
print("======================================================")
|
196 |
+
|
197 |
+
|
198 |
+
save_data={}
|
199 |
+
save_data["test_prompt_ids"] = args.test_prompt_ids
|
200 |
+
save_data["mean_results"] = mean_results
|
201 |
+
save_data["min_results"] = min_results
|
202 |
+
save_data["max_results"] = max_results
|
203 |
+
save_data["all_prompt_results"] = all_prompt_results
|
204 |
+
|
205 |
+
with open(args.results_file, "w") as f:
|
206 |
+
json.dump(save_data, f, indent=4)
|
207 |
+
print("Save file: ", args.results_file)
|
208 |
+
|
209 |
+
import smtplib
|
210 |
+
from email.mime.text import MIMEText
|
211 |
+
mail_host = 'smtp.qq.com'
|
212 |
+
mail_code = 'ouzplpngooqndjcb'
|
213 |
+
sender = '[email protected]'
|
214 |
+
receiver = '[email protected]'
|
215 |
+
|
216 |
+
task = '[v67: evaluate lcrec]'
|
217 |
+
message = MIMEText('Task {task} Finished'.format(task = task), 'plain', 'utf-8')
|
218 |
+
message['Subject'] = 'Auto Email'
|
219 |
+
message['From'] = sender
|
220 |
+
message['To'] = receiver
|
221 |
+
|
222 |
+
server = smtplib.SMTP_SSL("smtp.qq.com", 465)
|
223 |
+
server.login(sender, mail_code)
|
224 |
+
server.sendmail(sender, receiver, message.as_string())
|
225 |
+
|
226 |
+
server.quit()
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
parser = argparse.ArgumentParser(description="LLMRec_test")
|
232 |
+
parser = parse_global_args(parser)
|
233 |
+
parser = parse_dataset_args(parser)
|
234 |
+
parser = parse_test_args(parser)
|
235 |
+
|
236 |
+
args = parser.parse_args()
|
237 |
+
|
238 |
+
test_ddp(args)
|
test_ddp.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET=Instruments
|
2 |
+
DATA_PATH=$datain/v-yinju/rqvae-zzx/data
|
3 |
+
CKPT_PATH=$datain/v-yinju/rq-llama
|
4 |
+
RESULTS_FILE=$CKPT_PATH/result.json
|
5 |
+
|
6 |
+
torchrun --nproc_per_node=8 --master_port=4324 test_ddp.py \
|
7 |
+
--ckpt_path $CKPT_PATH \
|
8 |
+
--dataset $DATASET \
|
9 |
+
--data_path $DATA_PATH \
|
10 |
+
--results_file $RESULTS_FILE \
|
11 |
+
--test_batch_size 1 \
|
12 |
+
--num_beams 20 \
|
13 |
+
--test_prompt_ids all \
|
14 |
+
--index_file .index.json
|
utils.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import datetime
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import ConcatDataset
|
10 |
+
from data import SeqRecDataset, ItemFeatDataset, ItemSearchDataset, FusionSeqRecDataset, SeqRecTestDataset, PreferenceObtainDataset
|
11 |
+
|
12 |
+
|
13 |
+
def parse_global_args(parser):
|
14 |
+
|
15 |
+
|
16 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
17 |
+
|
18 |
+
parser.add_argument("--base_model", type=str,
|
19 |
+
default="./llama-7b/",
|
20 |
+
help="basic model path")
|
21 |
+
parser.add_argument("--output_dir", type=str,
|
22 |
+
default="./ckpt/",
|
23 |
+
help="The output directory")
|
24 |
+
|
25 |
+
|
26 |
+
return parser
|
27 |
+
|
28 |
+
def parse_dataset_args(parser):
|
29 |
+
parser.add_argument("--data_path", type=str, default="",
|
30 |
+
help="data directory")
|
31 |
+
parser.add_argument("--tasks", type=str, default="seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain",
|
32 |
+
help="Downstream tasks, separate by comma")
|
33 |
+
parser.add_argument("--dataset", type=str, default="Games", help="Dataset name")
|
34 |
+
parser.add_argument("--index_file", type=str, default=".index.json", help="the item indices file")
|
35 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers")
|
36 |
+
parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor")
|
37 |
+
|
38 |
+
|
39 |
+
# arguments related to sequential task
|
40 |
+
parser.add_argument("--max_his_len", type=int, default=20,
|
41 |
+
help="the max number of items in history sequence, -1 means no limit")
|
42 |
+
parser.add_argument("--add_prefix", action="store_true", default=False,
|
43 |
+
help="whether add sequential prefix in history")
|
44 |
+
parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history")
|
45 |
+
parser.add_argument("--only_train_response", action="store_true", default=False,
|
46 |
+
help="whether only train on responses")
|
47 |
+
|
48 |
+
parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1",
|
49 |
+
help="the number of sampling prompts for each task")
|
50 |
+
parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,100000,0,0",
|
51 |
+
help="the number of sampling prompts for each task")
|
52 |
+
|
53 |
+
parser.add_argument("--valid_prompt_id", type=int, default=0,
|
54 |
+
help="The prompt used for validation")
|
55 |
+
parser.add_argument("--sample_valid", action="store_true", default=True,
|
56 |
+
help="use sampled prompt for validation")
|
57 |
+
parser.add_argument("--valid_prompt_sample_num", type=int, default=2,
|
58 |
+
help="the number of sampling validation sequential recommendation prompts")
|
59 |
+
|
60 |
+
return parser
|
61 |
+
|
62 |
+
def parse_train_args(parser):
|
63 |
+
|
64 |
+
parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer')
|
65 |
+
parser.add_argument("--epochs", type=int, default=4)
|
66 |
+
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
67 |
+
parser.add_argument("--per_device_batch_size", type=int, default=8)
|
68 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
|
69 |
+
parser.add_argument("--logging_step", type=int, default=10)
|
70 |
+
parser.add_argument("--model_max_length", type=int, default=2048)
|
71 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
72 |
+
|
73 |
+
parser.add_argument("--lora_r", type=int, default=8)
|
74 |
+
parser.add_argument("--lora_alpha", type=int, default=32)
|
75 |
+
parser.add_argument("--lora_dropout", type=float, default=0.05)
|
76 |
+
parser.add_argument("--lora_target_modules", type=str,
|
77 |
+
default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma")
|
78 |
+
parser.add_argument("--lora_modules_to_save", type=str,
|
79 |
+
default="embed_tokens,lm_head", help="separate by comma")
|
80 |
+
|
81 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter")
|
82 |
+
|
83 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.01)
|
84 |
+
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
85 |
+
parser.add_argument("--save_and_eval_strategy", type=str, default="epoch")
|
86 |
+
parser.add_argument("--save_and_eval_steps", type=int, default=1000)
|
87 |
+
parser.add_argument("--fp16", action="store_true", default=False)
|
88 |
+
parser.add_argument("--bf16", action="store_true", default=False)
|
89 |
+
parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json")
|
90 |
+
|
91 |
+
return parser
|
92 |
+
|
93 |
+
def parse_test_args(parser):
|
94 |
+
|
95 |
+
parser.add_argument("--ckpt_path", type=str,
|
96 |
+
default="",
|
97 |
+
help="The checkpoint path")
|
98 |
+
parser.add_argument("--lora", action="store_true", default=False)
|
99 |
+
parser.add_argument("--filter_items", action="store_true", default=False,
|
100 |
+
help="whether filter illegal items")
|
101 |
+
|
102 |
+
parser.add_argument("--results_file", type=str,
|
103 |
+
default="./results/test-ddp.json",
|
104 |
+
help="result output path")
|
105 |
+
|
106 |
+
parser.add_argument("--test_batch_size", type=int, default=1)
|
107 |
+
parser.add_argument("--num_beams", type=int, default=20)
|
108 |
+
parser.add_argument("--sample_num", type=int, default=-1,
|
109 |
+
help="test sample number, -1 represents using all test data")
|
110 |
+
parser.add_argument("--gpu_id", type=int, default=0,
|
111 |
+
help="GPU ID when testing with single GPU")
|
112 |
+
parser.add_argument("--test_prompt_ids", type=str, default="0",
|
113 |
+
help="test prompt ids, separate by comma. 'all' represents using all")
|
114 |
+
parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10",
|
115 |
+
help="test metrics, separate by comma")
|
116 |
+
parser.add_argument("--test_task", type=str, default="SeqRec")
|
117 |
+
|
118 |
+
|
119 |
+
return parser
|
120 |
+
|
121 |
+
|
122 |
+
def get_local_time():
|
123 |
+
cur = datetime.datetime.now()
|
124 |
+
cur = cur.strftime("%b-%d-%Y_%H-%M-%S")
|
125 |
+
|
126 |
+
return cur
|
127 |
+
|
128 |
+
|
129 |
+
def set_seed(seed):
|
130 |
+
random.seed(seed)
|
131 |
+
np.random.seed(seed)
|
132 |
+
torch.manual_seed(seed)
|
133 |
+
torch.cuda.manual_seed_all(seed)
|
134 |
+
torch.backends.cudnn.benchmark = False
|
135 |
+
torch.backends.cudnn.deterministic = True
|
136 |
+
torch.backends.cudnn.enabled = False
|
137 |
+
|
138 |
+
def ensure_dir(dir_path):
|
139 |
+
|
140 |
+
os.makedirs(dir_path, exist_ok=True)
|
141 |
+
|
142 |
+
|
143 |
+
def load_datasets(args):
|
144 |
+
|
145 |
+
tasks = args.tasks.split(",")
|
146 |
+
|
147 |
+
train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")]
|
148 |
+
assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number"
|
149 |
+
train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")]
|
150 |
+
assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number"
|
151 |
+
|
152 |
+
train_datasets = []
|
153 |
+
for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num):
|
154 |
+
if task.lower() == "seqrec":
|
155 |
+
dataset = SeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
|
156 |
+
|
157 |
+
elif task.lower() == "item2index" or task.lower() == "index2item":
|
158 |
+
dataset = ItemFeatDataset(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
|
159 |
+
|
160 |
+
elif task.lower() == "fusionseqrec":
|
161 |
+
dataset = FusionSeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
|
162 |
+
|
163 |
+
elif task.lower() == "itemsearch":
|
164 |
+
dataset = ItemSearchDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
|
165 |
+
|
166 |
+
elif task.lower() == "preferenceobtain":
|
167 |
+
dataset = PreferenceObtainDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
|
168 |
+
|
169 |
+
else:
|
170 |
+
raise NotImplementedError
|
171 |
+
train_datasets.append(dataset)
|
172 |
+
|
173 |
+
train_data = ConcatDataset(train_datasets)
|
174 |
+
|
175 |
+
valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num)
|
176 |
+
|
177 |
+
return train_data, valid_data
|
178 |
+
|
179 |
+
def load_test_dataset(args):
|
180 |
+
|
181 |
+
if args.test_task.lower() == "seqrec":
|
182 |
+
test_data = SeqRecDataset(args, mode="test", sample_num=args.sample_num)
|
183 |
+
# test_data = SeqRecTestDataset(args, sample_num=args.sample_num)
|
184 |
+
elif args.test_task.lower() == "itemsearch":
|
185 |
+
test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num)
|
186 |
+
elif args.test_task.lower() == "fusionseqrec":
|
187 |
+
test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num)
|
188 |
+
else:
|
189 |
+
raise NotImplementedError
|
190 |
+
|
191 |
+
return test_data
|
192 |
+
|
193 |
+
def load_json(file):
|
194 |
+
with open(file, 'r') as f:
|
195 |
+
data = json.load(f)
|
196 |
+
return data
|