cifar10_make_original.lua

October 25, 2016 ยท View on GitHub

require 'xlua' require 'sys'

local batches_folder = '/opt/rocks/cifar.torch/cifar-10-batches-t7'

local data = {} local labels = {}

for i=1,5 do local name = paths.concat(batches_folder, 'data_batch_'..i..'.t7') local part = torch.load(paths.concat(batches_folder, name), 'ascii') table.insert(data, part.data:view(3,32,32,-1)) table.insert(labels, part.labels:squeeze()) end

data = torch.ByteTensor.cat(data, 4) labels = torch.ByteTensor.cat(labels)

test_part = torch.load(paths.concat(batches_folder, 'test_batch.t7'), 'ascii') test_labels = test_part.labels test_data = test_part.data

local dataset = { trainData = { data = data:permute(4,1,2,3):clone(), labels = labels:add(1), size = function() return labels:numel() end, }, testData = { data = test_data:view(3,32,32,-1):permute(4,1,2,3):clone(), labels = test_labels:squeeze():add(1), size = function() return test_labels:numel() end, } }

print(dataset) print(dataset.trainData.labels:max()) print(dataset.testData.labels:max()) torch.save('cifar10_original.t7', dataset)