-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathCCNNMetric.lua
executable file
·457 lines (317 loc) · 12.5 KB
/
CCNNMetric.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
local cnnMetric = {}
---------------------------------------------------------------------------------
------------------- User interface functions ------------------------------------
---------------------------------------------------------------------------------
-- Function returns embedding net given network name
function cnnMetric.getEmbeddNet(metricName)
assert( metricName == 'fst-mb' or
metricName == 'fst-kitti' or
metricName == 'acrt-mb' or
metricName == 'fst-mb-4x' or
metricName == 'acrt-kitti' or
metricName == 'fst-kitti-4x' or
metricName == 'fst-xxl','wrong network name!')
local nbConvLayers
local nbFeatureMap
local kernel
if( metricName == 'fst-mb' ) then
nbConvLayers = 5
nbFeatureMap = 64
kernel = 3
elseif( metricName == 'fst-mb-4x' ) then
nbConvLayers = 5
nbFeatureMap = 256
kernel = 3
elseif( metricName == 'acrt-mb' ) then
nbConvLayers = 5
nbFeatureMap = 112
kernel = 3
elseif( metricName == 'fst-kitti' ) then
nbConvLayers = 4
nbFeatureMap = 64
kernel = 3
elseif( metricName == 'fst-kitti-4x' ) then
nbConvLayers = 4
nbFeatureMap = 256
kernel = 3
elseif( metricName == 'acrt-kitti' ) then
nbConvLayers = 4
nbFeatureMap = 112
kernel = 3
elseif( metricName == 'fst-xxl' ) then
nbConvLayers = 12
nbFeatureMap = 64
kernel = 3
end
local embedNet = cnnMetric.embeddNet( nbConvLayers, nbFeatureMap, kernel )
--if ( metricName == 'acrt-mb' or metricName == 'acrt-kitti' ) then
-- embedNet:add( nn.ReLU() ) -- add nonlinearity to last layer of embed net for accurate architecture
--end
return embedNet
end
-- Function returns head net given network name
function cnnMetric.getHeadNet(metricName)
assert( metricName == 'fst-mb' or
metricName == 'fst-mb-4x' or
metricName == 'fst-kitti' or
metricName == 'acrt-mb' or
metricName == 'acrt-kitti' or
metricName == 'fst-kitti-4x' or
metricName == 'fst-xxl','wrong network name!')
local headNet
if( metricName == 'fst-mb' ) then
local nbFeatureMap = 64
headNet = cnnMetric.cosHead(nbFeatureMap)
elseif( metricName == 'fst-mb-4x' ) then
local nbFeatureMap = 256
headNet = cnnMetric.cosHead(nbFeatureMap)
elseif( metricName == 'acrt-mb' ) then
local nbFeatureMap = 112
local nbFcLayers = 3
local nbFcUnits = 384
headNet = cnnMetric.fcHead(nbFeatureMap, nbFcLayers, nbFcUnits)
elseif( metricName == 'fst-kitti' ) then
local nbFeatureMap = 64
headNet = cnnMetric.cosHead(nbFeatureMap)
elseif ( metricName == 'fst-kitti-4x' ) then
local nbFeatureMap = 256
headNet = cnnMetric.cosHead(nbFeatureMap)
elseif( metricName == 'acrt-kitti' ) then
local nbFeatureMap = 112 --
local nbFcLayers = 4
local nbFcUnits = 384
headNet = cnnMetric.fcHead(nbFeatureMap, nbFcLayers, nbFcUnits)
elseif( metricName == 'fst-xxl' ) then
local nbFeatureMap = 112
headNet = cnnMetric.cosHead(nbFeatureMap)
-- local nbFeatureMap = 112 --
-- local nbFcLayers = 4
-- local nbFcUnits = 384
-- headNet = cnnMetric.fcHead(nbFeatureMap, nbFcLayers, nbFcUnits)
end
return headNet
end
-- Function parse siamese net into embedding net and head net
-- (output headNet and embedNet share storage with original net)
function cnnMetric.parseSiamese(siamNet)
local embedNet = siamNet.modules[1].modules[1]:clone('weight','bias', 'gradWeight','gradBias');
embedNet:remove(8) -- delete squeeze and transpose
embedNet:remove(8)
local headNet = siamNet.modules[3]:clone('weight','bias', 'gradWeight','gradBias');
return embedNet, headNet
end
-- Function sets up siamese net, given headNet and embedNet
-- (new netwok use same storage for parameters)
function cnnMetric.setupSiamese(embedNet, headNet, width, disp_max)
local siamNet = nn.Sequential()
local hpatch = cnnMetric.getHPatch(embedNet)
--local activeRows
--local activeCols
--local activeIdx
--local nbActivePairs
local active_pairs
---------------------- Find active elementes of similarity matrix --------------------
do
-- local row = torch.Tensor(width-hpatch*2, 1)
-- row[{{},{1}}] = torch.range(1, width-hpatch*2)
-- local rows = torch.repeatTensor(row, 1, width-hpatch*2)
-- local col = row:t():clone()
-- local cols = torch.repeatTensor(col, width-hpatch*2, 1)
-- local disp = rows - cols
-- local mask = disp:le(disp_max):cmul( disp:gt(0) )
-- activeIdx = (cols-1) + (rows-1)*(width-hpatch*2) + 1
-- activeIdx = activeIdx[mask]
-- activeCols = cols[mask]
-- activeRows = rows[mask]
-- nbActivePairs = mask:ne(0):sum()
mask = torch.ones(width-2*hpatch, width-2*hpatch)*2
mask = torch.triu(torch.tril(mask,-1),-disp_max)
active_pairs = mask:nonzero()
end
------------------------------------------------------------------
local twoEmbedNet
---------------------- Make two embeded nets -------------------------------------------
-- all networks share parameters and gradient storage with the embed network
do
twoEmbedNet = nn.ParallelTable()
-- make two towers
local embedNet0 = embedNet:clone('weight','bias', 'gradWeight','gradBias');
-- nb_features x 1 x width-hpatch*2 ==> width-hpatch*2 x nb_features
embedNet0:add(nn.Squeeze(2))
embedNet0:add(nn.Transpose({1,2}))
-- second tower is clone of the first one
local embedNet1 = embedNet0:clone('weight','bias', 'gradWeight','gradBias');
twoEmbedNet:add(embedNet0)
twoEmbedNet:add(embedNet1)
end
siamNet:add(twoEmbedNet)
----------------------------------------------------------------------------------------
siamNet:add(nn.headNetMulti(active_pairs, headNet))
--siamNet:add(twoEmbedNet)
--local pairSelNet
------------------------- Make two pair selecting nets ------------------------------------
--do
-- pairSelNet = nn.ParallelTable()
-- pairSelNet:add(nn.fixedIndex(1, activeRows:long()))
-- pairSelNet:add(nn.fixedIndex(1, activeCols:long()))
--end
-------------------------------------------------------------------------------------------
--siamNet:add(pairSelNet)
--headNet_copy = headNet:clone('weight','bias', 'gradWeight','gradBias');
--siamNet:add(headNet_copy)
--siamNet:add(nn.copyElements(torch.LongStorage{nbActivePairs}, torch.LongStorage{width-hpatch*2, width-hpatch*2}, torch.range(1, nbActivePairs), activeIdx))
return siamNet
end
--function nnMetric.getTestNet(net)
-- for i = 1,#net.modules do
-- if torch.typename(net.modules[i]) == 'nn.SpatialConvolution' or
-- torch.typename(net.modules[i]) == 'cudnn.SpatialConvolution' then
-- local nInputPlane = net.modules[i].nInputPlane
-- local nOutputPlane = net.modules[i].nInputPlane
-- end
-- end
-- net:replace(function(module)
-- if torch.typename(module) == 'nn.SpatialConvolution' then
-- local nInputPlane = module.nInputPlane
-- local nOutputPlane = module.nOutputPlane
-- local weight = module.weight;
-- local bias = module.weight;
-- local substitute = nn.SpatialConvolution1_fw(nInputPlane, nOutputPlane))
-- substitute.weight:copy(weight)
-- substitute.bias:copy(bias)
-- return substitute
-- else
-- return module
-- end
-- end)
---- return net;
---- net_te_all = {}
---- for i, v in ipairs(net_te.modules) do table.insert(net_te_all, v) end
---- for i, v in ipairs(net_te2.modules) do table.insert(net_te_all, v) end
---- local finput = torch.CudaTensor()
---- local i_tr = 1
---- local i_te = 1
---- while i_tr <= net_tr:size() do
---- local module_tr = net_tr:get(i_tr)
---- local module_te = net_te_all[i_te]
---- local skip = {['nn.Reshape']=1, ['nn.Dropout']=1}
---- while skip[torch.typename(module_tr)] do
---- i_tr = i_tr + 1
---- module_tr = net_tr:get(i_tr)
---- end
---- if module_tr.weight then
---- -- print(('tie weights of %s and %s'):format(torch.typename(module_te), torch.typename(module_tr)))
---- assert(module_te.weight:nElement() == module_tr.weight:nElement())
---- assert(module_te.bias:nElement() == module_tr.bias:nElement())
---- module_te.weight = torch.CudaTensor(module_tr.weight:storage(), 1, module_te.weight:size())
---- module_te.bias = torch.CudaTensor(module_tr.bias:storage(), 1, module_te.bias:size())
---- end
---- i_tr = i_tr + 1
---- i_te = i_te + 1
---- end
--end
function cnnMetric.padBoundary(net)
for i = 1,#net.modules do
if torch.typename(net.modules[i]) == 'cudnn.SpatialConvolution' or
torch.typename(net.modules[i]) == 'nn.SpatialConvolution' or
torch.typename(net.modules[i]) == 'cunn.SpatialConvolution' then
net.modules[i].padW = 1
net.modules[i].padH = 1
end
end
return net;
end
function cnnMetric.isParametric(net)
for i = 1,#net.modules do
local module = net:get(i)
if torch.typename(module) == 'cudnn.Linear' or
torch.typename(module) == 'cudnn.SpatialConvolution' or
torch.typename(module) == 'nn.SpatialConvolution' or
torch.typename(module) == 'nn.Linear' or
torch.typename(module) == 'cunn.SpatialConvolution' or
torch.typename(module) == 'cunn.Linear' then
return true;
end
end
return false
end
function cnnMetric.getHPatch(net)
local ws = 1
for i = 1,#net.modules do
local module = net:get(i)
if torch.typename(module) == 'cudnn.SpatialConvolution' or
torch.typename(module) == 'nn.SpatialConvolution' or
torch.typename(module) == 'cunn.SpatialConvolution' then
ws = ws + module.kW - 1
end
end
return (ws-1)/2
end
----------------------------------------------------------------
------------------------------ embeding net --------------------
----------------------------------------------------------------
-- Given tensor 1 x hpatch*2 x width embedding net produces feature
-- tensor of size 64 x 1 x width-hpatch*2
function cnnMetric.embeddNet( nbConvLayers, nbFeatureMap, kernel )
local fNet = nn.Sequential();
for nConvLayer = 1, nbConvLayers do
-- if first layer, then input is just gray image
-- otherwise input is featuremaps of previous layer
local nInputPlane
if( nConvLayer == 1 ) then
nInputPlane = 1
else
nInputPlane = nbFeatureMap
end
local nOutputPlane = nbFeatureMap; -- number of feature maps in layer
local kW = kernel; -- kernel width and height
local kH = kernel;
local dW = 1; -- step of convolution
local dH = 1;
padW = 0;
padH = 0;
local module = nn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH);
fNet:add(module);
-- Make ReLU (rectified linear unit) last convolutional layer does not have ReLU unit
if( nConvLayer < nbConvLayers ) then
fNet:add(nn.ReLU());
end
end
local patchSize = 1 + nbConvLayers*(kernel - 1);
local hpatch = (patchSize - 1) / 2
return fNet
end
----------------------------------------------------------------
------------------------------ heads ---------------------------
----------------------------------------------------------------
-- Given two tensors of size nb_pairs x nb_features, head network computes
-- distance tensor of size nb_pairs
-- fully connected linear net
function cnnMetric.fcHead(nbFeatureMap, nbFcLayers, nbFcUnits)
local fcHead = nn.Sequential()
fcHead:add( nn.JoinTable(2) )
fcHead:add( nn.ReLU() ) -- add nonlinearity to last layer of embed net for accurate architecture
for nFcLayer = 1,nbFcLayers do
local idim = (nFcLayer == 1) and nbFeatureMap*2 or nbFcUnits
local odim = nbFcUnits
fcHead:add( nn.Linear(idim, odim) )
fcHead:add( nn.ReLU(true) )
end
fcHead:add( nn.Linear(nbFcUnits, 1) )
fcHead:add( nn.Sigmoid(true) )
return fcHead
end
-- cosine head
function cnnMetric.cosHead(nbFeatureMap)
local cosNet = nn.Sequential()
local normNet = nn.ParallelTable()
cosNet:add(normNet)
normNet:add(nn.Normalize(2))
normNet:add(nn.Normalize(2))
cosNet:add(nn.DotProduct())
-- convert range to (0 1)
cosNet:add(nn.AddConstant(1))
cosNet:add(nn.MulConstant(0.5))
return cosNet
end
return cnnMetric