cifar10_make_original.lua
October 25, 2016 ยท View on GitHub
require 'xlua' require 'sys'
local batches_folder = '/opt/rocks/cifar.torch/cifar-10-batches-t7'
local data = {} local labels = {}
for i=1,5 do local name = paths.concat(batches_folder, 'data_batch_'..i..'.t7') local part = torch.load(paths.concat(batches_folder, name), 'ascii') table.insert(data, part.data:view(3,32,32,-1)) table.insert(labels, part.labels:squeeze()) end
data = torch.ByteTensor.cat(data, 4) labels = torch.ByteTensor.cat(labels)
test_part = torch.load(paths.concat(batches_folder, 'test_batch.t7'), 'ascii') test_labels = test_part.labels test_data = test_part.data
local dataset = { trainData = { data = data:permute(4,1,2,3):clone(), labels = labels:add(1), size = function() return labels:numel() end, }, testData = { data = test_data:view(3,32,32,-1):permute(4,1,2,3):clone(), labels = test_labels:squeeze():add(1), size = function() return test_labels:numel() end, } }
print(dataset) print(dataset.trainData.labels:max()) print(dataset.testData.labels:max()) torch.save('cifar10_original.t7', dataset)