Launch a training run

March 16, 2026 · View on GitHub

jepa-rs

Joint Embedding Predictive Architecture in Rust

CI License: MIT docs.rs crates.io


Alpha Rust implementation of JEPA (Joint Embedding Predictive Architecture) — the self-supervised learning framework from Yann LeCun and Meta AI for learning world models that predict in representation space rather than pixel space.

jepa-rs provides modular, backend-agnostic building blocks for I-JEPA (images), V-JEPA (video), C-JEPA (causal object-centric), and hierarchical world models, built on top of the burn deep learning framework. It includes a CLI, an interactive TUI dashboard, a browser demo crate for local experimentation, safetensors checkpoint loading, ONNX metadata inspection, and a pretrained model registry for Facebook Research models.

                    ┌──────────────┐
                    │   Context    │──── Encoder ────┐
                    │   (visible)  │                 │
   Image/Video ─────┤              │         ┌───────▼───────┐
                    │   Target     │         │   Predictor   │──── predicted repr
                    │   (masked)   │──┐      └───────────────┘          │
                    └──────────────┘  │                                 │
                                      │      ┌───────────────┐          │
                                      └──────│ Target Encoder│── target repr
                                        EMA  │   (frozen)    │          │
                                             └───────────────┘          │

                                             ┌───────────────┐          │
                                             │  Energy Loss  │◄─────────┘
                                             └───────────────┘

Why jepa-rs?

jepa-rsPython (PyTorch)
RuntimeNative binary, no Python/CUDA dependencyRequires Python + PyTorch + CUDA
InferenceSafetensors checkpoint loading, ONNX metadataPyTorch runtime
MemoryRust ownership, no GC pausesPython GC + PyTorch allocator
BackendAny burn backend (CPU, GPU, WebGPU, WASM)CUDA-centric
Type safetyCompile-time tensor shape checksRuntime shape errors
DeploymentSingle static binaryDocker + Python environment

Pretrained Models

jepa-rs supports loading official Facebook Research pretrained JEPA models:

ModelArchitectureParamsResolutionDatasetWeights
I-JEPA ViT-H/14ViT-Huge, patch 14632M224x224ImageNet-1KDownload | HuggingFace
I-JEPA ViT-H/16-448ViT-Huge, patch 16632M448x448ImageNet-1KDownload | HuggingFace
I-JEPA ViT-H/14ViT-Huge, patch 14632M224x224ImageNet-22KDownload
I-JEPA ViT-G/16ViT-Giant, patch 161.0B224x224ImageNet-22KDownload
V-JEPA ViT-L/16ViT-Large, patch 16304M224x224VideoMix2MDownload
V-JEPA ViT-H/16ViT-Huge, patch 16632M224x224VideoMix2MDownload

Quick Start

Installation

# Cargo.toml
[dependencies]
jepa-core   = "0.1.0"
jepa-vision = "0.1.0"
jepa-compat = "0.1.0"  # For ONNX + checkpoint loading

CLI

The jepa binary provides a unified CLI for the workspace:

# Install the CLI from crates.io
cargo install jepa

# Or install from the local workspace checkout
cargo install --path crates/jepa

# Launch the interactive TUI dashboard
jepa

# List pretrained models in the registry
jepa models

# Inspect a safetensors checkpoint
jepa inspect model.safetensors

# Analyze checkpoint with key remapping
jepa checkpoint model.safetensors --keymap ijepa --verbose

# Launch a training run
jepa train --preset vit-base-16 --steps 10 --batch-size 1 --lr 1e-3

# Train from a normal image directory tree with deterministic resize/crop/normalize
jepa train --preset vit-base-16 --steps 100 --batch-size 4 \
  --dataset-dir ./images/train --resize 256 --crop-size 224 --shuffle

# Train from a safetensors image tensor dataset [N, C, H, W]
jepa train --preset vit-base-16 --steps 100 --batch-size 1 \
  --dataset train.safetensors --dataset-key images

# Encode inputs through a safetensors checkpoint
jepa encode --model model.safetensors --preset vit-base-16

# Or through an ONNX model
jepa encode --model model.onnx --height 224 --width 224

The CLI train command now runs real strict masked-image optimization with AdamW and EMA. It chooses one input source per run:

  • --dataset-dir <PATH> for a recursive image-folder dataset (jpg, jpeg, png, webp) with decode, RGB conversion, shorter-side resize, center crop, CHW tensor conversion, and normalization
  • --dataset <FILE> --dataset-key <KEY> for a safetensors image tensor shaped [N, C, H, W]
  • no dataset flags for the synthetic random-tensor fallback

Image-folder preprocessing defaults to the preset image size for --crop-size and the ImageNet RGB normalization statistics when --mean and --std are omitted. Dataset loading is currently single-threaded. jepa encode executes real encoder weights for .safetensors and .onnx inputs; other extensions still fall back to the preset demo path.

Runnable Examples

The jepa crate now ships runnable examples under crates/jepa/examples/ that exercise the real training command instead of mocking the CLI path:

# Create a tiny recursive image-folder dataset under target/example-data/jepa/
cargo run -p jepa --example prepare_demo_image_folder

# Train for 2 steps on that generated image-folder dataset
cargo run -p jepa --example train_image_folder_demo

# Train for 2 steps with the synthetic fallback path
cargo run -p jepa --example train_synthetic_demo

The image-folder example deliberately uses a very small generated dataset (6 PNG files across nested subdirectories). That is enough for a meaningful smoke demo of recursive dataset discovery, decode, resize, crop, normalize, batching, masking, optimizer updates, and EMA without checking a large image corpus into git. It is not large enough to demonstrate real representation learning quality; it is an execution demo, not a benchmark dataset.

The TUI now incorporates these demos in the Training tab as a guided demo runner. Launch jepa, switch to tab 3, choose a demo with j/k, and press Enter to run it. The panel streams real run logs, step metrics, loss/energy charts, and a short interpretation of what happened.

The TUI Inference tab on 4 adds a separate guided walkthrough for encoder inference. It runs deterministic demo image patterns through a preset ViT, streams phase changes, per-sample latency and embedding statistics, and explains what the representation telemetry means. The walkthrough is intentionally a pipeline demo rather than a pretrained semantic benchmark.

If you want to run the CLI directly after generating the demo dataset:

cargo run -p jepa -- train --preset vit-small-16 --steps 2 --batch-size 2 \
  --dataset-dir target/example-data/jepa/demo-image-folder \
  --resize 256 --crop-size 224 --shuffle --dataset-limit 6

Loading SafeTensors Checkpoints

use jepa_compat::safetensors::load_checkpoint;
use jepa_compat::keymap::ijepa_vit_keymap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let mappings = ijepa_vit_keymap();
    let checkpoint = load_checkpoint("model.safetensors", &mappings)?;

    println!("Loaded {} tensors", checkpoint.len());
    for key in checkpoint.keys() {
        println!("  {}: {:?}", key, checkpoint.get(key).unwrap().shape);
    }
    Ok(())
}

Building JEPA Models from Scratch

use burn::prelude::*;
use burn_ndarray::NdArray;
use jepa_core::masking::{BlockMasking, MaskingStrategy};
use jepa_core::types::InputShape;
use jepa_vision::image::IJepaConfig;
use jepa_vision::vit::VitConfig;

type B = NdArray<f32>;

fn main() {
    let device = burn_ndarray::NdArrayDevice::Cpu;

    // Configure I-JEPA with ViT-Huge/14 (matches Facebook pretrained)
    let config = IJepaConfig {
        encoder: VitConfig::vit_huge_patch14(),
        predictor: jepa_vision::image::TransformerPredictorConfig {
            encoder_embed_dim: 1280,
            predictor_embed_dim: 384,
            num_layers: 12,
            num_heads: 12,
            max_target_len: 256,
        },
    };
    let model = config.init::<B>(&device);

    // Generate masks (I-JEPA block masking)
    let shape = InputShape::Image { height: 16, width: 16 }; // 224/14 = 16
    let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
    let masking = BlockMasking {
        num_targets: 4,
        target_scale: (0.15, 0.2),
        target_aspect_ratio: (0.75, 1.5),
    };
    let mask = masking.generate_mask(&shape, &mut rng);

    println!("Context tokens: {}, Target tokens: {}",
             mask.context_indices.len(), mask.target_indices.len());
}

Browse Available Models

use jepa_compat::registry::{list_models, find_model};

fn main() {
    for model in list_models() {
        println!("{}: {} ({}, {})",
            model.name,
            model.param_count_human(),
            model.architecture,
            model.pretrained_on);
    }

    // Search for a specific model
    if let Some(m) = find_model("vit-h/14") {
        println!("\nFound: {} with {} patches",
            m.name, m.num_patches());
    }
}

Architecture

jepa-rs/
├── jepa-core        Core traits, tensor wrappers, masking, energy, EMA
│   ├── Encoder          Trait for context/target encoders
│   ├── Predictor        Trait for latent predictors
│   ├── EnergyFn         L2, Cosine, SmoothL1 energy functions
│   ├── MaskingStrategy  Block, MultiBlock, Spatiotemporal, Object masking
│   ├── CollapseReg      VICReg, BarlowTwins collapse prevention
│   └── EMA              Exponential moving average with cosine schedule

├── jepa-vision      Vision transformers and JEPA models
│   ├── VitEncoder       ViT-S/B/L/H/G with 2D RoPE
│   ├── IJepa            I-JEPA pipeline (image)
│   ├── VJepa            V-JEPA pipeline (video, 3D tubelets)
│   ├── SlotAttention    Slot attention encoder for object-centric representations
│   └── Predictor        Transformer-based cross-attention predictor

├── jepa-world       World models and planning
│   ├── ActionPredictor  Action-conditioned latent prediction
│   ├── ObjectDynamics   Transformer dynamics predictor for object slots
│   ├── Planner          Random shooting planner with cost functions
│   ├── HierarchicalJepa Multi-level H-JEPA
│   └── ShortTermMemory  Sliding-window memory for temporal context

├── jepa-train       Training orchestration
│   ├── TrainConfig      Learning rate schedules, EMA config
│   ├── JepaComponents   Generic forward step orchestration
│   ├── CausalJepa       C-JEPA training loop (frozen encoder, object masking)
│   └── CheckpointMeta   Save/resume metadata

├── jepa-compat      Model compatibility and interop
│   ├── ModelRegistry     Pretrained model catalog (Facebook Research)
│   ├── SafeTensors       Load .safetensors checkpoints
│   ├── KeyMap            PyTorch → burn key remapping
│   └── OnnxModelInfo     ONNX metadata inspection and initializer loading

├── jepa             CLI and interactive TUI dashboard
│   ├── CLI               models, inspect, checkpoint, train, encode commands
│   └── TUI               Dashboard, Models, Training, Inference, Checkpoint, About tabs

└── jepa-web         Browser demo crate
    ├── WASM API          JS-callable training and inference helpers
    ├── Demo UI           HTML, JS, and CSS assets for local browser demos
    └── Backend status    CPU-backed path today; WebGPU scaffolding remains internal

All tensor-bearing APIs are generic over B: Backend, allowing transparent execution on CPU (NdArray), GPU (WGPU), or WebAssembly backends.

ONNX Support

jepa-rs provides ONNX metadata inspection and initializer loading through jepa-compat. This allows inspecting model structure, input/output specs, and importing weight initializers from .onnx files.

Current scope: metadata inspection and weight import are production-ready. Tract-based ONNX graph execution exists (OnnxSession, OnnxEncoder) but is not yet production-grade — it is functional for prototyping and testing.

Examples

ExampleDescriptionRun command
jepaInteractive TUI dashboardcargo run -p jepa
jepa modelsBrowse pretrained model registrycargo run -p jepa -- models
jepa trainLaunch a training runcargo run -p jepa -- train --preset vit-base-16
prepare_demo_image_folderGenerate a tiny recursive dataset for --dataset-dir demoscargo run -p jepa --example prepare_demo_image_folder
train_image_folder_demoRun the real jepa train image-folder path on generated imagescargo run -p jepa --example train_image_folder_demo
train_synthetic_demoRun the real jepa train synthetic fallback pathcargo run -p jepa --example train_synthetic_demo
ijepa_demoFull I-JEPA forward pass pipelinecargo run -p jepa-vision --example ijepa_demo
ijepa_train_loopTraining loop with metricscargo run -p jepa-vision --example ijepa_train_loop
world_model_planningWorld model with random shootingcargo run -p jepa-world --example world_model_planning
model_registryBrowse pretrained models (library)cargo run -p jepa-compat --example model_registry

The browser demo lives in crates/jepa-web. Its exported WASM path currently uses the deterministic CPU backend for reliable local demos and tests; see the architecture and production-gap docs below before treating it as a WebGPU-ready surface.

Build & Test

# Build everything
cargo build --workspace

# Run all tests
cargo test --workspace

# Lint
cargo clippy --workspace --all-targets -- -D warnings

# Format check
cargo fmt -- --check

# Generate docs
cargo doc --workspace --no-deps --open

# Run differential parity tests
scripts/run_parity_suite.sh

# Target a single crate
cargo test -p jepa-core
cargo test -p jepa-vision
cargo test -p jepa-compat
cargo test -p jepa-web

Extended quality gates

# Code coverage (requires cargo-llvm-cov)
cargo llvm-cov --workspace --all-features --fail-under-lines 80

# Fuzz testing (requires cargo-fuzz)
(cd fuzz && cargo fuzz run masking -- -runs=1000)

# Benchmark smoke test
cargo bench --workspace --no-run

Project Docs

Project Status

As of 2026-03-16, this project is alpha. It is suitable for research, local demos, and extension work; it is not yet suitable for unqualified production deployment of every advertised surface. In particular, strict video parity is still pending, ONNX graph execution remains a prototype path, and the browser demo uses the CPU-backed WASM path today.

What works

  • Complete I-JEPA, V-JEPA, and C-JEPA architectures with strict masked-encoder paths
  • CLI with 6 commands (models, inspect, checkpoint, train, encode, tui)
  • Interactive TUI dashboard with 6 tabs (Dashboard, Models, Training, Inference, Checkpoint, About)
  • Browser demo crate with deterministic CPU-backed WASM training and inference for local experimentation
  • SafeTensors checkpoint loading with automatic key remapping
  • ONNX metadata inspection and initializer loading
  • Pretrained model registry with download URLs
  • Differential parity tests against 3 checked-in strict image fixtures
  • Comprehensive test suite (500+ tests), property-based testing, fuzz targets
  • All standard ViT configs: ViT-S/16, ViT-B/16, ViT-L/16, ViT-H/14, ViT-H/16, ViT-G/16

Known limitations

  • The generic trainer slices tokens after encoder forward; strict pre-attention masking is available via IJepa::forward_step_strict and VJepa::forward_step_strict
  • ONNX support covers metadata inspection and initializer loading only, not graph execution
  • The jepa-web crate keeps burn-wgpu scaffolding in-tree, but the exported browser demo currently runs on the burn-ndarray CPU backend only
  • Differential parity runs in CI for strict image fixtures; broader video parity is pending
  • First-time crates.io release must be published in dependency order because the workspace crates depend on each other by version

JEPA Variants: What We Implement

The JEPA family has grown across several papers. Here is exactly what jepa-rs implements and how each component maps to a specific paper and reference codebase.

I-JEPA (Image)

PaperSelf-Supervised Learning from Images with a Joint-Embedding Predictive Architecture (Assran et al., CVPR 2023)
Reference codefacebookresearch/ijepa (archived)
jepa-rs structIJepa<B> in jepa-vision (crates/jepa-vision/src/image.rs)
What it doesSelf-supervised image representation learning. A ViT context-encoder sees only visible patches; a lightweight predictor predicts representations of masked target patches. The target-encoder is an EMA copy of the context-encoder.
MaskingBlockMasking — contiguous rectangular blocks on the 2D patch grid.
Faithful pathIJepa::forward_step_strict — filters tokens before encoder self-attention (matches the paper).
Approximate pathJepaComponents::forward_step in jepa-train — encodes full input then slices (post-encoder masking; cheaper but not faithful).
Parity status3 checked-in strict image fixtures verified in CI.

V-JEPA (Video)

PaperRevisiting Feature Prediction for Learning Visual Representations from Video (Bardes et al., 2024)
Reference codefacebookresearch/jepa
jepa-rs structVJepa<B> in jepa-vision (crates/jepa-vision/src/video.rs)
What it doesExtends I-JEPA to video. A ViT encoder processes 3D tubelets (space + time) with 3D RoPE.
MaskingSpatiotemporalMasking — contiguous 3D regions in the spatiotemporal grid.
Faithful pathVJepa::forward_step_strict — pre-attention masking.
Parity statusImplemented but strict video parity not yet proven (pending).

V-JEPA 2 features

PaperV-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning (Bardes et al., 2025)
Reference codefacebookresearch/vjepa2
jepa-rs supportNot a separate struct. The VJepa<B> struct can be configured with V-JEPA 2 features.
What we take from V-JEPA 2Cosine momentum schedule for EMA — CosineMomentumSchedule in jepa-core (Ema::with_cosine_schedule). Momentum ramps from base (e.g. 0.996) to 1.0 over training. Also: MultiBlockMasking strategy, ViT-Giant/14 preset.
What we don't implementThe full V-JEPA 2 training recipe, attentive probing, or the planning/action heads from the paper.

Hierarchical JEPA (H-JEPA) — experimental

PaperInspired by A Path Towards Autonomous Machine Intelligence (LeCun, 2022) — the original JEPA position paper describes hierarchical prediction as a long-term goal. No standalone H-JEPA paper exists yet.
jepa-rs structHierarchicalJepa<B> in jepa-world (crates/jepa-world/src/hierarchy.rs)
What it doesStacks multiple JEPA levels at different temporal strides (e.g. stride 2, 6, 24). Each level has its own encoder and predictor. This is experimental — no reference implementation exists.

Action-Conditioned World Model — experimental

PaperDraws from both the LeCun position paper and V-JEPA 2 (planning component).
jepa-rs structsAction<B>, ActionConditionedPredictor<B> trait, RandomShootingPlanner in jepa-world (crates/jepa-world/src/action.rs, crates/jepa-world/src/planner.rs)
What it doesPredicts next-state representations given current state + action. Supports random-shooting (CEM) planning. This is experimental.

C-JEPA (Causal)

PaperCausal-JEPA: Learning World Models through Object-Level Latent Interventions (Nam et al., 2025)
Reference codegalilai-group/cjepa
jepa-rs structsObjectMasking in jepa-core, SlotAttention<B> / SlotEncoder<B> in jepa-vision, CausalJepaComponents in jepa-train, ObjectDynamicsPredictor<B> in jepa-world
What it doesObject-centric JEPA: masks whole objects (not patches), uses a frozen encoder with slot attention, identity-anchored masked tokens, and joint history + future MSE loss. Enables causal reasoning and efficient CEM planning in object-representation space (~98% token reduction vs patch-based models).
MaskingObjectMasking — randomly partitions N object slots into context/target subsets.
TrainingFrozen encoder (no EMA), only slot attention and predictor are trained. CausalJepaComponents::forward_step in jepa-train.
PlanningObjectDynamicsPredictor + RandomShootingPlanner in jepa-world for CEM-based MPC in object space.

What about EB-JEPA?

EB-JEPA (Terver et al., 2026) is a separate lightweight Python library for energy-based JEPA. jepa-rs is not an implementation of EB-JEPA. We reference it for comparison only. The energy functions in jepa-core (L2, Cosine, SmoothL1) are standard loss formulations, not the EB-JEPA energy framework.

Quick summary

VariantPaperjepa-rs structStatus
I-JEPAAssran et al. 2023IJepa<B>Strict path implemented, parity verified
V-JEPABardes et al. 2024VJepa<B>Strict path implemented, parity pending
V-JEPA 2Bardes et al. 2025VJepa<B> + cosine EMA scheduleSelect features only
H-JEPALeCun 2022 (position paper)HierarchicalJepa<B>Experimental, no reference impl
C-JEPANam et al. 2025ObjectMasking, SlotAttention, CausalJepaComponents, ObjectDynamicsPredictorCore support implemented
World modelLeCun 2022 + V-JEPA 2ActionConditionedPredictor, RandomShootingPlannerExperimental
EB-JEPATerver et al. 2026Not implementedReferenced for comparison only

References

Papers

PaperFocus
A Path Towards Autonomous Machine IntelligenceJEPA position paper — hierarchical world models (LeCun, 2022)
I-JEPASelf-supervised image learning with masked prediction in latent space (Assran et al., CVPR 2023)
V-JEPAExtension to video with spatiotemporal masking (Bardes et al., 2024)
V-JEPA 2Video understanding, prediction, and planning (Bardes et al., 2025)
C-JEPAObject-centric world models via latent interventions (Nam et al., 2025)
Slot AttentionObject-centric learning with slot attention (Locatello et al., NeurIPS 2020)
EB-JEPALightweight energy-based JEPA library — referenced for comparison (Terver et al., 2026)

Official reference implementations

RepoModelsRelationship to jepa-rs
facebookresearch/ijepaI-JEPA (archived)Primary reference for IJepa<B> and key remapping
facebookresearch/jepaV-JEPAPrimary reference for VJepa<B>
facebookresearch/vjepa2V-JEPA 2Reference for cosine EMA schedule, ViT-G config
galilai-group/cjepaC-JEPAPrimary reference for ObjectMasking, SlotAttention, CausalJepaComponents
facebookresearch/eb_jepaEB-JEPA tutorialNot implemented — comparison only

Contributing

See CONTRIBUTING.md for guidelines.

License

MIT License. See LICENSE for details.


Built with burn and tract