Codes for StepTool

May 24, 2025 · View on GitHub

0. Environment Setup

  1. Create a new Conda environment:

    conda create -n steptool python=3.10
    
  2. Activate the environment:

    conda activate steptool
    
  3. Install Pytorch and other required dependencies via pip:

    pip install torch torchvision torchaudio
    pip install -r requirements.txt

Note: Ensure that the version of GCC/G++ is >= 9.0.0.

1. Step-grained Data Construction

Step-grained rewards can be assigned using various methods, including automated rule-based systems, human annotations, or advanced models such as GPT-4.

Below is a reference prompt for GPT-4 to perform step-grained reward annotation:

Query:
{query}

Intermediate Steps:
{mid_steps}
Final Answer:
{final_answer}

Given the above query, all intermediate steps and the final answer, you need to evaluate the entire task-solving process by following rules:
(1) **Successful Tool Calling:** For each intermediate step, determine if a tool was called successfully and give a score of 0 (no) or 1 (yes).
(2) **Contribution to Final Answer:** For each intermediate step, rate its contribution to the final answer on a scale from 0 to 5.
(3) **Final Answer Status:** Determine if the final answer is 'Solved',  'Unsure', or 'Unsolved'.

Now provide your evaluation in JSON format with the parameters of 'succeed_tool_calling', 'contribution_to_final_answer' and 'final_answer_status'  to the function `evaluate_process_reward`.

We provide a sample training data file, data_train/${MODEL_TYPE}/step_grained_for_ppo_example.csv, for use in the subsequent training phase.

The complete training dataset can be downloaded from this Dropbox link.

2. Step-grained Training with PPO

The step-grained training is implemented in src/steptool/step_ppo.py and src/steptool/step_ppotrainer.py.

  1. Configuration

Modify the configuration file config/${MODEL_TYPE}/StepTool_ppo.json as needed. The MODEL_TYPE can be one of toollama, qwen2, or llama3-1. Here’s an example configuration:

{
    "peft_kwargs": {
        "r": 8,
        "lora_alpha": 16,
        "bias": "none",
        "task_type": "CAUSAL_LM"
    },
    "ppo_kwargs": {
        "learning_rate": 1e-5,
        "log_with": "wandb",
        "remove_unused_columns": false,
        "batch_size": 8,
        "mini_batch_size": 2,
        "gradient_accumulation_steps": 4,
        "kl_penalty": "kl",
        "init_kl_coef": 0.3,
        "target_kl": 6,
        "target": 6,
        "horizon": 10000,
        "gamma": 0.99
    }
}
  1. Run the scripts
bash scripts/steptool_train/train_toolllama.sh
bash scripts/steptool_train/train_qwen2.sh
bash scripts/steptool_train/train_llama3-1.sh

Example Command (from scripts/steptool_train/train_toolllama.sh):

export PYTHONPATH=./
export TRAIN_PATH="data_train"
export TRAIN_SET="step_grained_for_ppo_example"
export CUDA_VISIBLE_DEVICES="0,1,2,3"

export MODEL_TYPE="toolllama"
# load the base model after sft pretrain
export MODEL_PATH="ToolBench/ToolLLaMA-2-7b-v2"

python src/steptool/step_ppo.py \
    --model_path ${MODEL_PATH} \
    --model_type ${MODEL_TYPE} \
    --config_path config/${MODEL_TYPE}/StepTool_ppo.json \
    --data_file ${TRAIN_PATH}/${MODEL_TYPE}/${TRAIN_SET}.csv \
    --max_context_len 4096 \
    --max_response_len 1024 \
    --epochs 5

Note, for qwen2 and llama3.1, these models must undergo supervised fine-tuning (SFT) beforehand:

bash scripts/sft/train_qwen2.sh
bash scripts/sft/train_llama3-1.sh

A sample training dataset for SFT is available in data_train/${MODEL_TYPE}/gpt4_dfs_G123_for_sft_example.json

Train Baselines (RFT, PPO, ETO, ArCHer)

RFT

bash scripts/baseline-rft/train_rft.sh

PPO (Final Reward)

bash scripts/baseline-ppo/train_ppo.sh

ETO (DPO)

bash scripts/baseline-eto/train_dpo.sh

ArCHer

bash scripts/baseline-archer/build_data.sh
bash scripts/baseline-archer/train_archer.sh

Evaluation on StableToolBench

1. Build the api server

To set up the API server, follow the StableToolBench instructions.

First, download a cache from HuggingFace or Tsinghua Cloud.

After downloading, unzip the folder into the stabletoolbench/server folder and ensure the server folder contains tool_response_cache folder and tools folder. The resulting folder of server looks like:

├── /server/
│  ├── /tools/
│  │  └── ...
│  ├── /tool_response_cache/
│  │  └── ...
│  ├── config.yml
│  ├── main.py
│  ├── utils.py

Next, specify your configurations in server/config.yml

api_key: 
api_base: 
model: gpt-4-turbo-preview
temperature: 0
toolbench_url: http://8.130.32.149:8080/rapidapi
rapidapi_key: 
tools_folder: "./tools"
cache_folder: "./tool_response_cache"
is_save: true
port: 8081

To run the server:

cd server
python main.py

The server will be run at http://localhost:{port}/virtual. To use the server, you will further need a toolbench key. You can apply one from this form.

2. Run the model using vLLM

We recommend setting up a new Conda environment for vLLM by following the installation guide

To build a vLLM server for the ToolLLaMA-2-7b-v2 model, you can use the following command:

python -m vllm.entrypoints.openai.api_server --model ToolBench/ToolLLaMA-2-7b-v2 --served-model-name toolllama --max-model-len=8192 --dtype=bfloat16 --host 127.0.0.1 --port 8083 --rope-scaling '{"type": "linear", "factor": 2.0}'

Note: If you're using a LoRA version of the model, make sure to merge the LoRA weights with the base model before running it in vLLM.

3. Run the Evaluation Scripts

To evaluate the model on StableToolBench, first configure stabletoolbench/config.yml:

api_key:
api_base:
toolbench_key:
tool_root_dir: server/tools

Then, infer the model on the solvable_test_queries by running:

bash scripts_eval/toolllama/inference_toolllama_vllm.sh
bash scripts_eval/qwen2/inference_qwen2_vllm.sh
bash scripts_eval/llama3-1/inference_llama3-1_vllm.sh
bash scripts_eval/baseline-rft/inference_rft_vllm.sh
bash scripts_eval/baseline-ppo/inference_ppo_vllm.sh
bash scripts_eval/baseline-eto/inference_eto_vllm.sh
bash scripts_eval/baseline-archer/inference_archer_vllm.sh

Finally, evaluate the pass_rate and win_rate metrics:

bash scripts_eval/toolllama/run_convert_answer.sh
bash scripts_eval/toolllama/run_pass_rate.sh
bash scripts_eval/toolllama/run_preference.sh

bash scripts_eval/qwen2/run_convert_answer.sh
bash scripts_eval/qwen2/run_pass_rate.sh
bash scripts_eval/qwen2/run_preference.sh

bash scripts_eval/llama3-1/run_convert_answer.sh
bash scripts_eval/llama3-1/run_pass_rate.sh
bash scripts_eval/llama3-1/run_preference.sh

bash scripts_eval/baseline-rft/run_convert_answer.sh
bash scripts_eval/baseline-rft/run_pass_rate.sh

bash scripts_eval/baseline-ppo/run_convert_answer.sh
bash scripts_eval/baseline-ppo/run_pass_rate.sh

bash scripts_eval/baseline-eto/run_convert_answer.sh
bash scripts_eval/baseline-eto/run_pass_rate.sh

bash scripts_eval/baseline-archer/run_convert_answer.sh
bash scripts_eval/baseline-archer/run_pass_rate.sh

Main Experimental Results in the Paper (Updated in May. 2025)

All results were re-evaluated in May. 2025 to ensure the stability of the \texttt{gpt-4-turbo-2024-04-09} evaluator.

BackboneStrategyMethodI1. PassI1. RecallI2. PassI2. RecallI3. PassI3. RecallToolLens Test. PassToolLens Test. Recall
ToolLLaMA-2-7b-v2CoTSFT50.6±1.60.795247.1±0.80.808140.4±0.80.683340.2±0.90.6769
ToolLLaMA-2-7b-v2CoTRFT50.2±1.20.806145.9±1.80.819738.5±1.20.753639.5±0.70.7323
ToolLLaMA-2-7b-v2CoTPPO (Final Reward)50.9±1.00.803046.6±2.00.818540.2±0.00.686939.2±0.10.6817
ToolLLaMA-2-7b-v2CoTETO (DPO)50.3±0.90.787445.9±0.30.785938.8±1.00.715040.2±0.50.7176
ToolLLaMA-2-7b-v2CoTArCHer51.8±0.80.800547.5±0.60.803935.5±2.80.690742.8±0.60.6693
ToolLLaMA-2-7b-v2CoTStepTool61.1±0.70.874356.6±2.20.899245.9±1.80.772447.3±0.40.7500
---------------------------------------------------------------------------------------
ToolLLaMA-2-7b-v2DFSDTSFT58.7±1.00.841954.3±1.10.866554.1±1.30.733146.6±1.00.7092
ToolLLaMA-2-7b-v2DFSDTRFT55.0±1.60.849049.5±0.80.853158.5±2.40.746541.1±0.50.7598
ToolLLaMA-2-7b-v2DFSDTPPO (Final Reward)59.6±1.10.836054.0±1.60.870939.9±0.80.734442.4±0.50.7004
ToolLLaMA-2-7b-v2DFSDTETO (DPO)57.1±1.20.841254.5±2.10.874744.0±2.80.729847.9±0.80.7487
ToolLLaMA-2-7b-v2DFSDTArCHer60.0±1.50.849154.5±0.80.872453.3±2.00.728445.6±0.30.7207
ToolLLaMA-2-7b-v2DFSDTStepTool64.1±1.40.879760.3±0.90.900464.8±2.30.783153.2±1.20.7819

More Cases