wuqinzhuo commited on
Commit
1863858
·
verified ·
1 Parent(s): 052e8bb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +264 -3
README.md CHANGED
@@ -1,3 +1,264 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ToolPlanner
6
+ ===========================
7
+
8
+ ## Paper Link
9
+ [ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback](https://arxiv.org/abs/2409.14826)
10
+
11
+ ****
12
+ ## 目录
13
+ * [Requirement](##Requirement)
14
+ * [Data](##Data)
15
+ * [Model](#Model)
16
+
17
+
18
+ ## Requirement
19
+
20
+ ```
21
+ accelerate==0.24.0
22
+ datasets==2.13.0
23
+ deepspeed==0.9.2
24
+ Flask==1.1.2
25
+ Flask_Cors==4.0.0
26
+ huggingface_hub==0.16.4
27
+ jsonlines==3.1.0
28
+ nltk==3.7
29
+ numpy==1.24.3
30
+ openai==0.27.7
31
+ pandas==2.0.3
32
+ peft==0.6.0.dev0
33
+ psutil==5.8.0
34
+ pydantic==1.10.8
35
+ pygraphviz==1.11
36
+ PyYAML==6.0
37
+ PyYAML==6.0.1
38
+ Requests==2.31.0
39
+ scikit_learn==1.0.2
40
+ scipy==1.11.4
41
+ sentence_transformers==2.2.2
42
+ tenacity==8.2.3
43
+ termcolor==2.4.0
44
+ torch==2.0.1
45
+ tqdm==4.65.0
46
+ transformers==4.28.1
47
+ trl==0.7.3.dev0
48
+ ```
49
+
50
+ ## Data
51
+
52
+ |path|data description|
53
+ |----|-----|
54
+ |[/data/category/dataset]|MGToolBench: pairwise_responses|
55
+ |[/data/category/answer](./data/category/answer)|MGToolBench: Multi-Level Instruction Split|
56
+ |[/data/category/coarse_instruction](./data/category/coarse_instruction)|Self-Instruct Data: multi-granularity instructions|
57
+ |[/data/test_sample](./data/test_sample)|Test Sample: test dataset|
58
+ |[/data/category/toolenv]|Tool Environment: Tools, APIs, and their documentation.|
59
+ |[/data/category/inference]|Output: solution trees path|
60
+ |[/data/category/converted_answer](./data/category/converted_answer)|Output: converted_answer path|
61
+ |[/data/category/retrieval/G3_category](./data/category/retrieval/G3_category)|Supplementary: Category & Tool & API Name|
62
+ |[/data/retrieval/G3_clear](./data/retrieval/G3_clear)|Supplementary: corpus for seperate retriever|
63
+
64
+ ## Download Data and Checkpoints
65
+
66
+ download these data and unzip them.
67
+ |path|data description|data name|url|
68
+ |----|-----|-----|-----|
69
+ |[/data/category/dataset]|MGToolBench: pairwise_responses|G3_1107_gensample_Reward_pair.json|https://huggingface.co/datasets/wuqinzhuo/ToolPlanner|
70
+ |[/data/category/toolenv]|Tool Environment: Tools, APIs, and their documentation.|toolenv.zip|https://huggingface.co/datasets/wuqinzhuo/ToolPlanner|
71
+ |[/data/category/inference]|Output: solution trees path|inference.zip|https://huggingface.co/datasets/wuqinzhuo/ToolPlanner|
72
+
73
+
74
+ |path|model description|model name|url|
75
+ |----|-----|-----|-----|
76
+ |[ToolPlanner root path]|Stage1 sft model|ToolPlanner_Stage1_1020|https://huggingface.co/wuqinzhuo/ToolPlanner_Stage1_1020|
77
+ |[ToolPlanner root path]|Stage1 sft model|ToolPlanner_Stage2_1107|https://huggingface.co/wuqinzhuo/ToolPlanner_Stage2_1107/|
78
+ |[ToolPlanner root path]|Baseline ToolLLaMA|ToolLLaMA-7b|https://github.com/OpenBMB/ToolBench|
79
+ |[ToolPlanner root path]|Retrivel model for test, using MGToolBench data|model_1122_G3_tag_trace_multilevel|https://huggingface.co/wuqinzhuo/model_1122_G3_tag_trace_multilevel|
80
+ |[ToolPlanner root path]|Retrivel model for test, using ToolBench data|retriever_model_G3_clear|https://huggingface.co/wuqinzhuo/retriever_model_G3_clear|
81
+
82
+
83
+ # Model
84
+ ## Install
85
+ pip install -r requirements.txt
86
+
87
+
88
+ ## Train ToolPlanner, Stage 1 SFT
89
+ ### Script
90
+ bash scripts/category/train_model_1020_stage1.sh
91
+ ### Code
92
+ ```
93
+ export PYTHONPATH=./
94
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
95
+
96
+ torchrun --nproc_per_node=8 --master_port=20001 toolbench/train/train_long_seq.py \
97
+ --model_name_or_path ToolLLaMA-7b \
98
+ --data_path data/category/answer/G3_plan_gen_train_1020_G3_3tag_whole_prefixTagTraceAll.json \
99
+ --eval_data_path data/category/answer/G3_plan_gen_eval_1020_G3_3tag_whole_prefixTagTraceAll.json \
100
+ --conv_template tool-llama-single-round \
101
+ --bf16 True \
102
+ --output_dir ToolPlanner_Stage1 \
103
+ --num_train_epochs 2 \
104
+ --per_device_train_batch_size 2 \
105
+ --per_device_eval_batch_size 2 \
106
+ --gradient_accumulation_steps 8 \
107
+ --evaluation_strategy "epoch" \
108
+ --prediction_loss_only \
109
+ --save_strategy "epoch" \
110
+ --save_total_limit 8 \
111
+ --learning_rate 5e-5 \
112
+ --weight_decay 0. \
113
+ --warmup_ratio 0.04 \
114
+ --lr_scheduler_type "cosine" \
115
+ --logging_steps 1 \
116
+ --fsdp "full_shard auto_wrap" \
117
+ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
118
+ --tf32 True \
119
+ --model_max_length 8192 \
120
+ --gradient_checkpointing True \
121
+ --lazy_preprocess True \
122
+ --report_to none
123
+ ```
124
+
125
+ ## Train ToolPlanner, Stage 2 Reinforcement Learning
126
+ ### Script
127
+ bash scripts/category/train_model_1107_stage2.sh
128
+ ### Code
129
+ ```
130
+ export PYTHONPATH=./
131
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
132
+
133
+ export MODEL_PATH="ToolPlanner_Stage1_1020"
134
+ export SAVE_PATH="ToolPlanner_Stage2"
135
+ export DATA_PATH="data/category/dataset/G3_1107_gensample_Reward_pair.json"
136
+ export MASTER_ADDR="localhost"
137
+ export MASTER_PORT="20010"
138
+ export WANDB_DISABLED=true
139
+ wandb offline
140
+
141
+ torchrun --nproc_per_node=8 --master_port=20001 toolbench/train/train_long_seq_RRHF.py \
142
+ --model_name_or_path $MODEL_PATH \
143
+ --data_path $DATA_PATH \
144
+ --bf16 True \
145
+ --output_dir $SAVE_PATH \
146
+ --num_train_epochs 3 \
147
+ --per_device_train_batch_size 1 \
148
+ --per_device_eval_batch_size 1 \
149
+ --gradient_accumulation_steps 8 \
150
+ --evaluation_strategy "no" \
151
+ --save_strategy "steps" \
152
+ --save_steps 100 \
153
+ --save_total_limit 3 \
154
+ --learning_rate 2e-5 \
155
+ --weight_decay 0. \
156
+ --warmup_ratio 0.03 \
157
+ --lr_scheduler_type "cosine" \
158
+ --logging_steps 1 \
159
+ --fsdp "full_shard auto_wrap" \
160
+ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
161
+ --gradient_checkpointing True \
162
+ --tf32 True --model_max_length 8192 --rrhf_weight 1
163
+ ```
164
+
165
+ ## Inference, Generate Solution Tree
166
+ ### Script
167
+ ```
168
+ bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh <GPU_Id> <model_name> <method_name> <decode_method> <output_path> <test_sample> <retriever_path> <TOOLBENCH_KEY>
169
+ ```
170
+
171
+ ### ToolBench Key
172
+ Go to [ToolBench](https://github.com/OpenBMB/ToolBench) to apply for a [ToolBench Key](https://github.com/OpenBMB/ToolBench).
173
+
174
+
175
+ ### Decode_Method
176
+
177
+ |Model|Method|
178
+ |----|-----|
179
+ |`Full Model`|`Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2`|
180
+ |`Seperate Retriever`|`Mix_Whole3Tag_MixWhole3TagTrace_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2`|
181
+ |`Without Solution Planning`|`Mix_Whole3Tag_MixWhole3TagTrace_MixWhole3Retri_MixWhole3Gen_DFS_woFilter_w2`|
182
+ |`Without Tag Extraction`|`Mix_Whole3Tag_MixWhole3TagTrace_MixTagTraceRetri_MixTagTraceGen_DFS_woFilter_w2`|
183
+ |`Without Tag & Solution`|`Mix_Whole3Tag_MixWhole3TagTrace_MixRetri_MixGen_DFS_woFilter_w2`|
184
+ |`Chain-based Method`|`Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_CoT@5`|
185
+
186
+
187
+ ### Example
188
+ ```
189
+ bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 6,7 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23 data/test_sample/G3_query_100_opendomain.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
190
+
191
+ bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 1,3 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23 data/test_sample/G3_query_100_level_cate.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
192
+ bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 2,4 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23 data/test_sample/G3_query_100_level_tool.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
193
+ bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 5,4 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23 data/test_sample/G3_query_100_level_api.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
194
+ ```
195
+
196
+ ## Eval
197
+ ### Script
198
+ Use generated results to eval Match Rate and Pass Rate
199
+ ```
200
+ bash scripts/category/eval/eval_match_pass_rate.sh api name2 <output_path>
201
+ ```
202
+
203
+ ### Example
204
+ ```
205
+ bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23
206
+ bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23
207
+ bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23
208
+ bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23
209
+ ```
210
+
211
+ ### Script
212
+ Use generated results to eval Win Rate
213
+ ```
214
+ Change generate(prompt, name) function in "ToolPlanner/toolbench/tooleval/new_eval_win_rate_cut_list.py" to your own ChatGPT API.
215
+
216
+ bash scripts/category/eval/eval_match_pass_rate.sh api name2 <output_path>
217
+ ```
218
+
219
+ ### Example
220
+ ```
221
+ bash scripts/inference/convert_preprocess_win_rate.sh DFS data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Cate_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Tool_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_API_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Desc_1122_level_23.json
222
+ bash scripts/inference/eval_win_rate_cut_list.sh data/category/converted_answer/plan_1107_G3_gensample_RRHF_Cate_1122_level_23.json
223
+ ```
224
+
225
+ ### Citation
226
+ ```
227
+ @misc{wu2024toolplannertoolaugmentedllm,
228
+ title={ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback},
229
+ author={Qinzhuo Wu and Wei Liu and Jian Luan and Bin Wang},
230
+ year={2024},
231
+ eprint={2409.14826},
232
+ archivePrefix={arXiv},
233
+ primaryClass={cs.CL},
234
+ url={https://arxiv.org/abs/2409.14826},
235
+ }
236
+ ```
237
+
238
+ ### License
239
+
240
+ The dataset of this project is licensed under the [**Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)**](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.
241
+
242
+ The source code of the this is licensed under the [**Apache 2.0**](http://www.apache.org/licenses/LICENSE-2.0) license.
243
+
244
+ #### Summary of Terms
245
+ - **Attribution**: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
246
+ - **NonCommercial**: You may not use the material for commercial purposes.
247
+ - **ShareAlike**: If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
248
+
249
+
250
+ #### License Badge
251
+ [![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
252
+
253
+ ### 5. Citation
254
+ If you'd like to use our benchmark or cite this paper, please kindly use the reference below:
255
+
256
+ ```bibtex
257
+ @inproceedings{wu2024toolplanner,
258
+ title={ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback},
259
+ author={Wu, Qinzhuo and Liu, Wei and Luan, Jian and Wang, Bin},
260
+ booktitle={Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing},
261
+ pages={18315--18339},
262
+ year={2024}
263
+ }
264
+