Instruction

February 9, 2026 · View on GitHub

SGS-GNN: A Supervised Graph Sparsifier for Graph Neural Networks

SGS-GNN is a novel supervised graph sparsification algorithm that learns the sampling probability distribution of edges and samples sparse subgraphs of a user-specified size to reduce the memory required by GNNs for inference tasks on large graphs.

Installation:

These are the necessary packages for installation from scratch and other related packages.

Python version: 3.11
Pytorch version: 2.0.1
Cuda: 11.7
Cudnn: 8.6
Pytorch-Geometric: 2.3.1

For direct installation, Conda packages are in environment.yml, and PIP packages are in requirements.txt and can be imported as,

conda env create -f environment.yml
pip install -r requirements.txt

Run

python main.py --dataset SmallCora --mode learned --runs 3 --epochs 250 --save_csv True --edge_mlp_type GCN --GNN GCN --log True --sparse_edge_mlp True --conditional True --reg1 True --reg2 True --hybrid_checkpoint --pipeline hybrid

Demo run

#!/usr/bin/env bash
set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
ROOT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"

COMMON_ARGS=(
  --dataset Reddit
  --mode learned
  --runs 1
  --epochs 10
  --save_csv True
  --edge_mlp_type GCN
  --GNN GCN
  --log False
  --sparse_edge_mlp True
  --conditional True
  --reg1 True
  --reg2 True
  --stats True
  --hybrid_checkpoint True
)

run_pipeline () {
  local pipeline="\$1"
  local log_file="${ROOT_DIR}/pipeline_${pipeline}.log"
  echo "=== Running pipeline: ${pipeline} ==="
  (cd "${ROOT_DIR}" && python main.py "${COMMON_ARGS[@]}" --pipeline "${pipeline}") | tee "${log_file}"
  echo "--- Stats (${pipeline}) ---"
  grep -n "\\[stats\\]" "${log_file}" || true
  echo ""  
}

run_pipeline "two_pass"
run_pipeline "straight_through"
run_pipeline "hybrid"

Diagram (SGS-GNN is the hybrid version in the paper)

Below are corrected, syntax‑valid Mermaid diagrams for the three pipelines. Solid arrows for forward pass and dashed arrows for backward/grad flow. Key differences (straight‑through vs detach vs two‑pass recompute) are also highlighted.

Straight‑Through

flowchart LR
  X["batch.x"] --> E["EdgeProbMLP (grad)"]
  EI["batch.edge_index"] --> E
  E --> P["edge_probs_full"]

  P --> GS["gumbel_softmax_sampling (straight-through)"]
  EI --> GS
  GS --> SEI["sampled_edge_index"]
  GS --> SEW["sampled_edge_weight (straight-through)"]

  X --> G["GNN"]
  SEI --> G
  SEW --> G

  G --> L["loss"]
  SEW --> R1["reg1 BCE"]
  SEI --> R1
  G --> R2["reg2 consistency"]
  SEW --> R2
  R1 --> L
  R2 --> L

  L -.-> G
  L -.-> E
  L -.-> GS
  GS -.-> P

Hybrid

flowchart LR
  X["batch.x"] --> E["EdgeProbMLP (grad, optional checkpoint)"]
  EI["batch.edge_index"] --> E
  E --> P["edge_probs_full"]

  P -->|"detach"| GS["gumbel_softmax_sampling"]
  EI --> GS
  GS --> SEI["sampled_edge_index"]

  P --> IDX["index_select by sampled_edge_index"]
  IDX --> SEW["edge_probs_sampled"]

  X --> G["GNN"]
  SEI --> G
  SEW --> G

  G --> L["loss"]
  SEW --> R1["reg1 BCE"]
  SEI --> R1
  G --> R2["reg2 consistency"]
  SEW --> R2
  R1 --> L
  R2 --> L

  L -.-> G
  L -.-> E
  L -.-> IDX

Two‑Pass

flowchart LR
  X["batch.x"] --> E1["EdgeProbMLP pass1 (no grad)"]
  EI["batch.edge_index"] --> E1
  E1 --> P["edge_probs_full (detached)"]

  P --> GS["gumbel_softmax_sampling"]
  EI --> GS
  GS --> SEI["sampled_edge_index"]

  X --> E2["EdgeProbMLP pass2 (grad, sampled only)"]
  SEI --> E2
  E2 --> SEW["edge_probs_sampled"]

  X --> G["GNN"]
  SEI --> G
  SEW --> G

  G --> L["loss"]
  SEW --> R1["reg1 BCE"]
  SEI --> R1
  G --> R2["reg2 consistency"]
  SEW --> R2
  R1 --> L
  R2 --> L

  L -.-> G
  L -.-> E2

Full Diagram

Below are full, syntax‑valid Mermaid diagrams for each pipeline, with separate subgraphs for the conditional gate, random baseline branch, and optimizer steps. Solid arrows are forward pass; dashed arrows are backward/grad flow.

Straight‑Through

flowchart LR
  subgraph Forward_Learned["Forward (learned path)"]
    X["batch.x"] --> E["EdgeProbMLP (grad)"]
    EI["batch.edge_index"] --> E
    E --> P["edge_probs_full"]

    P --> GS["gumbel_softmax_sampling (straight-through, temp, degree bias)"]
    EI --> GS
    GS --> SEI["sampled_edge_index"]
    GS --> SEW["sampled_edge_weight (straight-through)"]

    X --> G["GNN (learned edges)"]
    SEI --> G
    SEW --> G
    G --> Lout["learned_out"]
  end

  subgraph Random_Baseline["Random baseline (conditional)"]
    RP["F.softmax(batch.prob)"] --> RSEL["random_edge_sample"]
    RSEL --> RSEI["random_sampled_edge_index"]
    X --> RG["GNN (random edges)"]
    RSEI --> RG
    RG --> Rout["random_out"]
  end

  subgraph Regularizers["Regularizers (learned path)"]
    SEW --> R1["reg1 BCE"]
    SEI --> R1
    Lout --> R2["reg2 consistency"]
    SEW --> R2
  end

  subgraph Gate["Conditional update gate"]
    Lout --> F1L["F1(learned_out)"]
    Rout --> F1R["F1(random_out)"]
    F1L --> CMP["compare F1"]
    F1R --> CMP
    CMP -->|learned wins| Llearn["loss(learned_out + regs)"]
    CMP -->|random wins| Lrand["loss(random_out)"]
  end

  R1 --> Llearn
  R2 --> Llearn

  subgraph Optimizers["Optimizers"]
    Llearn --> Oe["optimizer_edge_prob.step()"]
    Llearn --> Og["optimizer_gnn.step()"]
    Lrand --> Ogr["optimizer_gnn.step()"]
  end

  Llearn -.-> G
  Llearn -.-> E
  Llearn -.-> GS
  Lrand -.-> RG

Hybrid

flowchart LR
  subgraph Forward_Learned["Forward (learned path)"]
    X["batch.x"] --> E["EdgeProbMLP (grad, optional checkpoint)"]
    EI["batch.edge_index"] --> E
    E --> P["edge_probs_full"]

    P -->|"detach"| GS["gumbel_softmax_sampling (temp, degree bias)"]
    EI --> GS
    GS --> SEI["sampled_edge_index"]

    P --> IDX["index_select by sampled_edge_index"]
    IDX --> SEW["edge_probs_sampled"]

    X --> G["GNN (learned edges)"]
    SEI --> G
    SEW --> G
    G --> Lout["learned_out"]
  end

  subgraph Random_Baseline["Random baseline (conditional)"]
    RP["F.softmax(batch.prob)"] --> RSEL["random_edge_sample"]
    RSEL --> RSEI["random_sampled_edge_index"]
    X --> RG["GNN (random edges)"]
    RSEI --> RG
    RG --> Rout["random_out"]
  end

  subgraph Regularizers["Regularizers (learned path)"]
    SEW --> R1["reg1 BCE"]
    SEI --> R1
    Lout --> R2["reg2 consistency"]
    SEW --> R2
  end

  subgraph Gate["Conditional update gate"]
    Lout --> F1L["F1(learned_out)"]
    Rout --> F1R["F1(random_out)"]
    F1L --> CMP["compare F1"]
    F1R --> CMP
    CMP -->|learned wins| Llearn["loss(learned_out + regs)"]
    CMP -->|random wins| Lrand["loss(random_out)"]
  end

  R1 --> Llearn
  R2 --> Llearn

  subgraph Optimizers["Optimizers"]
    Llearn --> Oe["optimizer_edge_prob.step()"]
    Llearn --> Og["optimizer_gnn.step()"]
    Lrand --> Ogr["optimizer_gnn.step()"]
  end

  Llearn -.-> G
  Llearn -.-> E
  Llearn -.-> IDX
  Lrand -.-> RG

Two‑Pass

flowchart LR
  subgraph Forward_Learned["Forward (learned path)"]
    X["batch.x"] --> E1["EdgeProbMLP pass1 (no grad)"]
    EI["batch.edge_index"] --> E1
    E1 --> P["edge_probs_full (detached)"]

    P --> GS["gumbel_softmax_sampling (temp, degree bias)"]
    EI --> GS
    GS --> SEI["sampled_edge_index"]

    X --> E2["EdgeProbMLP pass2 (grad, sampled only)"]
    SEI --> E2
    E2 --> SEW["edge_probs_sampled"]

    X --> G["GNN (learned edges)"]
    SEI --> G
    SEW --> G
    G --> Lout["learned_out"]
  end

  subgraph Random_Baseline["Random baseline (conditional)"]
    RP["F.softmax(batch.prob)"] --> RSEL["random_edge_sample"]
    RSEL --> RSEI["random_sampled_edge_index"]
    X --> RG["GNN (random edges)"]
    RSEI --> RG
    RG --> Rout["random_out"]
  end

  subgraph Regularizers["Regularizers (learned path)"]
    SEW --> R1["reg1 BCE"]
    SEI --> R1
    Lout --> R2["reg2 consistency"]
    SEW --> R2
  end

  subgraph Gate["Conditional update gate"]
    Lout --> F1L["F1(learned_out)"]
    Rout --> F1R["F1(random_out)"]
    F1L --> CMP["compare F1"]
    F1R --> CMP
    CMP -->|learned wins| Llearn["loss(learned_out + regs)"]
    CMP -->|random wins| Lrand["loss(random_out)"]
  end

  R1 --> Llearn
  R2 --> Llearn

  subgraph Optimizers["Optimizers"]
    Llearn --> Oe["optimizer_edge_prob.step()"]
    Llearn --> Og["optimizer_gnn.step()"]
    Lrand --> Ogr["optimizer_gnn.step()"]
  end

  Llearn -.-> G
  Llearn -.-> E2
  Lrand -.-> RG