Wide_ResNet_Lasagne.py

May 24, 2016 ยท View on GitHub

import lasagne from lasagne.nonlinearities import rectify, softmax from lasagne.layers import InputLayer, DenseLayer, DropoutLayer, batch_norm, BatchNormLayer from lasagne.layers import ElemwiseSumLayer, NonlinearityLayer, GlobalPoolLayer from lasagne.layers.dnn import Conv2DDNNLayer as ConvLayer from lasagne.init import HeNormal

def ResNet_FullPre_Wide(input_var=None, n=3, k=2): ''' Adapted from https://github.com/Lasagne/Recipes/tree/master/papers/deep_residual_learning.

Tweaked to be consistent with 'Identity Mappings in Deep Residual Networks', Kaiming He et al. 2016 (https://arxiv.org/abs/1603.05027)

And 'Wide Residual Networks', Sergey Zagoruyko, Nikos Komodakis 2016 (http://arxiv.org/pdf/1605.07146v1.pdf)
'''
n_filters = {0:16, 1:16*k, 2:32*k, 3:64*k}

# create a residual learning building block with two stacked 3x3 convlayers and dropout
def residual_block(l, increase_dim=False, first=False, filters=16):
    if increase_dim:
        first_stride = (2,2)
    else:
        first_stride = (1,1)

    if first:
        # hacky solution to keep layers correct
        bn_pre_relu = l
    else:
        # contains the BN -> ReLU portion, steps 1 to 2
        bn_pre_conv = BatchNormLayer(l)
        bn_pre_relu = NonlinearityLayer(bn_pre_conv, rectify)

    # contains the weight -> BN -> ReLU portion, steps 3 to 5
    conv_1 = batch_norm(ConvLayer(bn_pre_relu, num_filters=filters, filter_size=(3,3), stride=first_stride, nonlinearity=rectify, pad='same', W=HeNormal(gain='relu')))

    dropout = DropoutLayer(conv_1, p=0.3)

    # contains the last weight portion, step 6
    conv_2 = ConvLayer(dropout, num_filters=filters, filter_size=(3,3), stride=(1,1), nonlinearity=None, pad='same', W=HeNormal(gain='relu'))

    # add shortcut connections
    if increase_dim:
        # projection shortcut, as option B in paper
        projection = ConvLayer(l, num_filters=filters, filter_size=(1,1), stride=(2,2), nonlinearity=None, pad='same', b=None)
        block = ElemwiseSumLayer([conv_2, projection])
    elif first:
        # projection shortcut, as option B in paper
        projection = ConvLayer(l, num_filters=filters, filter_size=(1,1), stride=(1,1), nonlinearity=None, pad='same', b=None)
        block = ElemwiseSumLayer([conv_2, projection])
    else:
        block = ElemwiseSumLayer([conv_2, l])

    return block

# Building the network
l_in = InputLayer(shape=(None, 3, 32, 32), input_var=input_var)

# first layer=
l = batch_norm(ConvLayer(l_in, num_filters=n_filters[0], filter_size=(3,3), stride=(1,1), nonlinearity=rectify, pad='same', W=HeNormal(gain='relu')))

# first stack of residual blocks
l = residual_block(l, first=True, filters=n_filters[1])
for _ in range(1,n):
    l = residual_block(l, filters=n_filters[1])

# second stack of residual blocks
l = residual_block(l, increase_dim=True, filters=n_filters[2])
for _ in range(1,n):
    l = residual_block(l, filters=n_filters[2])

# third stack of residual blocks
l = residual_block(l, increase_dim=True, filters=n_filters[3])
for _ in range(1,n):
    l = residual_block(l, filters=n_filters[3])

bn_post_conv = BatchNormLayer(l)
bn_post_relu = NonlinearityLayer(bn_post_conv, rectify)

# average pooling
avg_pool = GlobalPoolLayer(bn_post_relu)

# fully connected layer
network = DenseLayer(avg_pool, num_units=10, W=HeNormal(), nonlinearity=softmax)

return network