Verify correctness in eval mode (because we have dropout)

May 7, 2019 ยท View on GitHub

import sys from collections import OrderedDict

PY2 = sys.version_info[0] == 2 _internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'}

class Scope(object): def init(self): self._modules = OrderedDict()

def _make_functional(module, params_box, params_offset): self = Scope() num_params = len(module._parameters) param_names = list(module._parameters.keys()) forward = type(module).forward.func if PY2 else type(module).forward for name, attr in module.dict.items(): if name in _internal_attrs: continue setattr(self, name, attr)

child_params_offset = params_offset + num_params
for name, child in module.named_children():
    child_params_offset, fchild = _make_functional(child, params_box, child_params_offset)
    self._modules[name] = fchild
    setattr(self, name, fchild)

def fmodule(*args, **kwargs):
    for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
        setattr(self, name, param)
    return forward(self, *args, **kwargs)

return child_params_offset, fmodule

def make_functional(module): params_box = [None] _, fmodule_internal = _make_functional(module, params_box, 0)

def fmodule(*args, **kwargs):
    params_box[0] = kwargs.pop('params')
    return fmodule_internal(*args, **kwargs)

return fmodule

################################################################################

import torch from torch import nn from torch.nn import functional as F

class Net(nn.Module): def init(self): super(Net, self).init() self.layers = nn.Sequential( nn.Conv2d(1, 10, kernel_size=5), nn.MaxPool2d(2), nn.ReLU(), nn.Conv2d(10, 20, kernel_size=5), nn.MaxPool2d(2), nn.ReLU(), nn.Dropout2d()) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10)

def forward(self, x):
    x = self.layers(x)
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

model = Net()

model.eval() eval_fmodel = make_functional(model) model.train() train_fmodel = make_functional(model)

Verify correctness in eval mode (because we have dropout)

model.eval() params = list(model.parameters()) x = torch.randn(10, 1, 28, 28) print(model(x).sum()) print(fmodel(x, params=params).sum())