Graph World Model (GWM), ICML 2025

September 20, 2025 Β· View on GitHub

Build Build License
Build Build Build

🌐 Project Page | πŸ“œ arXiv |

πŸ“Œ Applicable Scenarios

GWM covers six scenarios: multi-modal generation and matching, recommendation systems, graph prediction, multi-agent collaboration, retrieval-augmented generation, and planning and optimization. It represents different entities and their interactions as graph nodes and edges, enabling unified modeling across these tasks.

Ranking FM Instantiations

🧠 Method

GWM models world states as graphs with multi-modal nodes and defines actions at node, edge, and graph levels, enabling state updates through intended and unintended actions.

Ranking FM Method

News

[2025.01.22] 🌟 Graph World Model is accepted for ICML 2025.

πŸ“ŒPreliminary

Environment Setup

# create a new environment
conda create -n gwm python=3.10
conda activate gwm

# install pytorch. Modify the command to align with your own CUDA version.
pip3 install torch  --index-url https://download.pytorch.org/whl/cu118

# install related libraries
pip install -r requirements.txt

# install flash-attn
pip install flash-attn --no-build-isolation

# install pyg
pip install torch_geometric
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu118.html

Dataset Preparation

You can download the multi_modal_data from the provided links. Put the datasets into the multi_modal_data/ folder like below:

$CODE_DIR
β”œβ”€β”€ model
β”œβ”€β”€ multi_modal_data
β”‚   β”œβ”€β”€ agent
β”‚   β”œβ”€β”€ goodreads
β”‚   β”œβ”€β”€ multimodal_paper
β”‚   β”œβ”€β”€ optimization
β”‚   β”œβ”€β”€ rag
β”‚   β”œβ”€β”€ recommendation
β”‚   β”‚   β”œβ”€β”€ baby
β”‚   β”‚   β”œβ”€β”€ clothing
β”‚   β”‚   └── sports
β”‚   └── traditional_graph
β”‚       β”œβ”€β”€ cora
β”‚       β”œβ”€β”€ HIV
β”‚       └── pubmed
└── gwm_e

Inside each leaf directory of multi_modal_data, there are three files: train_node/edge/graph_data.jsonl, test_node/edge/graph_data.jsonl, and multi_hop_graph_embedding.pt. Specifically, the .pt file is the embedding file and the .jsonl files indicate the training/testing data samples. Each row corresponds to one data sample in a dictionary format:

{
 "id": [59],
 "conversations": [
   {
     "from": "human", 
     "value": "What is the correct answer to this question: A 23-year-old man presented with a 1-month history of double vision and right eyelid drooping that worsened at the end of the day..."
   },
   {
     "from": "gpt",
     "value": "E"
   }
 ],
 "graph": 1
}

"id" indicates the corresponding rows in the embedding file, "conversations" contains the prompt and its corresponding label, and "graph" serves as an indicator for utilizing graph tokens.

⭐Experiments

Training and Evaluation

Run experiments and save the checkpoint.

deepspeed --num_gpus=3 gwm_e/train.py --tune_mm_mlp_adapter True --deepspeed ../scripts/zero2.json --mm_use_graph_start_end False --mm_use_graph_patch_token False --bf16 True --num_train_epochs 1 --per_device_train_batch_size 10  --per_device_eval_batch_size 4 --gradient_accumulation_steps 1 --evaluation_strategy "no" --save_strategy "epoch" --learning_rate 3e-4 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type "cosine" --logging_steps 1 --tf32 True --gradient_checkpointing True --lazy_preprocess True --report_to wandb

πŸ“ Acknowledgement

The implementation of GWM is built upon LLaGA and LLaVA.

We sincerely appreciate the efforts of these teams for their contributions to open-source research and development.

Citation

@inproceedings{fenggraph,
  title={Graph World Model},
  author={Feng, Tao and Wu, Yexin and Lin, Guanyu and You, Jiaxuan},
  booktitle={Forty-second International Conference on Machine Learning}
}