Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,264 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
ToolPlanner
|
6 |
+
===========================
|
7 |
+
|
8 |
+
## Paper Link
|
9 |
+
[ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback](https://arxiv.org/abs/2409.14826)
|
10 |
+
|
11 |
+
****
|
12 |
+
## 目录
|
13 |
+
* [Requirement](##Requirement)
|
14 |
+
* [Data](##Data)
|
15 |
+
* [Model](#Model)
|
16 |
+
|
17 |
+
|
18 |
+
## Requirement
|
19 |
+
|
20 |
+
```
|
21 |
+
accelerate==0.24.0
|
22 |
+
datasets==2.13.0
|
23 |
+
deepspeed==0.9.2
|
24 |
+
Flask==1.1.2
|
25 |
+
Flask_Cors==4.0.0
|
26 |
+
huggingface_hub==0.16.4
|
27 |
+
jsonlines==3.1.0
|
28 |
+
nltk==3.7
|
29 |
+
numpy==1.24.3
|
30 |
+
openai==0.27.7
|
31 |
+
pandas==2.0.3
|
32 |
+
peft==0.6.0.dev0
|
33 |
+
psutil==5.8.0
|
34 |
+
pydantic==1.10.8
|
35 |
+
pygraphviz==1.11
|
36 |
+
PyYAML==6.0
|
37 |
+
PyYAML==6.0.1
|
38 |
+
Requests==2.31.0
|
39 |
+
scikit_learn==1.0.2
|
40 |
+
scipy==1.11.4
|
41 |
+
sentence_transformers==2.2.2
|
42 |
+
tenacity==8.2.3
|
43 |
+
termcolor==2.4.0
|
44 |
+
torch==2.0.1
|
45 |
+
tqdm==4.65.0
|
46 |
+
transformers==4.28.1
|
47 |
+
trl==0.7.3.dev0
|
48 |
+
```
|
49 |
+
|
50 |
+
## Data
|
51 |
+
|
52 |
+
|path|data description|
|
53 |
+
|----|-----|
|
54 |
+
|[/data/category/dataset]|MGToolBench: pairwise_responses|
|
55 |
+
|[/data/category/answer](./data/category/answer)|MGToolBench: Multi-Level Instruction Split|
|
56 |
+
|[/data/category/coarse_instruction](./data/category/coarse_instruction)|Self-Instruct Data: multi-granularity instructions|
|
57 |
+
|[/data/test_sample](./data/test_sample)|Test Sample: test dataset|
|
58 |
+
|[/data/category/toolenv]|Tool Environment: Tools, APIs, and their documentation.|
|
59 |
+
|[/data/category/inference]|Output: solution trees path|
|
60 |
+
|[/data/category/converted_answer](./data/category/converted_answer)|Output: converted_answer path|
|
61 |
+
|[/data/category/retrieval/G3_category](./data/category/retrieval/G3_category)|Supplementary: Category & Tool & API Name|
|
62 |
+
|[/data/retrieval/G3_clear](./data/retrieval/G3_clear)|Supplementary: corpus for seperate retriever|
|
63 |
+
|
64 |
+
## Download Data and Checkpoints
|
65 |
+
|
66 |
+
download these data and unzip them.
|
67 |
+
|path|data description|data name|url|
|
68 |
+
|----|-----|-----|-----|
|
69 |
+
|[/data/category/dataset]|MGToolBench: pairwise_responses|G3_1107_gensample_Reward_pair.json|https://huggingface.co/datasets/wuqinzhuo/ToolPlanner|
|
70 |
+
|[/data/category/toolenv]|Tool Environment: Tools, APIs, and their documentation.|toolenv.zip|https://huggingface.co/datasets/wuqinzhuo/ToolPlanner|
|
71 |
+
|[/data/category/inference]|Output: solution trees path|inference.zip|https://huggingface.co/datasets/wuqinzhuo/ToolPlanner|
|
72 |
+
|
73 |
+
|
74 |
+
|path|model description|model name|url|
|
75 |
+
|----|-----|-----|-----|
|
76 |
+
|[ToolPlanner root path]|Stage1 sft model|ToolPlanner_Stage1_1020|https://huggingface.co/wuqinzhuo/ToolPlanner_Stage1_1020|
|
77 |
+
|[ToolPlanner root path]|Stage1 sft model|ToolPlanner_Stage2_1107|https://huggingface.co/wuqinzhuo/ToolPlanner_Stage2_1107/|
|
78 |
+
|[ToolPlanner root path]|Baseline ToolLLaMA|ToolLLaMA-7b|https://github.com/OpenBMB/ToolBench|
|
79 |
+
|[ToolPlanner root path]|Retrivel model for test, using MGToolBench data|model_1122_G3_tag_trace_multilevel|https://huggingface.co/wuqinzhuo/model_1122_G3_tag_trace_multilevel|
|
80 |
+
|[ToolPlanner root path]|Retrivel model for test, using ToolBench data|retriever_model_G3_clear|https://huggingface.co/wuqinzhuo/retriever_model_G3_clear|
|
81 |
+
|
82 |
+
|
83 |
+
# Model
|
84 |
+
## Install
|
85 |
+
pip install -r requirements.txt
|
86 |
+
|
87 |
+
|
88 |
+
## Train ToolPlanner, Stage 1 SFT
|
89 |
+
### Script
|
90 |
+
bash scripts/category/train_model_1020_stage1.sh
|
91 |
+
### Code
|
92 |
+
```
|
93 |
+
export PYTHONPATH=./
|
94 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
95 |
+
|
96 |
+
torchrun --nproc_per_node=8 --master_port=20001 toolbench/train/train_long_seq.py \
|
97 |
+
--model_name_or_path ToolLLaMA-7b \
|
98 |
+
--data_path data/category/answer/G3_plan_gen_train_1020_G3_3tag_whole_prefixTagTraceAll.json \
|
99 |
+
--eval_data_path data/category/answer/G3_plan_gen_eval_1020_G3_3tag_whole_prefixTagTraceAll.json \
|
100 |
+
--conv_template tool-llama-single-round \
|
101 |
+
--bf16 True \
|
102 |
+
--output_dir ToolPlanner_Stage1 \
|
103 |
+
--num_train_epochs 2 \
|
104 |
+
--per_device_train_batch_size 2 \
|
105 |
+
--per_device_eval_batch_size 2 \
|
106 |
+
--gradient_accumulation_steps 8 \
|
107 |
+
--evaluation_strategy "epoch" \
|
108 |
+
--prediction_loss_only \
|
109 |
+
--save_strategy "epoch" \
|
110 |
+
--save_total_limit 8 \
|
111 |
+
--learning_rate 5e-5 \
|
112 |
+
--weight_decay 0. \
|
113 |
+
--warmup_ratio 0.04 \
|
114 |
+
--lr_scheduler_type "cosine" \
|
115 |
+
--logging_steps 1 \
|
116 |
+
--fsdp "full_shard auto_wrap" \
|
117 |
+
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
|
118 |
+
--tf32 True \
|
119 |
+
--model_max_length 8192 \
|
120 |
+
--gradient_checkpointing True \
|
121 |
+
--lazy_preprocess True \
|
122 |
+
--report_to none
|
123 |
+
```
|
124 |
+
|
125 |
+
## Train ToolPlanner, Stage 2 Reinforcement Learning
|
126 |
+
### Script
|
127 |
+
bash scripts/category/train_model_1107_stage2.sh
|
128 |
+
### Code
|
129 |
+
```
|
130 |
+
export PYTHONPATH=./
|
131 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
132 |
+
|
133 |
+
export MODEL_PATH="ToolPlanner_Stage1_1020"
|
134 |
+
export SAVE_PATH="ToolPlanner_Stage2"
|
135 |
+
export DATA_PATH="data/category/dataset/G3_1107_gensample_Reward_pair.json"
|
136 |
+
export MASTER_ADDR="localhost"
|
137 |
+
export MASTER_PORT="20010"
|
138 |
+
export WANDB_DISABLED=true
|
139 |
+
wandb offline
|
140 |
+
|
141 |
+
torchrun --nproc_per_node=8 --master_port=20001 toolbench/train/train_long_seq_RRHF.py \
|
142 |
+
--model_name_or_path $MODEL_PATH \
|
143 |
+
--data_path $DATA_PATH \
|
144 |
+
--bf16 True \
|
145 |
+
--output_dir $SAVE_PATH \
|
146 |
+
--num_train_epochs 3 \
|
147 |
+
--per_device_train_batch_size 1 \
|
148 |
+
--per_device_eval_batch_size 1 \
|
149 |
+
--gradient_accumulation_steps 8 \
|
150 |
+
--evaluation_strategy "no" \
|
151 |
+
--save_strategy "steps" \
|
152 |
+
--save_steps 100 \
|
153 |
+
--save_total_limit 3 \
|
154 |
+
--learning_rate 2e-5 \
|
155 |
+
--weight_decay 0. \
|
156 |
+
--warmup_ratio 0.03 \
|
157 |
+
--lr_scheduler_type "cosine" \
|
158 |
+
--logging_steps 1 \
|
159 |
+
--fsdp "full_shard auto_wrap" \
|
160 |
+
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
|
161 |
+
--gradient_checkpointing True \
|
162 |
+
--tf32 True --model_max_length 8192 --rrhf_weight 1
|
163 |
+
```
|
164 |
+
|
165 |
+
## Inference, Generate Solution Tree
|
166 |
+
### Script
|
167 |
+
```
|
168 |
+
bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh <GPU_Id> <model_name> <method_name> <decode_method> <output_path> <test_sample> <retriever_path> <TOOLBENCH_KEY>
|
169 |
+
```
|
170 |
+
|
171 |
+
### ToolBench Key
|
172 |
+
Go to [ToolBench](https://github.com/OpenBMB/ToolBench) to apply for a [ToolBench Key](https://github.com/OpenBMB/ToolBench).
|
173 |
+
|
174 |
+
|
175 |
+
### Decode_Method
|
176 |
+
|
177 |
+
|Model|Method|
|
178 |
+
|----|-----|
|
179 |
+
|`Full Model`|`Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2`|
|
180 |
+
|`Seperate Retriever`|`Mix_Whole3Tag_MixWhole3TagTrace_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2`|
|
181 |
+
|`Without Solution Planning`|`Mix_Whole3Tag_MixWhole3TagTrace_MixWhole3Retri_MixWhole3Gen_DFS_woFilter_w2`|
|
182 |
+
|`Without Tag Extraction`|`Mix_Whole3Tag_MixWhole3TagTrace_MixTagTraceRetri_MixTagTraceGen_DFS_woFilter_w2`|
|
183 |
+
|`Without Tag & Solution`|`Mix_Whole3Tag_MixWhole3TagTrace_MixRetri_MixGen_DFS_woFilter_w2`|
|
184 |
+
|`Chain-based Method`|`Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_CoT@5`|
|
185 |
+
|
186 |
+
|
187 |
+
### Example
|
188 |
+
```
|
189 |
+
bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 6,7 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23 data/test_sample/G3_query_100_opendomain.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
|
190 |
+
|
191 |
+
bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 1,3 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23 data/test_sample/G3_query_100_level_cate.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
|
192 |
+
bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 2,4 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23 data/test_sample/G3_query_100_level_tool.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
|
193 |
+
bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 5,4 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23 data/test_sample/G3_query_100_level_api.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
|
194 |
+
```
|
195 |
+
|
196 |
+
## Eval
|
197 |
+
### Script
|
198 |
+
Use generated results to eval Match Rate and Pass Rate
|
199 |
+
```
|
200 |
+
bash scripts/category/eval/eval_match_pass_rate.sh api name2 <output_path>
|
201 |
+
```
|
202 |
+
|
203 |
+
### Example
|
204 |
+
```
|
205 |
+
bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23
|
206 |
+
bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23
|
207 |
+
bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23
|
208 |
+
bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23
|
209 |
+
```
|
210 |
+
|
211 |
+
### Script
|
212 |
+
Use generated results to eval Win Rate
|
213 |
+
```
|
214 |
+
Change generate(prompt, name) function in "ToolPlanner/toolbench/tooleval/new_eval_win_rate_cut_list.py" to your own ChatGPT API.
|
215 |
+
|
216 |
+
bash scripts/category/eval/eval_match_pass_rate.sh api name2 <output_path>
|
217 |
+
```
|
218 |
+
|
219 |
+
### Example
|
220 |
+
```
|
221 |
+
bash scripts/inference/convert_preprocess_win_rate.sh DFS data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Cate_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Tool_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_API_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Desc_1122_level_23.json
|
222 |
+
bash scripts/inference/eval_win_rate_cut_list.sh data/category/converted_answer/plan_1107_G3_gensample_RRHF_Cate_1122_level_23.json
|
223 |
+
```
|
224 |
+
|
225 |
+
### Citation
|
226 |
+
```
|
227 |
+
@misc{wu2024toolplannertoolaugmentedllm,
|
228 |
+
title={ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback},
|
229 |
+
author={Qinzhuo Wu and Wei Liu and Jian Luan and Bin Wang},
|
230 |
+
year={2024},
|
231 |
+
eprint={2409.14826},
|
232 |
+
archivePrefix={arXiv},
|
233 |
+
primaryClass={cs.CL},
|
234 |
+
url={https://arxiv.org/abs/2409.14826},
|
235 |
+
}
|
236 |
+
```
|
237 |
+
|
238 |
+
### License
|
239 |
+
|
240 |
+
The dataset of this project is licensed under the [**Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)**](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.
|
241 |
+
|
242 |
+
The source code of the this is licensed under the [**Apache 2.0**](http://www.apache.org/licenses/LICENSE-2.0) license.
|
243 |
+
|
244 |
+
#### Summary of Terms
|
245 |
+
- **Attribution**: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
|
246 |
+
- **NonCommercial**: You may not use the material for commercial purposes.
|
247 |
+
- **ShareAlike**: If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
|
248 |
+
|
249 |
+
|
250 |
+
#### License Badge
|
251 |
+
[](https://creativecommons.org/licenses/by-nc-sa/4.0/)
|
252 |
+
|
253 |
+
### 5. Citation
|
254 |
+
If you'd like to use our benchmark or cite this paper, please kindly use the reference below:
|
255 |
+
|
256 |
+
```bibtex
|
257 |
+
@inproceedings{wu2024toolplanner,
|
258 |
+
title={ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback},
|
259 |
+
author={Wu, Qinzhuo and Liu, Wei and Luan, Jian and Wang, Bin},
|
260 |
+
booktitle={Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing},
|
261 |
+
pages={18315--18339},
|
262 |
+
year={2024}
|
263 |
+
}
|
264 |
+
|