A graph-transformer for whole slide image classification
June 17, 2024 ยท View on GitHub
This work is published in IEEE Transactions on Medical Imaging (https://doi.org/10.1109/TMI.2022.3176598).
Introduction
This repository contains a PyTorch implementation of a deep learning based graph-transformer for whole slide image (WSI) classification. We propose a Graph-Transformer (GT) network that fuses a graph representation of a WSI and a transformer that can generate WSI-level predictions in a computationally efficient fashion.
To demonstrate the applicability of our approach, we selected 3,024 hematoxylin and eosin WSIs of lung tumors and the oneswith normal histology from the Clinical Proteomic TumorAnalysis Consortium (CPTAC), the National Lung ScreeningTrial (NLST) and The Cancer Genome Atlas (TCGA) and developed a model to distinguish adenocarcinoma (LUAD) and squamous cell carcinoma (LSCC) from those that havenormal histology. To understand how our model processes WSI data and visualize regions that are highly associated with the class label, we proposed a novel class activation mapping technique called GraphCAM on graphs. see below:
Usage
1. Graph Construction
(a) Tiling Patch
python src/tile_WSI.py -s 512 -e 0 -j 32 -B 50 -M 20 -o <full_patch_to_output_folder> "full_path_to_input_slides/*/*.svs"
Mandatory parameters:
(b) Training Patch Feature Extractor
Go to './feature_extractor' and config 'config.yaml' before training. The trained feature extractor based on contrastive learning is saved in folder './feature_extractor/runs'. We train the model with patches cropped in single magnification (20X). Before training, put paths to all pathces in 'all_patches.csv' file.
python run.py
You could use pretrained feature extractor: feature_extractor/model.pth. The pre-trained models can be downloaded.
(c) Constructing Graph
Go to './feature_extractor' and build graphs from patches:
python build_graphs.py --weights "path_to_pretrained_feature_extractor" --dataset "path_to_patches" --output "../graphs"
2. Training Graph-Transformer
Run the following script to train and store the model and logging files under "graph_transformer/saved_models" and "graph_transformer/runs".
bash scripts/train.sh
To evaluate the model. run
bash scripts/test.sh
Split training, validation, and testing dataset and store them in text files as:
sample1 \t label1
sample2 \t label2
LUAD/C3N-00293-23 \t luad
...
3. GraphCAM
To generate GraphCAM of the model on the WSI:
1. bash scripts/get_graphcam.sh
To visualize the GraphCAM:
2. bash scripts/vis_graphcam.sh
Note: Currently we only support generating GraphCAM for one WSI at each time.
More GraphCAM examples:
GraphCAMs generated on WSIs across the runs performed via 5-fold cross validation are shown above. The same set of WSI regions are highlighted by our method across the various cross-validation folds, thus indicating consistency of our technique in highlighting salient regions of interest.