Graphite

April 28, 2026 · View on GitHub

Reference implementation of:

Graphite: GRAPH-Induced feaTure Extraction for Point Cloud Registration Mahdi Saleh, Shervin Dehghani, Nassir Navab, Benjamin Busam, Federico Tombari International Conference on 3D Vision (3DV), IEEE, 2020.

A graph neural network for learning local point-cloud descriptors and keypoints. Originally written for PyTorch 1.1 / CUDA 9.0 / Python 3.6; this fork brings it up to PyTorch 2.x / CUDA 11.8 / Python 3.10 and reproduces the paper's ModelNet40 protocol plus the 3DMatch geometric-registration benchmark.

Note on training stages. The original paper describes three stages — SynPrim pretraining → COCO pretraining → ModelNet finetuning. Neither SynPrim nor the COCO patch generator ship with the public repo, and the authors note that the SynPrim stage adds little. This fork therefore trains directly on ModelNet from scratch.

src/
├── train_modelnet.py        train on ModelNet (40 or 10)
├── evaluate_modelnet_rt.py  ModelNet rigid-registration eval (DCP-style)
├── evaluate_match3d.py      3DMatch geometric-registration eval
├── params.py                paths + training_configs
├── _compat.py               Open3D / NumPy / SciPy compatibility shims
├── wandb_logger.py          optional W&B logging
├── model.py                 Net + NetAE
├── pretrain.py              training loop
├── modelnetdataset.py       ModelNet pair generation
├── match3Ddataset.py        3DMatch fragment loader (random-keypoint fallback)
├── test_rt.py / test_match.py
└── …
scripts/
├── download_modelnet40.sh   one-shot ModelNet40 downloader
└── download_match3d.sh      one-shot 3DMatch downloader
data/{modelnet,match3d}/     datasets land here
models/                      trained checkpoints
logs/                        per-tag training logs

1. Install

conda create -y -n graphite python=3.10
conda activate graphite

pip install --no-cache-dir 'numpy<2' \
    torch==2.4.1 torchvision==0.19.1 \
    --index-url https://download.pytorch.org/whl/cu118

pip install --no-cache-dir torch-scatter torch-sparse torch-cluster \
    -f https://data.pyg.org/whl/torch-2.4.0+cu118.html
pip install --no-cache-dir torch-geometric==2.5.3

pip install --no-cache-dir 'open3d==0.18.0' 'opencv-python==4.10.0.84' \
    pyquaternion h5py

Sanity check:

python -c "import torch, torch_geometric, open3d; print(torch.__version__, torch.cuda.is_available())"

The pinned versions are dictated by the torch-{scatter,sparse,cluster} wheels at data.pyg.org, which only exist for specific torch / CUDA combos. Nothing is installed system-wide.

2. Dataset

./scripts/download_modelnet40.sh

This drops ModelNet40 (airplane/, bathtub/, …, xbox/) into data/modelnet/ and writes object_names.txt. Total ~450 MB.

ModelNet10 (smaller / faster smoke runs)

mkdir -p data/modelnet && cd data/modelnet
curl -L -o ModelNet10.zip \
    "http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip"
unzip -q ModelNet10.zip && mv ModelNet10/* . && rm -rf ModelNet10 ModelNet10.zip __MACOSX

printf 'bathtub\nbed\nchair\ndesk\ndresser\nmonitor\nnight_stand\nsofa\ntable\ntoilet\n' \
    > object_names.txt

3. Train

conda activate graphite
python src/train_modelnet.py            # paper-protocol defaults

Defaults follow Table 1 of the paper: data_mode=noise, mb_size=12, lr=1e-4, epochs=50, and the DCP unseen-categories split (first 20 categories train, last 20 test) — auto-selected when object_names.txt contains 40 categories. On ModelNet10 the split falls back to using all 10 categories for both train and test (no unseen-category generalization possible with only 10 classes).

The first run lazily preprocesses each .off mesh into patch graphs and caches the result under data/modelnet/processed/; later runs skip preprocessing. After every epoch the script runs the ModelNet rigid-registration eval (test_rt.test) and saves a checkpoint to models/trained_<tag>.model.

Quick smoke test (caps meshes per class):

GRAPHITE_MAX_TRAIN_PER_CLASS=3 GRAPHITE_MAX_TEST_PER_CLASS=2 \
    python src/train_modelnet.py --epoch-size 1

CLI flags (most useful)

FlagDefaultNotes
--epoch-size50Number of epochs
--mb-size12Must be a multiple of 3 (triplet sampling)
--learning-rate1e-4Optimizer base LR
--alpha1.0Weight on the descriptor (triplet) loss; losses now use mean reduction
--m0.7Triplet margin (unit-norm descriptors, distances in [0, 2])
--weight-decay1e-4AdamW weight decay; 0 falls back to plain Adam
--grad-clip1.0Max gradient norm; <=0 disables clipping
--lr-schedulecosinecosine / step / none
--archnetaeArchitecture: netae (default) or net
--data-modenoisenoise / patch / patchnoise
--splitautodcp (first 20 / last 20), all, first11, custom
--resume PATHContinue from a saved .model
--file-tagmodelnet_runUsed in checkpoint and log filenames

The triplet sampler is class-aware: negatives are drawn from a different ModelNet category than the anchor (no more silent label noise from random.randint). Best-on-val checkpoints are written alongside each periodic save as <tag>_best.model; the metric is configurable with GRAPHITE_BEST_METRIC (default eval_modelnet/icp/rot_RMSE).

Optional: Weights & Biases tracking

Training logs every per-epoch loss and ModelNet eval metric to W&B if it's configured. The default project is graphite-registration.

pip install wandb && wandb login
python src/train_modelnet.py

# Override project / entity if you want:
WANDB_PROJECT=my-project python src/train_modelnet.py

# WANDB_MODE=disabled to turn it off; the helper is also a no-op if `wandb`
# isn't installed.

Useful environment variables

VariableEffect
MODELNET_DIRModelNet root (default data/modelnet/)
MATCH3D_DIR3DMatch root (default data/match3d/)
GRAPHITE_MODELNET_SPLITOverride split selection (all / dcp / first11 / custom)
GRAPHITE_TRAIN_CLASSES / GRAPHITE_TEST_CLASSESComma-separated class indices when split=custom
GRAPHITE_MAX_TRAIN_PER_CLASS / GRAPHITE_MAX_TEST_PER_CLASSCap meshes per class during preprocessing (0 = no cap)
GRAPHITE_DATA_ROOT / GRAPHITE_MODELS_ROOT / GRAPHITE_LOG_ROOTOverride default data/, models/, logs/
WANDB_PROJECT / WANDB_MODEStandard W&B knobs (set WANDB_MODE=disabled to silence it)

Tuning the SVD+RANSAC and ICP at eval time

VariableDefaultEffect
GRAPHITE_RANSAC_INLIER_THRESH0.05Inlier threshold for the RANSAC pose solver, in normalised object units
GRAPHITE_RANSAC_NUM_ITER1000Number of RANSAC hypotheses
GRAPHITE_RANSAC_MIN_SAMPLE4Correspondences sampled per hypothesis
GRAPHITE_RANSAC_ROT_GATE_RAD1.4Reject rotations larger than this (radians); training samples up to ~1.36 rad
GRAPHITE_TRANS_SANITY_MAX1.0Drop predictions with `
GRAPHITE_CONFIDENCE_THRESH0.0Optional keypoint-confidence filter (e.g. 0.2 to drop low-confidence keypoints)
GRAPHITE_ICP_METHODplaneplane for point-to-plane, point for point-to-point ICP
GRAPHITE_EVAL_SEED0Seeds NumPy / Python RNG before eval; set to "" for non-deterministic

4. Evaluation

Two benchmarks are wired up. Both load a checkpoint produced by train_modelnet.py.

4.1 ModelNet rigid registration (DCP-style)

python src/evaluate_modelnet_rt.py \
    --model models/trained_modelnet_run_noise_mb12_a0.1_m0.5_lr0.00014.model \
    --data-mode noise            # patch / noise / patchnoise
    # --split dcp                # auto-selected on ModelNet40

Reports rotation MSE/RMSE/MAE (degrees) and translation MSE/RMSE/MAE for raw descriptor matches and ICP-refined results — the metrics from the DCP paper. Use --split to evaluate on a different category partition than the one auto-selected from object_names.txt. Pass --max-pairs N to evaluate only the first N test pairs.

4.2 3DMatch geometric registration benchmark

# Download fragments + gt.log (~300 MB, all 8 scenes)
./scripts/download_match3d.sh
# ...or one scene only:
SCENES="kitchen" ./scripts/download_match3d.sh

python src/evaluate_match3d.py \
    --model models/trained_modelnet_run_noise_mb12_a0.1_m0.5_lr0.00014.model \
    --scenes kitchen home1       # default: all 8 scenes

Outputs per-scene keypoint-match recall and writes data/match3d/eval_files/<scene>/graphite.log with predicted transformations in the standard Choi/Zhou/Koltun .log format. That file is a drop-in input for the Matlab evaluate.m script in 3dmatch-toolbox if you want precision/recall instead of keypoint recall.

Pre-computed keypoint files (01_Keypoints/cloud_bin_XKeypoints.txt) are not part of the public 3DMatch download. When missing, the loader falls back to seeded uniform-random keypoints; tune via GRAPHITE_MATCH3D_NUM_KEYPOINTS (default 5000). If you have the official keypoints, drop them into data/match3d/fragments/<scene>/01_Keypoints/ and they're picked up automatically.

5. Loading a checkpoint manually

import params, utils
from params import training_configs
from model import init_model

settings = training_configs(file_tag='modelnet_run', mb_size=12,
                            alpha=0.1, m=0.5, learning_rate=1e-4,
                            data_mode='noise')
model = init_model(utils.get_model_name(False, settings), settings)
model.eval()