PyTorch to MXNet

July 31, 2018 · View on GitHub

This cheatsheet serves as a quick reference for PyTorch users.

Pytorch Tensor and MXNet NDArray

Tensor operation

We document PyTorch function names that are different than MXNet NDArray

FunctionPyTorchMXNet Gluon
Element-wise inverse cosinex.acos() or torch.acos(x)nd.arccos(x)
Batch Matrix product and accumulationtorch.addbmm(M, batch1, batch2)nd.linalg_gemm(M, batch1, batch2) Leading n-2 dim are reduced
Element-wise division of t1, t2, multiply v, and add ttorch.addcdiv(t, v, t1, t2)t + v*(t1/t2)
Matrix product and accumulationtorch.addmm(M, mat1, mat2)nd.linalg_gemm(M, mat1, mat2)
Outer-product of two vector add a matrixm.addr(vec1, vec2)Not available
Element-wise applies functionx.apply_(calllable)Not available, but there is nd.custom(x, 'op')
Element-wise inverse sinex.asin() or torch.asin(x)nd.arcsin(x)
Element-wise inverse tangentx.atan() or torch.atan(x)nd.arctan(x)
Tangent of two tensorx.atan2(y) or torch.atan2(x, y)Not available
batch matrix productx.bmm(y) or torch.bmm(x, x)nd.linalg_gemm2(x, y)
Draws a sample from bernoulli distributionx.bernoulli()Not available
Fills a tensor with number drawn from Cauchy distributionx.cauchy_()Not available
Splits a tensor in a given dimx.chunk(num_of_chunk)nd.split(x, num_outputs=num_of_chunk)
Limits the values of a tensor to between min and maxx.clamp(min, max)nd.clip(x, min, max)
Returns a copy of the tensorx.clone()x.copy()
Cross productx.cross(y)Not available
Cumulative product along an axisx.cumprod(1)Not available
Cumulative sum along an axisx.cumsum(1)Not available
Address of the first elementx.data_ptr()Not available
Creates a diagonal tensorx.diag()Not available
Computes norm of a tensorx.dist()nd.norm(x) Only calculate L2 norm
Computes Gauss error functionx.erf()Not available
Broadcasts/Expands tensor to new shapex.expand(3,4)x.broadcast_to([3, 4])
Fills a tensor with samples drawn from exponential distributionx.exponential_()nd.random_exponential()
Element-wise modx.fmod(3)nd.module(x, 3)
Fractional portion of a tensorx.frac()x - nd.trunc(x)
Gathers values along an axis specified by dimtorch.gather(x, 1, torch.LongTensor([[0,0],[1,0]]))nd.gather_nd(x, nd.array([[[0,0],[1,1]],[[0,0],[1,0]]]))
Solves least square & least normB.gels(A)Not available
Draws from geometirc distributionx.geometric_(p)Not available
Device context of a tensorprint(x) will print which device x is onx.context
Repeats tensorx.repeat(4,2)x.tile(4,2)
Data type of a tensorx.type()x.dtype
Scattertorch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)nd.scatter_nd(nd.array([1.23,1.23]), nd.array([[0,1],[2,3]]), (2,4))
Returns the shape of a tensorx.size()x.shape
Number of elements in a tensorx.numel()x.size
Returns this tensor as a NumPy ndarrayx.numpy()x.asnumpy()
Eigendecomposition for symmetric matrixe, v = a.symeig()v, e = nd.linalg.syevd(a)
Transposex.t()x.T
Sample uniformlytorch.uniform_()nd.sample_uniform()
Inserts a new dimesionx.unsqueeze()nd.expand_dims(x)
Reshapex.view(16)x.reshape((16,))
Veiw as a specified tensorx.view_as(y)x.reshape_like(y)
Returns a copy of the tensor after casting to a specified typex.type(type)x.astype(dtype)
Copies the value of one tensor to anotherdst.copy_(src)src.copyto(dst)
Returns a zero tensor with specified shapex = torch.zeros(2,3)x = nd.zeros((2,3))
Returns a one tensor with specified shapex = torch.ones(2,3)x = nd.ones((2,3)
Returns a Tensor filled with the scalar value 1, with the same size as inputy = torch.ones_like(x)y = nd.ones_like(x)

Functional

GPU

Just like Tensor, MXNet NDArray can be copied to and operated on GPU. This is done by specifying context.

FunctionPyTorchMXNet Gluon
Copy to GPUy = torch.FloatTensor(1).cuda()y = mx.nd.ones((1,), ctx=mx.gpu(0))
Convert to numpy arrayx = y.cpu().numpy()x = y.asnumpy()
Context scopewith torch.cuda.device(1):
    y= torch.cuda.FloatTensor(1)
with mx.gpu(1):
    y = mx.nd.ones((3,5))

Cross-device

Just like Tensor, MXNet NDArray can be copied across multiple GPUs.

FunctionPyTorchMXNet Gluon
Copy from GPU 0 to GPU 1x = torch.cuda.FloatTensor(1)
y=x.cuda(1)
x = mx.nd.ones((1,), ctx=mx.gpu(0))
y=x.as_in_context(mx.gpu(1))
Copy Tensor/NDArray on different GPUsy.copy_(x)x.copyto(y)

Autograd

variable wrapper vs autograd scope

Autograd package of PyTorch/MXNet enables automatic differentiation of Tensor/NDArray.

FunctionPyTorchMXNet Gluon
Recording computationx = Variable(torch.FloatTensor(1), requires_grad=True)
y = x * 2
y.backward()
x = mx.nd.ones((1,))
x.attach_grad()
with mx.autograd.record():
    y = x * 2
y.backward()

scope override (pause, train_mode, predict_mode)

Some operators (Dropout, BatchNorm, etc) behave differently in training and making predictions. This can be controlled with train_mode and predict_mode scope in MXNet. Pause scope is for codes that do not need gradients to be calculated.

FunctionPyTorchMXNet Gluon
Scope overrideNot availablex = mx.nd.ones((1,))
with autograd.train_mode():
    y = mx.nd.Dropout(x)
    with autograd.predict_mode():
        z = mx.nd.Dropout(y)

w = mx.nd.ones((1,))
w.attach_grad()
with autograd.record():
    y = x * w
    y.backward()
    with autograd.pause():
        w += w.grad

batch-end synchronization is needed

MXNet uses lazy evaluation to achieve superior performance. The Python thread just pushes the operations into the backend engine and then returns. In training phase batch-end synchronization is needed, e.g, asnumpy(), wait_to_read(), metric.update(...).

FunctionPyTorchMXNet Gluon
Batch-end synchronizationNot availablefor (data, label) in train_data:
    with autograd.record():
        output = net(data)
        L = loss(output, label)
        L.backward()
    trainer.step(data.shape[0])
    metric.update([label], [output])

Pytorch module and Gluon blocks

for new block definition, gluon needs name_scope

name_scope coerces gluon to give each parameter an appropriate name, indicating which model it belongs to.

FunctionPyTorchMXNet Gluon
New block definitionclass Net(torch.nn.Module):
    def __init__(self, D_in, D_out):
        super(Net, self).__init__()
        self.linear = torch.nn.Linear(D_in, D_out)
    def forward(self, x):
        return self.linear(x)
class Net(mx.gluon.Block):
    def __init__(self, D_in, D_out):
        super(Net, self).__init__()
        with self.name_scope():
            self.dense=mx.gluon.nn.Dense(D_out, in_units=D_in)
    def forward(self, x):
        return self.dense(x)

Parameter and Initializer

when creating new layers in pytorch, you do not need to specify its parameter initializer, and different layers have different default initializer. When you create new layers in gluon, you can specify its initializer or just leave it none. The parameters will finish initializing after calling net.initialize(<init method>) and all parameters will be initialized in init method except those layers whose initializer specified.

FunctionPyTorchMXNet Gluon
Get all parametersnet.parameters()net.collect_params()
Initialize networkNot Availablenet.initialize(mx.init.Xavier())
Specify layer initializerlayer = torch.nn.Linear(20, 10)
torch.nn.init.normal(layer.weight, 0, 0.01)
layer = mx.gluon.nn.Dense(10, weight_initializer=mx.init.Normal(0.01))

usage of existing blocks look alike

FunctionPyTorchMXNet Gluon
Usage of existing blocksy=net(x)y=net(x)

HybridBlock can be hybridized, and allows partial-shape info

HybridBlock supports forwarding with both Symbol and NDArray. After hybridized, HybridBlock will create a symbolic graph representing the forward computation and cache it. Most of the built-in blocks (Dense, Conv2D, MaxPool2D, BatchNorm, etc.) are HybridBlocks.

Instead of explicitly declaring the number of inputs to a layer, we can simply state the number of outputs. The shape will be inferred on the fly once the network is provided with some input.

FunctionPyTorchMXNet Gluon
partial-shape
hybridized
Not Availablenet = mx.gluon.nn.HybridSequential()
with net.name_scope():
    net.add(mx.gluon.nn.Dense(10))
net.hybridize()

SymbolBlock

SymbolBlock can construct block from symbol. This is useful for using pre-trained models as feature extractors.

FunctionPyTorchMXNet Gluon
SymbolBlockNot Availablealexnet = mx.gluon.model_zoo.vision.alexnet(pretrained=True, prefix='model_')
out = alexnet(inputs)
internals = out.get_internals()
outputs = [internals['model_dense0_relu_fwd_output']]
feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())

Pytorch optimizer vs Gluon Trainer

for gluon zero_grad is not necessary most of the time

zero_grad in optimizer(Pytorch) or Trainer(Gluon) clears the gradients of all parameters. In gluon, there is no need to clear the gradients every batch if grad_req = 'write'(default).

FunctionPytorchMXNet Gluon
clear the gradientsoptm = torch.optim.SGD(model.parameters(), lr=0.1)
optm.zero_grad()
loss_fn(model(input), target).backward()
optm.step()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
with autograd.record():
    loss = loss_fn(net(data), label)
loss.backward()
trainer.step(batch_size)

Multi-GPU training

FunctionPytorchMXNet Gluon
data parallelismnet = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
output = net(data)
ctx = [mx.gpu(i) for i in range(3)]
data = gluon.utils.split_and_load(data, ctx)
label = gluon.utils.split_and_load(label, ctx)
with autograd.record():
    losses = [loss(net(X), Y) for X, Y in zip(data, label)]
for l in losses:
    l.backward()

Distributed training

FunctionPytorchMXNet Gluon
distributed data parallelismtorch.distributed.init_process_group(...)
model = torch.nn.parallel.distributedDataParallel(model, ...)
store = kv.create('dist')
trainer = gluon.Trainer(net.collect_params(), ..., kvstore=store)

Monitoring

MXNet has pre-defined metrics

Gluon provide several predefined metrics which can online evaluate the performance of a learned model.

FunctionPytorchMXNet Gluon
metricNot availablemetric = mx.metric.Accuracy()
with autograd.record():
    output = net(data)
    L = loss(ouput, label)
    loss(ouput, label).backward()
trainer.step(batch_size)
metric.update(label, output)

Data visualization

tensorboardX(PyTorch) and dmlc-tensorboard(Gluon) can be used to visualize your network and plot quantitative metrics about the execution of your graph.

FunctionPyTorchMXNet Gluon
visualizationwriter = tensorboardX.SummaryWriter()
...
for name, param in model.named_parameters():
    grad = param.clone().cpu().data.numpy()
    writer.add_histogram(name, grad, n_iter)
...
writer.close()
summary_writer = tensorboard.FileWriter('./logs/')
...
for name, param in net.collect_params():
    grad = param.grad.asnumpy().flatten()
    s = tensorboard.summary.histogram(name, grad)
    summary_writer.add_summary(s)
...
tensorboard.summary_writer.close()

I/O and deploy

Data loading

Dataset and DataLoader are the basic components for loading data.

ClassPytorchMXNet Gluon
Dataset holding arraystorch.utils.data.TensorDataset(data_tensor, label_tensor)gluon.data.ArrayDataset(data_array, label_array)
Data loadertorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, drop_last=False)gluon.data.DataLoader(dataset, batch_size=None, shuffle=False, sampler=None, last_batch='keep', batch_sampler=None, batchify_fn=None, num_workers=0)
Sequentially applied samplertorch.utils.data.sampler.SequentialSampler(data_source)gluon.data.SequentialSampler(length)
Random order samplertorch.utils.data.sampler.RandomSampler(data_source)gluon.data.RandomSampler(length)

Some commonly used datasets for computer vision are provided in mx.gluon.data.vision package.

ClassPytorchMXNet Gluon
MNIST handwritten digits dataset.torchvision.datasets.MNISTmx.gluon.data.vision.MNIST
CIFAR10 Dataset.torchvision.datasets.CIFAR10mx.gluon.data.vision.CIFAR10
CIFAR100 Dataset.torchvision.datasets.CIFAR100mx.gluon.data.vision.CIFAR100
A generic data loader where the images are arranged in folders.torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)mx.gluon.data.vision.ImageFolderDataset(root, flag, transform=None)

Serialization

Serialization and De-Serialization are achieved by calling save_parameters and load_parameters.

ClassPytorchMXNet Gluon
Save model parameterstorch.save(the_model.state_dict(), filename)model.save_parameters(filename)
Load parametersthe_model.load_state_dict(torch.load(PATH))model.load_parameters(filename, ctx, allow_missing=False, ignore_extra=False)