Skip to content

Commit bee7a2b

Browse files
author
Michael 'myrhev' Mathieu
committed
Add SpatialRadialMatching
1 parent 07562f3 commit bee7a2b

File tree

6 files changed

+168
-1
lines changed

6 files changed

+168
-1
lines changed

SpatialPadding.lua

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ function SpatialPadding:__init(pad_l, pad_r, pad_t, pad_b, y_dim, x_dim)
2424
end
2525

2626
function SpatialPadding:updateOutput(input)
27+
if self.output:type() ~= input:type() then
28+
self.output = input.new()
29+
end
2730
self.x_dim = self.x_dim or 3
2831
self.y_dim = self.y_dim or 2
2932
local h = input:size(self.y_dim) + self.pad_t + self.pad_b

SpatialRadialMatching.lua

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
local SpatialRadialMatching, parent = torch.class('nn.SpatialRadialMatching', 'nn.Module')
2+
3+
function SpatialRadialMatching:__init(maxh)
4+
-- If full_output is false, output is computed on elements of the first input
5+
-- for which all the possible corresponding elements exist in the second input
6+
-- In addition, if full_output is set to false, the pixel (1,1) of the first input
7+
-- is supposed to correspond to the pixel (maxh/2, maxw/2) of the second one
8+
parent.__init(self)
9+
self.maxh = maxh
10+
self.gradInput1 = torch.Tensor()
11+
self.gradInput2 = torch.Tensor()
12+
end
13+
14+
function SpatialRadialMatching:updateOutput(input)
15+
-- input is a table of 2 inputs, each one being KxHxW
16+
-- if not full_output, the 1st one is KxH1xW1 where H1 <= H-maxh+1, W1 <= W-maxw+1
17+
self.output:resize(input[1]:size(2), input[1]:size(3), self.maxh)
18+
--if input[3] == nil then
19+
-- input[3] = torch.LongTensor(input[1]:size(2), input[1]:size(3)):fill(1)
20+
--end
21+
--input[1].nn.SpatialRadialMatching_updateOutput(self, input[1], input[2], input[3])
22+
input[1].nn.SpatialRadialMatching_updateOutput(self, input[1], input[2])
23+
return self.output
24+
end
25+
26+
function SpatialRadialMatching:updateGradInput(input, gradOutput)
27+
self.gradInput1:resize(input[1]:size()):zero()
28+
self.gradInput2:resize(input[2]:size()):zero()
29+
--input[1].nn.SpatialRadialMatching_updateGradInput(self,input[1],input[2],gradOutput,input[3])
30+
input[1].nn.SpatialRadialMatching_updateGradInput(self,input[1],input[2],gradOutput)
31+
self.gradInput = {self.gradInput1, self.gradInput2}
32+
return self.gradInput
33+
end

generic/SpatialMatching.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static int nn_(SpatialMatching_updateOutput)(lua_State *L)
8989
}
9090
*/
9191
} else {
92-
//#pragma omp parallel for private(y1,x1,x2,y2,k,dist)
92+
#pragma omp parallel for private(y1,x1,x2,y2,k,dist)
9393
for (y1 = 0; y1 < iheight; y1++) {
9494
for (x1 = 0; x1 < iwidth; x1++) {
9595
for (y2 = y1; y2 < y1+maxh; y2++) {

generic/SpatialRadialMatching.c

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#ifndef TH_GENERIC_FILE
2+
#define TH_GENERIC_FILE "generic/SpatialRadialMatching.c"
3+
#else
4+
5+
#define square(x) ((x)*(x))
6+
#define max(x,y) (((x)>(y)) ? (x) : (y))
7+
#define min(x,y) (((x)>(y)) ? (y) : (x))
8+
9+
static int nn_(SpatialRadialMatching_updateOutput)(lua_State *L)
10+
{
11+
// get all params
12+
THTensor *input1 = luaT_checkudata(L, 2, torch_(Tensor_id));
13+
THTensor *input2 = luaT_checkudata(L, 3, torch_(Tensor_id));
14+
//THLongTensor *mask= luaT_checkudata(L, 4, luaT_checktypename2id(L, "torch.LongTensor"));
15+
int maxh = luaT_getfieldcheckint(L, 1, "maxh");
16+
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id));
17+
18+
// dims
19+
int iwidth = input1->size[2];
20+
int iheight = input1->size[1];
21+
int ichannels = input1->size[0];
22+
23+
// get strides
24+
long *i1s = input1->stride;
25+
long *i2s = input2->stride;
26+
//long *ms = mask ->stride;
27+
long *os = output->stride;
28+
29+
// get pointers
30+
real *input1_p = THTensor_(data)(input1);
31+
real *input2_p = THTensor_(data)(input2);
32+
//long *mask_p = THLongTensor_data(mask);
33+
real *output_p = THTensor_(data)(output);
34+
35+
// compute output
36+
int x1,y1,y2,k;
37+
real dist;
38+
#pragma omp parallel for private(y1,x1,y2,k,dist)
39+
for (y1 = 0; y1 < iheight; y1++) {
40+
for (x1 = 0; x1 < iwidth; x1++) {
41+
//if (mask_p[y1*ms[0] + x1*ms[1]]) {
42+
for (y2 = y1; y2 < y1+maxh; y2++) {
43+
dist = 0.0f;
44+
for (k = 0; k < ichannels; k++)
45+
dist += square( input1_p[k*i1s[0] + y1*i1s[1] + x1*i1s[2]]
46+
- input2_p[k*i2s[0] + y2*i2s[1] + x1*i2s[2]]);
47+
output_p[(y2-y1)*os[2] + y1*os[0] + x1*os[1]] = dist;
48+
}
49+
//}
50+
}
51+
}
52+
53+
// done
54+
return 0;
55+
}
56+
57+
static int nn_(SpatialRadialMatching_updateGradInput)(lua_State *L)
58+
{
59+
// get all params
60+
THTensor* input1 = luaT_checkudata(L, 2, torch_(Tensor_id));
61+
THTensor* input2 = luaT_checkudata(L, 3, torch_(Tensor_id));
62+
THTensor* gradOutput = luaT_checkudata(L, 4, torch_(Tensor_id));
63+
//THLongTensor* mask = luaT_checkudata(L, 5, luaT_checktypename2id(L, "torch.LongTensor"));
64+
THTensor* gradInput1 = luaT_getfieldcheckudata(L, 1, "gradInput1", torch_(Tensor_id));
65+
THTensor* gradInput2 = luaT_getfieldcheckudata(L, 1, "gradInput2", torch_(Tensor_id));
66+
int maxh = luaT_getfieldcheckint(L, 1, "maxh");
67+
68+
// dims
69+
int iwidth = input1->size[2];
70+
int iheight = input1->size[1];
71+
int ichannels = input1->size[0];
72+
73+
// get strides
74+
long* i1s = input1->stride;
75+
long* i2s = input2->stride;
76+
long* gi1s = gradInput1->stride;
77+
long* gi2s = gradInput2->stride;
78+
long* gos = gradOutput->stride;
79+
//long* ms = mask->stride;
80+
81+
// get pointers
82+
real* input1_p = THTensor_(data)(input1);
83+
real* input2_p = THTensor_(data)(input2);
84+
real* gradInput1_p = THTensor_(data)(gradInput1);
85+
real* gradInput2_p = THTensor_(data)(gradInput2);
86+
real* gradOutput_p = THTensor_(data)(gradOutput);
87+
//long* mask_p = THLongTensor_data(mask);
88+
89+
// compute gradients
90+
int x1, y1, y2, k;
91+
real partial_d;
92+
for (y1 = 0; y1 < iheight; y1++) {
93+
for (x1 = 0; x1 < iwidth; x1++) {
94+
// if (mask_p[y1*ms[0] + x1*ms[1]]) {
95+
for (y2 = y1; y2 < y1+maxh; y2++) {
96+
for (k = 0; k < ichannels; k++) {
97+
partial_d = 2.0f*( input1_p[k*i1s[0] + y1*i1s[1] + x1*i1s[2]]
98+
- input2_p[k*i2s[0] + y2*i2s[1] + x1*i2s[2]]);
99+
partial_d *= gradOutput_p[(y2-y1)*gos[2]+y1*gos[0]+x1*gos[1]];
100+
gradInput1_p[k*gi1s[0] + y1*gi1s[1] + x1*gi1s[2]] += partial_d;
101+
gradInput2_p[k*gi2s[0] + y2*gi2s[1] + x1*gi2s[2]] -= partial_d;
102+
}
103+
}
104+
//}
105+
}
106+
}
107+
108+
// done
109+
return 0;
110+
}
111+
112+
static const struct luaL_Reg nn_(SpatialRadialMatching__) [] = {
113+
{"SpatialRadialMatching_updateOutput", nn_(SpatialRadialMatching_updateOutput)},
114+
{"SpatialRadialMatching_updateGradInput", nn_(SpatialRadialMatching_updateGradInput)},
115+
{NULL, NULL}
116+
};
117+
118+
static void nn_(SpatialRadialMatching_init)(lua_State *L)
119+
{
120+
luaT_pushmetaclass(L, torch_(Tensor_id));
121+
luaT_registeratname(L, nn_(SpatialRadialMatching__), "nn");
122+
lua_pop(L,1);
123+
}
124+
125+
#endif

init.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ static const void* torch_DoubleTensor_id = NULL;
5757
#include "generic/SpatialMatching.c"
5858
#include "THGenerateFloatTypes.h"
5959

60+
#include "generic/SpatialRadialMatching.c"
61+
#include "THGenerateFloatTypes.h"
62+
6063
#include "generic/DataSetLabelMe.c"
6164
#include "THGenerateFloatTypes.h"
6265

@@ -81,6 +84,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
8184
nn_FloatSpatialClassNLLCriterion_init(L);
8285
nn_FloatSpatialGraph_init(L);
8386
nn_FloatSpatialMatching_init(L);
87+
nn_FloatSpatialRadialMatching_init(L);
8488
nn_FloatDataSetLabelMe_init(L);
8589

8690
nn_DoubleSpatialLinear_init(L);
@@ -99,6 +103,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
99103
nn_DoubleSpatialClassNLLCriterion_init(L);
100104
nn_DoubleSpatialGraph_init(L);
101105
nn_DoubleSpatialMatching_init(L);
106+
nn_DoubleSpatialRadialMatching_init(L);
102107
nn_DoubleDataSetLabelMe_init(L);
103108

104109
return 1;

init.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ torch.include('nnx', 'SpatialFovea.lua')
7373
torch.include('nnx', 'SpatialPyramid.lua')
7474
torch.include('nnx', 'SpatialGraph.lua')
7575
torch.include('nnx', 'SpatialMatching.lua')
76+
torch.include('nnx', 'SpatialRadialMatching.lua')
7677
torch.include('nnx', 'SpatialMaxSampling.lua')
7778
torch.include('nnx', 'SpatialColorTransform.lua')
7879
torch.include('nnx', 'SpatialConvolutionSparse.lua')

0 commit comments

Comments
 (0)