MMOA-RAG
March 2, 2025 ยท View on GitHub
This repository contains the code for MMOA-RAG, a system for multi-modules optimization involving Query Rewriter, Retriever, Selector and Generator. The code is organized into several components that facilitate the deployment, training, and evaluation of the RAG system.
Paper: Improving Retrieval-Augmented Generation through Multi-Agent Reinforcement Learning
Table of Contents
- Computational Resource Requirements
- Deploying the Retrieval Model
- Getting the SFT and MAPPO Training Data
- Warm Start for RAG System
- Multi-Agent Optimization for RAG System
- Evaluation
- Others
Computational Resource Requirements
We used two servers, each equipped with 8 A800 GPUs (each with 80GB of memory), for training MMOA-RAG. One server was dedicated to deploying the retrieval model, while the other was used for training MARL.
Why is a separate machine needed to deploy the retrieval model? During the MARL training process, updates to the Query Rewriter are involved, and it is necessary to obtain Top-k documents in real-time during Rollout. This requires high real-time performance from the retrieval model. Therefore, we deployed the retrieval model on a separate machine using Faiss and leveraged GPU acceleration to ensure fast retrieval results.
Deploying the Retrieval Model
The retrieval models are deployed using a specialized machine due to the multi-modules optimization that involves the training of the Query Rewriter.
To deploy the retrieval model, execute the following:
- Ensure the code in
./flask_server.pyis properly configured. - Start the retrieval model API by running in one server:
bash run_server.sh
Getting the SFT and MAPPO Training Data
To generate the training data for SFT and MAPPO processes, follow these steps:
Run the following script to obtain the SFT training data:
python qr_s_g_sft_data_alpaca.py
Run the following script to get the MAPPO training data for each dataset:
python get_ppo_data_alpaca.py
We developed the code of MAPPO to joint optimizing multiple modules in RAG system based on LLaMA-Factory, and the core code can be seen at:
./LLaMA-Factory/src/llamafactory/train/ppo/trainer_qr_s_g.py
Warm Start for RAG System
To warm start multiple modules in the RAG system using SFT, execute:
bash LLaMA-Factory/run_sft.sh
Multi-Agent Optimization for RAG System
To perform joint learning of the multiple modules in the RAG system using MAPPO, run the following command in another server:
bash LLaMA-Factory/run_mappo.sh
Evaluation
Evaluate the performance of the RAG system by executing:
CUDA_VISIBLE_DEVICES=0 python evaluate_qr_s_g.py
Others
Create necessary directories:
-
./datafor storing data sets. For example,./data/ambigqais used to save the AmbigQA dataset. -
./modelsfor saving checkpoints of the retrieval model and LLMs.