RustTensor Library Architecture
May 13, 2025 ยท View on GitHub
This document provides an overview of the Rust Tensor Library's architecture, explaining the key components, their relationships, and design decisions.
Table of Contents
- High-Level Architecture
- Core Components
- Backend System
- Automatic Differentiation
- Hooks System
- Operations System
- Operator Overloading
- Optimizers
- Feature Flags
- Design Decisions
High-Level Architecture
The RustTensor Library is designed around a modular architecture with several key components:
graph TD
User[User Code] --> Tensor[Tensor<B: Backend>]
Tensor --> Backend[Backend Trait]
Tensor --> Ops[Operations]
Tensor --> Graph[Computation Graph]
Backend --> CPU[CPU Backend]
Backend --> CUDA[CUDA Backend]
Ops --> CPU_Ops[CPU Implementations]
Ops --> CUDA_Ops[CUDA Implementations]
Graph --> Autograd[Automatic Differentiation]
Autograd --> Optim[Optimizers]
classDef core fill:#f9f,stroke:#333,stroke-width:2px;
classDef backend fill:#bbf,stroke:#333,stroke-width:1px;
classDef ops fill:#bfb,stroke:#333,stroke-width:1px;
classDef user fill:#fbb,stroke:#333,stroke-width:1px;
class Tensor,Graph,Autograd core;
class Backend,CPU,CUDA backend;
class Ops,CPU_Ops,CUDA_Ops ops;
class User,Optim user;
Core Components
Tensor
The Tensor<B: Backend> struct is the central type in the library:
- Generic over Backend: Uses a type parameter
Bthat implements theBackendtrait - Data Storage: Contains data in backend-specific storage format
- Gradient Tracking: Optional gradient storage for automatic differentiation
- Computation Graph: Tracks operations for backward pass
- Device Information: Tracks where the tensor data resides (CPU/CUDA)
pub struct Tensor<B: Backend> {
data: Rc<RefCell<TensorData<B>>>,
_marker: PhantomData<B>,
}
pub struct TensorData<B: Backend> {
pub id: usize,
pub data: B::Storage,
pub grad: Option<B::Storage>,
pub requires_grad: bool,
pub op: Option<Op<B>>,
pub device: Device,
pub hooks: Vec<Box<dyn Hook<B>>>,
}
Type Aliases
For convenience, the library provides type aliases for common backend configurations:
pub type CpuTensor = Tensor<CpuBackend>;
pub type CudaTensor = Tensor<CudaBackend>;
Backend System
The Backend trait defines the interface that all backends must implement:
pub trait Backend: 'static + Sized {
type Storage: 'static;
// Creation methods
fn zeros(shape: &[usize]) -> Result<Self::Storage, Error>;
fn ones(shape: &[usize]) -> Result<Self::Storage, Error>;
fn from_array(array: Array) -> Result<Self::Storage, Error>;
// Data access methods
fn shape(tensor: &Self::Storage) -> &[usize];
fn copy_to_host(tensor: &Self::Storage) -> Result<Vec<f32>, Error>;
// Operations
fn add(a: &Self::Storage, b: &Self::Storage) -> Result<Self::Storage, Error>;
fn sub(a: &Self::Storage, b: &Self::Storage) -> Result<Self::Storage, Error>;
fn mul(a: &Self::Storage, b: &Self::Storage) -> Result<Self::Storage, Error>;
// ... many more operations
}
CPU Backend
The CpuBackend implements the Backend trait using ndarray for storage and computations:
- Storage Type:
Array(a wrapper aroundndarray::Array) - Performance: Can use system BLAS libraries via the
cpu_openblasfeature - Implementation: Uses vectorized operations where possible
CUDA Backend
The CudaBackend implements the Backend trait using CUDA for GPU acceleration:
- Storage Type: Custom CUDA memory management
- Kernels: Custom CUDA kernels for operations
- cuBLAS Integration: Uses cuBLAS for matrix operations
- Context Management: Thread-safe CUDA context handling
Automatic Differentiation
The library uses dynamic automatic differentiation (autograd):
- Computation Graph: Built dynamically during forward pass
- Reverse-Mode Differentiation: Computes gradients efficiently
- Operation Tracking: Each operation records its inputs and a closure for backward pass
- Gradient Accumulation: Handles multiple paths to the same node*
graph LR
A(Tensor A) --> C(Tensor C)
B(Tensor B) --> C
C --> D(Tensor D)
D -.backward.-> C
C -.backward.-> A
C -.backward.-> B
Hooks System
The hooks system provides a mechanism to monitor and modify tensor operations during both forward and backward passes:
- Forward Hooks: Executed after a tensor operation is performed
- Backward Hooks: Executed when gradients flow through a tensor during the backward pass
- Hook Registration: Hooks are registered with tensors and return a unique ID for later removal
- Hook Interface: All hooks implement the
Hooktrait
pub trait Hook<B: Backend>: Debug {
fn forward(&self, tensor: &Tensor<B>, input: &[&Tensor<B>], output: &Tensor<B>) -> Result<(), Error>;
fn backward(&self, tensor: &Tensor<B>, grad_input: &[Option<Tensor<B>>], grad_output: &Tensor<B>) -> Result<(), Error>;
}
Hook Types
- FnHook: A convenience wrapper around closures for quick hook creation
- Custom Hook Types: Users can implement the
Hooktrait for custom behavior
Use Cases
- Debugging: Monitoring intermediate values during forward and backward passes
- Gradient Clipping: Modifying gradients to prevent exploding gradients
- Feature Visualization: Capturing activations in neural networks
- Custom Regularization: Applying custom regularization during the backward pass
Operator Overloading
The library provides operator overloading for tensors, allowing for more expressive and readable code:
- Arithmetic Operators:
+,-,*,/for element-wise operations - Comparison Operators:
==,!=for tensor equality checks
Operator overloading is implemented for references to tensors to avoid unnecessary moves and clones:
// Implementation for &Tensor + &Tensor
impl<'a, 'b, B: Backend> Add<&'b Tensor<B>> for &'a Tensor<B> {
type Output = Result<Tensor<B>, Error>;
fn add(self, other: &'b Tensor<B>) -> Self::Output {
ops::add(self, other)
}
}
// Implementation for &Tensor * f32 (scalar multiplication)
impl<'a, B: Backend> Mul<f32> for &'a Tensor<B> {
type Output = Result<Tensor<B>, Error>;
fn mul(self, scalar: f32) -> Self::Output {
ops::mul_scalar(self, scalar)
}
}
Error Handling with Operators
Operator overloading in Rust doesn't support returning Result types directly, so the library uses a pattern where operators return Result<Tensor<B>, Error> which must be unwrapped:
// Using operator overloading with error handling
let c = (&a + &b)?; // Note the ? operator to handle potential errors
// Chaining operations
let result = (&a + &b)?.relu()?;
Benefits of Operator Overloading
- Readability: Mathematical expressions look more natural
- Expressiveness: Complex operations can be written concisely
- Familiarity: Similar to Python's tensor libraries (PyTorch, NumPy)
Implementation Details
- All operators delegate to the corresponding functions in the
opsmodule - Operators are implemented for references to avoid unnecessary cloning
- Both tensor-tensor and tensor-scalar operations are supported
- Broadcasting is handled automatically by the underlying operations
Operations System
Operations are implemented in a backend-agnostic way:
- Public API: Functions in the
opsmodule that work with any backend - Backend-Specific Implementation: Each operation delegates to the backend
- Gradient Registration: Operations register their backward pass during forward computation
- Broadcasting: Automatic shape broadcasting for compatible operations
Operation Categories
- Element-wise Operations: Add, Subtract, Multiply, Divide, etc.
- Matrix Operations: MatMul, Transpose
- Activation Functions: ReLU, Sigmoid, Tanh, etc.
- Reduction Operations: Sum, Mean, Max, etc.
- Shape Operations: View, Expand, Squeeze, etc.
- Neural Network Operations: Conv2D, MaxPool2D, etc.
Optimizers
The library provides several optimizers for neural network training:
- SGD: Basic stochastic gradient descent
- MomentumSGD: SGD with momentum
- Adam: Adaptive moment estimation
- Adagrad: Adaptive gradient algorithm
All optimizers follow a common interface:
pub trait Optimizer<B: Backend> {
fn step(&mut self) -> Result<(), Error>;
fn zero_grad(&mut self) -> Result<(), Error>;
}
Feature Flags
The library uses feature flags to control optional functionality:
cuda: Enables CUDA GPU supportserialization: Enables tensor serializationmnist: Enables MNIST dataset loading utilitiesdebug_logs: Enables detailed diagnostic loggingcpu_openblas: Enables OpenBLAS acceleration for CPU operations
Design Decisions
1. Generic Backend System
The library uses a trait-based approach for backends, allowing:
- Code reuse between backends
- Easy addition of new backends
- Backend-agnostic user code
2. Dynamic Computation Graph
Unlike static frameworks (e.g., TensorFlow 1.x), the library builds the computation graph dynamically:
- More flexible for research and experimentation
- Easier debugging
- More Pythonic/intuitive API
3. Reference Counting and Interior Mutability
The library uses Rc<RefCell<TensorData>> for tensor data:
- Allows multiple tensors to share the same data
- Enables in-place operations when possible
- Supports efficient view operations without copying data
4. Safety Considerations
- Internal mutation is controlled via
RefCell - Graph consistency is maintained via careful operation tracking
- Mutable operations are marked with warning comments
- Error handling uses
Resultthroughout the codebase
5. Performance Optimizations
- Custom CUDA kernels for critical operations
- cuBLAS integration for matrix operations
- Optional OpenBLAS integration for CPU
- Efficient memory management and reuse
Implementation Details
Memory Management
- CPU Backend: Uses Rust's memory management via
ndarray - CUDA Backend: Custom CUDA memory allocation and deallocation
- View Operations: Create new tensors that share underlying storage
- In-place Operations: Modify tensor data in-place when possible
Error Handling
- Result Type: Operations return
Result<T, Error>for robust error handling - Error Types: Specific error variants for different failure modes
- Error Propagation: Uses the
?operator throughout the codebase
Thread Safety
- CPU Backend: Thread-safe via Rust's ownership system
- CUDA Backend: Thread-safe via CUDA context management
- Note: The library is not designed for concurrent modification of the same tensor from multiple threads