Installation

June 12, 2026 ยท View on GitHub

GridFM logo

gridfm-graphkit

DOI Docs Coverage OpenSSF Best Practices OpenSSF Scorecard Python License

This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.


Installation

Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)

python -m venv venv
source venv/bin/activate

Install gridfm-graphkit from PyPI

pip install gridfm-graphkit

torch-scatter is a required dependency. It cannot be bundled in pyproject.toml because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.

Get PyTorch + CUDA version for torch-scatter

TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")

Install the correct torch-scatter wheel

pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html

For documentation generation and unit testing, install with the optional dev and test extras:

pip install "gridfm-graphkit[dev,test]"

CLI commands

Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.

gridfm_graphkit <command> [OPTIONS]

Available commands:

  • train - Train a new model from scratch
  • finetune - Fine-tune an existing pre-trained model
  • evaluate - Evaluate model performance on a dataset
  • predict - Run inference and save predictions

Training Models

gridfm_graphkit train --config path/to/config.yaml

Arguments

ArgumentTypeDescriptionDefault
--configstrRequired. Path to the training configuration YAML file.None
--exp_namestrMLflow experiment name.timestamp
--run_namestrMLflow run name.run
--log_dirstrMLflow tracking/logging directory.mlruns
--data_pathstrRoot dataset directory.data
--compile [MODE]strEnable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default.None
--bfloat16flagCast model to torch.bfloat16 (model.to(torch.bfloat16)).False
--tf32flagEnable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high").False
--dataset_wrapperstrRegistered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset.None
--pluginslist[str]Python packages to import for plugin registration, e.g. gridfm_graphkit_ee.[]
--num_workersintOverride data.workers from YAML. Use 0 to debug worker crashes.None
--dataset_wrapper_cache_dirstrDisk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population.None
--profilerstrEnable Lightning profiler (simple, advanced, pytorch).None
--compute_dc_ac_metricsflagCompute ground-truth AC/DC power balance metrics on the test split.False
--mp_contextstrDataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning.None

Examples

Standard Training:

gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data

Fine-Tuning Models

gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt

Arguments

ArgumentTypeDescriptionDefault
--configstrRequired. Fine-tuning configuration file.None
--model_pathstrRequired. Path to a pre-trained model state dict.None
--exp_namestrMLflow experiment name.timestamp
--run_namestrMLflow run name.run
--log_dirstrMLflow logging directory.mlruns
--data_pathstrRoot dataset directory.data
--compile [MODE]strEnable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default.None
--bfloat16flagCast model to torch.bfloat16 (model.to(torch.bfloat16)).False
--tf32flagEnable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high").False
--dataset_wrapperstrRegistered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset.None
--pluginslist[str]Python packages to import for plugin registration, e.g. gridfm_graphkit_ee.[]
--num_workersintOverride data.workers from YAML. Use 0 to debug worker crashes.None
--dataset_wrapper_cache_dirstrDisk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population.None
--profilerstrEnable Lightning profiler (simple, advanced, pytorch).None
--compute_dc_ac_metricsflagCompute ground-truth AC/DC power balance metrics on the test split.False
--mp_contextstrDataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning.None

Evaluating Models

gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt

Arguments

ArgumentTypeDescriptionDefault
--configstrRequired. Path to evaluation config.None
--model_pathstrPath to the trained model state dict.None
--normalizer_statsstrPath to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics instead of re-fitting on current split.None
--exp_namestrMLflow experiment name.timestamp
--run_namestrMLflow run name.run
--log_dirstrMLflow logging directory.mlruns
--data_pathstrDataset directory.data
--compile [MODE]strEnable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default.None
--bfloat16flagCast model to torch.bfloat16 (model.to(torch.bfloat16)).False
--tf32flagEnable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high").False
--dataset_wrapperstrRegistered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset.None
--pluginslist[str]Python packages to import for plugin registration, e.g. gridfm_graphkit_ee.[]
--num_workersintOverride data.workers from YAML. Use 0 to debug worker crashes.None
--dataset_wrapper_cache_dirstrDisk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population.None
--profilerstrEnable Lightning profiler (simple, advanced, pytorch).None
--compute_dc_ac_metricsflagCompute ground-truth AC/DC power balance metrics on the test split.False
--save_outputflagSave predictions as <grid_name>_predictions.parquet under MLflow artifacts (.../artifacts/test).False
--mp_contextstrDataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning.None

Example with saved normalizer stats

When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:

gridfm_graphkit evaluate \
  --config examples/config/HGNS_PF_datakit_case118.yaml \
  --model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
  --normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
  --data_path data

Note: The --normalizer_stats flag only affects normalizers with fit_strategy = "fit_on_train" (e.g. HeteroDataMVANormalizer). Per-sample normalizers (HeteroDataPerSampleMVANormalizer) always recompute their statistics from the current dataset regardless of this flag.


Running Predictions

gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt

Arguments

ArgumentTypeDescriptionDefault
--configstrRequired. Path to prediction config file.None
--model_pathstrPath to trained model state dict. Optional; may be defined in config.None
--normalizer_statsstrPath to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics.None
--exp_namestrMLflow experiment name.timestamp
--run_namestrMLflow run name.run
--log_dirstrMLflow logging directory.mlruns
--data_pathstrDataset directory.data
--dataset_wrapperstrRegistered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset.None
--pluginslist[str]Python packages to import for plugin registration, e.g. gridfm_graphkit_ee.[]
--num_workersintOverride data.workers from YAML. Use 0 to debug worker crashes.None
--dataset_wrapper_cache_dirstrDisk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population.None
--output_pathstrDirectory where predictions are saved as <grid_name>_predictions.parquet.data
--compile [MODE]strEnable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default.None
--bfloat16flagCast model to torch.bfloat16 (model.to(torch.bfloat16)).False
--tf32flagEnable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high").False
--profilerstrEnable Lightning profiler (simple, advanced, pytorch).None
--mp_contextstrDataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning.None

Benchmarking Dataloader Throughput

gridfm_graphkit benchmark --config path/to/config.yaml

Arguments

ArgumentTypeDescriptionDefault
--configstrRequired. Path to configuration YAML file.None
--data_pathstrRoot dataset directory.data
--epochsintNumber of epochs to iterate through the train dataloader.3
--dataset_wrapperstrRegistered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset.None
--dataset_wrapper_cache_dirstrDirectory for dataset wrapper disk cache.None
--num_workersintOverride data.workers from YAML.None
--pluginslist[str]Python packages to import for plugin registration.[]
--mp_contextstrDataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning.None

Use built-in help for full command details:

gridfm_graphkit --help
gridfm_graphkit <command> --help

Running Tests

Unit and Integration Tests

Install the test dependencies first (if not already done):

pip install -e .[dev,test]

Run the full unit test suite:

pytest ./tests

Run the base set integration tests:

pytest ./integrationtests/test_base_set.py

Running Base Set Tests on an LSF Cluster (GPU)

To submit the base set integration tests as an interactive LSF job with GPU access, use bsub. Adjust the paths to match your environment:

bsub -gpu "num=1" \
     -n 16 \
     -R "rusage[mem=32GB] span[hosts=1]" \
     -Is \
     -J gridfm_base_set_tests \
     /bin/bash -c "
       cd /path/to/gridfm-graphkit && \
       export PATH=/path/to/cuda/bin:\$PATH \
               CUDA_HOME=/path/to/cuda \
               LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
       source /path/to/venv/bin/activate && \
       pytest ./integrationtests/test_base_set.py
     "

Key bsub options used above:

OptionDescription
-gpu "num=1"Request 1 GPU
-n 16Request 16 CPU slots
-R "rusage[mem=32GB] span[hosts=1]"Reserve 32 GB of memory on a single host
-IsRun as an interactive shell session
-J <job_name>Assign a name to the job

Concrete example (adapt paths to your cluster setup):

bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"