Model Backbones
October 13, 2025 · View on GitHub
The networks/ package bundles all learnable components consumed by the
interfaces and trainers—transformer-based generators, latent encoders, and
decoders.
Supported Families
transformers/: DiT in both Linen and NNX forms, Lightning DiT variants, and Lightning DDT (decoupled decoder transformer) models tuned for REPA/RAE workflows. Utility scripts help port weights between PyTorch and NNX.unets/: U-Net architectures (EDM-style) for experimentation and future diffusion releases.encoders/: Pretrained latent autoencoders and feature extractors. Shipped presets include RGB passthrough, Stability VAE, RAE encoders, and DINOv2 backbones with optional register tokens.decoders/: Vision decoders (ViT-based) for reconstructing images from latent representations produced by RAE encoders.
Every module is implemented as an nnx.Module (with bridge helpers when
converting from Linen) so they can be sharded and checkpointed with the rest of
the training stack.
Encoder Contract
Encoders must expose encode and decode methods matching the signatures in
networks/encoders/sd_vae.py. In particular:
encode(x, sample_posterior=True, deterministic=True)returns latent codes and auxiliary stats if needed. Thedeterministicflag controls dropout behavior, not stochastic sampling.decode(z, deterministic=True)reconstructs pixel-space tensors, typically returning uint8 images scaled to[0, 255].
When adding a new encoder, follow the same call contract so trainers, samplers,
and evaluation scripts can swap implementations without code changes. Ensure
the module initializes weights (downloaded checkpoints, etc.) inside its
constructor and registers any RNG streams via nnx.Rngs.