C$^2$MIL: Dual-Causal Graph-Based MIL for Survival Analysis

November 2, 2025 ยท View on GitHub

Paper Code License

Official PyTorch implementation of C2C^2MIL, a dual-causal graph-based multiple instance learning (MIL) model designed for robust and interpretable survival analysis on whole slide images (WSIs). Minor revisions have been made in the arXiv version of the C2C^2MIL to make the work more rigorous. The details can be verified from the arXiv version and the GitHub code.

๐Ÿ” Overview

Graph-based MIL is widely used in computational pathology but faces two key challenges:

  1. Semantic Confounding Bias
    Variations in staining, sectioning, and scanning introduce irrelevant features that harm generalization.

  2. Topological Noise
    Not all subgraphs in WSIs are causally relevant to survival outcomes, leading to biased representations.

To tackle these, we propose C2MIL, which synchronizes semantic and topological causalities via a dual structural causal model.

โœจ Key Features

  • Cross-Scale Adaptive Feature Disentangling (CAFD):
    Removes trivial semantic confounders via backdoor adjustment, adaptively learning confounders without prior knowledge.

  • Bernoulli Differentiable Subgraph Sampling:
    Identifies causal subgraphs within WSIs using a straight-through estimator for robust topology learning.

  • Joint Optimization:
    Combines semantic supervision and topological contrastive learning under causal invariance.

  • Generalizable & Interpretable:
    Achieves state-of-the-art survival prediction while providing interpretable attention heatmaps and adaptive clustering.

๐Ÿ“Š Performance

C2MIL achieves state-of-the-art C-index across three TCGA cohorts, with significant improvements in both cross-validation and out-of-distribution generalization.

ModelGraphCausalKIRC (CV)ESCA (CV)BLCA (CV)KIRC (OOD)ESCA (OOD)BLCA (OOD)
ABMILโœ—โœ—0.6790.6390.5770.5970.6140.673
TransMILโœ—โœ—0.6660.5650.5680.6100.5390.676
RRTMILโœ—โœ—0.6780.6200.5660.5840.5890.679
DeepGraphConvโœ“โœ—0.6670.6120.5720.5090.5980.613
PatchGCNโœ“โœ—0.6860.6520.5760.6060.5680.697
ProtoSurvโœ“โœ—0.6980.6190.5930.6100.5980.695
IBMILโœ—โœ“0.6970.5890.5530.6160.5710.654
C2MIL (Ours)โœ“โœ“0.7080.6900.6080.6280.6500.702

โš™๏ธ Installation

git clone https://github.com/mimic0127/C2MIL.git
cd C2MIL
conda create -n c2mil python=3.9
conda activate c2mil
pip install -r requirements.txt

Dependencies:

๐Ÿš€ Usage

0. WSI Preprocessing

Preprocess WSIs using CLAM to obtain tiled patches and the corresponding .h5 files.

1. Feature Extraction

Extract patch-level and thumbnail-level features using a pretrained backbone (e.g., UNI, ViT, CTransPath) in the data_process folder.

Patch features:

python patch_fea_sample.py

Thumbnail features:

python thumbnail_svs.py
python thumbnail_fea_pocess.py

2. Graph Construction

Construct patch-level graphs with KNN based on patch coordinates:

python to_Graph.py

3. Training

Split the dataset into training, validation, and test subsets:

python fold.py

Train the model:

python train.py

4. Evaluation

Run evaluation and prediction:

python test_prediction.py

๐Ÿ“‚ Repository Structure

๐Ÿ“œ Citation

If you find this repository useful, please โญ๏ธ star it and cite our paper:

@inproceedings{cen2025c2mil,
  title={C2MIL: Synchronizing Semantic and Topological Causalities in Multiple Instance Learning for Robust and Interpretable Survival Analysis},
  author={Cen, Min and Zhuang, Zhenfeng and Zhang, Yuzhe and Zeng, Min and Magnier, Baptiste and Yu, Lequan and Zhang, Hong and Wang, Liansheng},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  pages={24392--24401},
  year={2025}
}

๐Ÿ“ License

This project is licensed under the MIT License.