Skip to content

Commit 1be6324

Browse files
author
Yinda Zhang
committed
create repository for releasing
0 parents  commit 1be6324

20 files changed

+720
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.zip
2+
*.t7
3+
*.DS*

BatchIterator.lua

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
require 'image'
2+
require 'utils'
3+
4+
local BatchIterator = torch.class('BatchIterator')
5+
6+
function BatchIterator:__init(config, train_set, test_set)
7+
8+
self.batch_size = config.batch_size or 128
9+
self.pixel_means = config.pixel_means or {0, 0, 0}
10+
self.mr = config.mr
11+
12+
self.train = {}
13+
self.test = {}
14+
15+
self.train.data = train_set
16+
self.test.data = test_set
17+
if #train_set > 0 then
18+
self.train.order = torch.randperm(#self.train.data)
19+
else
20+
self.train.order = torch.Tensor(0);
21+
end
22+
-- self.test.order = torch.randperm(#self.test.data)
23+
self.test.order = torch.range(1,#self.test.data)
24+
self.train.id = 1
25+
self.test.id = 1
26+
27+
self.epoch = 0
28+
29+
end
30+
31+
function BatchIterator:setBatchSize(batch_size)
32+
self.batch_size = batch_size or 128
33+
end
34+
35+
function BatchIterator:nextEntry(set)
36+
local i = self[set].i or 1
37+
self[set].i = i
38+
if i > #self[set].data then
39+
if set == "train" then
40+
self[set].order = torch.randperm(#self[set].data)
41+
end
42+
i = 1
43+
self.epoch = self.epoch + 1
44+
end
45+
46+
local index = self[set].order[i]
47+
self[set].i = self[set].i + 1
48+
return self[set].data[index]
49+
end
50+
51+
function BatchIterator:currentName(set)
52+
local i = self[set].i
53+
local index = self[set].order[i-1]
54+
return self[set].data[index].name
55+
end
56+
57+
function BatchIterator:nextBatch(set, config)
58+
-- print(use_photo_realistic)
59+
-- local use_pr = use_photo_realistic or true
60+
-- print(use_photo_realistic)
61+
62+
local batch = {}
63+
batch.input = {}
64+
batch.output = {}
65+
batch.valid = {}
66+
67+
for i = 1, self.batch_size do
68+
local entry = self:nextEntry(set)
69+
70+
if set == "train" then
71+
72+
while not (file_exists(entry.input_file) and file_exists(entry.input_valid) and file_exists(entry.output_file)) do
73+
entry = self:nextEntry(set)
74+
end
75+
76+
local output = image.load(entry.output_file)
77+
local valid = image.load(entry.input_valid)
78+
79+
-- define your data process here
80+
output = output:add(-0.5):mul(2)
81+
output = output:index(2,torch.range(1,output:size(2),2):long())
82+
output = output:index(3,torch.range(1,output:size(3),2):long())
83+
valid = valid:index(2,torch.range(1,valid:size(2),2):long())
84+
valid = valid:index(3,torch.range(1,valid:size(3),2):long())
85+
-- end
86+
87+
table.insert(batch.output, output)
88+
table.insert(batch.valid, valid)
89+
90+
if config.verbose then
91+
print(string.format("output max: %f, min: %f, size: %d %d", output:max(), output:min(), output:size(2), output:size(3)))
92+
print(string.format("valid max: %f, min: %f, size: %d %d", valid:max(), valid:min(), valid:size(2), valid:size(3)))
93+
end
94+
end
95+
96+
97+
local input = image.load(entry.input_file)
98+
99+
-- process your input here
100+
input = input[{{1,3},{},{}}]
101+
for ch = 1, 3 do
102+
if math.max(unpack(self.pixel_means)) < 1 then
103+
input[{ch, {}, {}}]:add(-self.pixel_means[ch])
104+
else
105+
input[{ch, {}, {}}]:add(-self.pixel_means[ch] / 255)
106+
end
107+
end
108+
input = input:index(2,torch.range(1,input:size(2),2):long())
109+
input = input:index(3,torch.range(1,input:size(3),2):long())
110+
-- end
111+
112+
table.insert(batch.input, input)
113+
if config.verbose then
114+
print(string.format("input max: %f, min: %f, size: %d %d", input:max(), input:min(), input:size(2), input:size(3)))
115+
end
116+
end
117+
118+
-- format img
119+
local ch, h, w = batch.input[1]:size(1), batch.input[1]:size(2), batch.input[1]:size(3)
120+
batch.input = torch.cat(batch.input, 1):view(self.batch_size, ch, h, w)
121+
122+
-- ch, h, w= batch.input[1]:size(1), batch.input[1]:size(2), batch.input[1]:size(3)
123+
-- batch.input = torch.cat(batch.input):view(self.batch_size, ch, h, w)
124+
-- print(string.format("input size: %d %d %d %d", batch.input:size()))
125+
126+
if set == "train" then
127+
ch, h, w = batch.output[1]:size(1), batch.output[1]:size(2), batch.output[1]:size(3)
128+
batch.output = torch.cat(batch.output, 1):view(self.batch_size, ch, h, w)
129+
ch, h, w = batch.valid[1]:size(1), batch.valid[1]:size(2), batch.valid[1]:size(3)
130+
batch.valid = torch.cat(batch.valid, 1):view(self.batch_size, ch, h, w)
131+
end
132+
133+
return batch
134+
end

config.lua

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
require "utils"
2+
--- All parameters goes here
3+
local config = config or {}
4+
5+
function config.parse(arg)
6+
local cmd = torch.CmdLine()
7+
cmd:text()
8+
cmd:text('Multi-Task Classification FCN')
9+
cmd:text()
10+
-- Parameters
11+
12+
-- model configuration
13+
cmd:option('-model', 'model_deep.lua', 'model file')
14+
cmd:option('-input_channel', 3, '# of input channels')
15+
cmd:option('-output_channel', 3, '# of output channels')
16+
17+
-- testing
18+
cmd:option('-test_model', '', 'model used for testing')
19+
cmd:option('-result_path', './result/', 'path to save result')
20+
cmd:option('-max_count', 1000000, 'max number of data to test')
21+
22+
-- data loader
23+
cmd:option('-train_file', '', 'train file, compulsory');
24+
cmd:option('-test_file', './image/test_list.txt', 'test file, compulsory');
25+
26+
-- training and testing
27+
cmd:option('-gpuid', 1, 'gpu id')
28+
cmd:option('-optim_state', {rho=0.95, eps=1e-6, learningRate=1e-3, learningRateMin=1e-7, momentum=0.9}, 'optim state')
29+
cmd:option('-lr_decay', 150000, 'iterations between lr decreses')
30+
cmd:option('-lr_decay_t', 5, 'lr decay times')
31+
cmd:option('-nb_epoch', 20, 'number of epoches')
32+
cmd:option('-batch_size', 10, 'batch size')
33+
cmd:option('-pixel_means', {128, 128, 128}, 'Pixel mean values (RGB order)')
34+
35+
-- resume
36+
cmd:option('-resume_training', false, 'whether resume training')
37+
cmd:option('-saved_model_weights', '', 'path to saved model weights')
38+
cmd:option('-saved_optim_state', '', 'path to saved model weights')
39+
40+
-- finetune
41+
cmd:option('-finetune', false, '')
42+
cmd:option('-finetune_model', '', '')
43+
cmd:option('-finetune_init_lr', 1e-4, '')
44+
45+
-- save/print/log
46+
cmd:option('-snapshot_iters', 10000, 'Iterations between snapshots (used for saving the network)')
47+
cmd:option('-print_iters', 20, 'Iterations between print')
48+
cmd:option('-log_iters', 20, 'Iterations between log')
49+
cmd:option('-log_path','./logs/','Path to be used for logging')
50+
cmd:option('-ps', '', 'prefix: path&name to model and snapshot')
51+
cmd:option('-verbose', false, 'show more message')
52+
53+
-- Parsing the command line
54+
config = cmd:parse(arg or {})
55+
config.colors = {{0, 0, 0}, -- black
56+
{1, 0, 0}, -- red
57+
{0, 1, 0}, -- green
58+
{0, 0, 1}, -- blue
59+
{1, 1, 0}, -- yellow
60+
{1, 0, 1}, -- magenta
61+
{0, 1, 1}, -- cyan
62+
{1, 1, 1} -- white
63+
}
64+
65+
return config
66+
end
67+
68+
return config

image/0000_color.png

341 KB
Loading

image/0001_color.png

358 KB
Loading

image/0008_color.png

319 KB
Loading

image/0013_color.png

340 KB
Loading

image/0014_color.png

255 KB
Loading

image/test_list.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
./image/0000
2+
./image/0001
3+
./image/0008
4+
./image/0013
5+
./image/0014

main_test.lua

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
require 'nn'
2+
require 'cutorch'
3+
require 'cunn'
4+
require 'cudnn'
5+
require 'nngraph'
6+
require 'optim'
7+
require 'image'
8+
9+
require 'BatchIterator'
10+
require 'utils'
11+
-- require 'hdf5'
12+
13+
local config = dofile('config.lua')
14+
config = config.parse(arg)
15+
print(config)
16+
cutorch.setDevice(config.gpuid)
17+
18+
local tmp1 = split(config.test_model, "/")
19+
config.result_path = config.result_path .. "/" .. string.sub(tmp1[#tmp1],1,-4)
20+
21+
22+
config.result_path = config.result_path .. "_nyu_sample_test/"
23+
24+
os.execute("mkdir " .. config.result_path)
25+
26+
-- local model = dofile('model_multi_task.lua')(config.do_normal, config.do_semantic, config.do_boundary, config.do_room)
27+
local model = dofile(config.model)(config)
28+
29+
parameters, gradParameters = model:getParameters()
30+
model:cuda()
31+
parameters, gradParameters = model:getParameters()
32+
parameters:copy(torch.load(config.test_model))
33+
34+
-- dataset
35+
local train_data = {}
36+
local test_data = loadData(config.test_file, config)
37+
local batch_iterator = BatchIterator(config, train_data, test_data)
38+
batch_iterator:setBatchSize(1)
39+
40+
local test_count = 0
41+
42+
while batch_iterator.epoch==0 and test_count<=config.max_count do
43+
local batch = batch_iterator:nextBatch('test', config)
44+
local currName = batch_iterator:currentName('test')
45+
print(currName)
46+
local k = split(currName, "/")
47+
if config.matterport then
48+
saveName = k[#k-2] .. "_" .. k[#k]
49+
else
50+
saveName = k[#k-1] .. "_" .. k[#k]
51+
end
52+
print(string.format("Testing %s", saveName))
53+
54+
55+
local inputs = batch.input
56+
inputs = inputs:contiguous():cuda()
57+
local outputs = model:forward(inputs)
58+
59+
local ch, h, w = 0, 0, 0
60+
local normal_est, normal_mask, normal_gnd, f_normal, df_do_normal, normal_outputs = nil,nil,nil,nil,nil,nil
61+
62+
normal_est = outputs
63+
ch, h, w = normal_est:size(2), normal_est:size(3), normal_est:size(4)
64+
normal_est = normal_est:permute(1, 3, 4, 2):contiguous()
65+
normal_est = normal_est:view(-1, ch)
66+
local normalize_layer = nn.Normalize(2):cuda()
67+
normal_outputs = normalize_layer:forward(normal_est)
68+
normal_outputs = normal_outputs:view(1, h, w, ch)
69+
normal_outputs = normal_outputs:permute(1, 4, 2, 3):contiguous()
70+
normal_outputs = normal_outputs:view( ch, h, w)
71+
normal_outputs = normal_outputs:float()
72+
73+
image.save(string.format("%s%s_normal_est.png", config.result_path, saveName), normal_outputs:add(1):mul(0.5))
74+
75+
76+
test_count = test_count + 1
77+
end
78+
79+
print("Finish!")
80+
81+

0 commit comments

Comments
 (0)