Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Commit ed0993b

Browse files
Add torch.nn.functional.conv2d (#70)
* Refactor convolution code * Support torch.nn.functional.conv2d * Support for same and valid padding * View supports also args
1 parent b67bd8b commit ed0993b

File tree

9 files changed

+165
-55
lines changed

9 files changed

+165
-55
lines changed

intel_npu_acceleration_library/backend/bindings.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def init_network_factory(lib: ctypes.CDLL):
161161
lib.linear.restype = handler
162162

163163
lib.convolution.argtypes = [
164+
handler,
165+
handler,
164166
handler,
165167
handler,
166168
ctypes.c_int,
@@ -172,10 +174,6 @@ def init_network_factory(lib: ctypes.CDLL):
172174
ctypes.c_int,
173175
c_u32_array,
174176
ctypes.c_int,
175-
c_u32_array,
176-
ctypes.c_int,
177-
ctypes.c_bool,
178-
ctypes.c_char_p,
179177
ctypes.c_char_p,
180178
]
181179
lib.convolution.restype = handler

intel_npu_acceleration_library/backend/convolution.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,34 +38,21 @@ def __init__(
3838
"""
3939
super().__init__(profile, device)
4040
input = self.parameter(input_shape)
41-
42-
# Get the number of spatial dimensions
43-
n_spatial_dims = len(input_shape) - 2
44-
45-
if isinstance(strides, int):
46-
strides = [strides] * n_spatial_dims
47-
48-
if isinstance(padding, int):
49-
padding_begins = [padding] * n_spatial_dims
50-
padding_ends = [padding] * n_spatial_dims
41+
weights = self.parameter(weights_shape)
42+
if bias is not None:
43+
bias_node = self.parameter((1, weights_shape[0], 1, 1))
5144
else:
52-
padding_begins = list(padding)
53-
padding_ends = list(padding)
54-
55-
if isinstance(dilation, int):
56-
dilation = [dilation] * n_spatial_dims
45+
bias_node = None
5746

5847
conv = self.convolution(
5948
input,
60-
weights_shape,
61-
bias=bias,
49+
weights,
50+
bias=bias_node,
6251
strides=strides,
63-
padding_begins=padding_begins,
64-
padding_ends=padding_ends,
52+
padding=padding,
6553
dilation=dilation,
6654
groups=groups,
6755
act_dtype=np.float16,
68-
wt_dtype=np.float16,
6956
)
7057

7158
self.compile(conv)

intel_npu_acceleration_library/backend/factory.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def constant(
188188
data = np.array([data], dtype=np.float32)
189189
elif isinstance(data, torch.Tensor):
190190
data = data.detach().numpy()
191+
elif data is None:
192+
return ctypes.cast(ctypes.c_void_p(0), ctypes.POINTER(ctypes.c_char))
191193

192194
dst = data.ctypes.data_as(ctypes.c_void_p)
193195
shape_ptr = np.array(data.shape, dtype=np.uint32)
@@ -199,44 +201,59 @@ def constant(
199201
def convolution(
200202
self,
201203
input_node: ctypes._Pointer,
202-
weights_shape: Sequence[int],
203-
bias: bool,
204-
strides: Sequence[int] = (1, 1),
205-
padding_begins: Sequence[int] = (0, 0),
206-
padding_ends: Sequence[int] = (0, 0),
207-
dilation: Sequence[int] = (1, 1),
204+
weights_node: ctypes._Pointer,
205+
bias: Optional[ctypes._Pointer] = None,
206+
strides: Union[int, Sequence[int]] = 1,
207+
padding: Union[int, Sequence[int]] = 0,
208+
dilation: Union[int, Sequence[int]] = 1,
208209
groups: int = 1,
209210
act_dtype: npt.DTypeLike = np.float16,
210-
wt_dtype: npt.DTypeLike = np.float16,
211+
n_spatial_dims: int = 2,
211212
) -> ctypes._Pointer:
212213
"""Generate a convolution layer.
213214
214215
Args:
215216
input_node (ctypes._Pointer): layer input node
216-
weights_shape (Sequence[int]): weights shape
217+
weights_node (ctypes._Pointer): weights node
218+
bias (Optional[ctypes._Pointer}): bias node
217219
strides (Sequence[int]): strides
218-
padding_begins (Sequence[int]): padding
219-
padding_ends (Sequence[int]): padding
220+
padding (Sequence[int]): padding
220221
dilation (Sequence[int]): dilation
221222
groups (int): groups
222-
bias (bool): enable/disable bias
223223
act_dtype (npt.DTypeLike, optional): activation dtype. Defaults to np.float16.
224-
wt_dtype (npt.DTypeLike, optional): weight dtype. Defaults to np.float16.
224+
n_spatial_dims (int): number of spatial dimensions
225225
226226
Returns:
227227
ctypes._Pointer: output node
228228
"""
229-
weights_shape_ptr = np.array(weights_shape, dtype=np.uint32)
229+
if isinstance(strides, int):
230+
strides = [strides] * n_spatial_dims
231+
232+
if isinstance(padding, int):
233+
padding_begins = [padding] * n_spatial_dims
234+
padding_ends = [padding] * n_spatial_dims
235+
else:
236+
padding_begins = list(padding)
237+
padding_ends = list(padding)
238+
239+
if isinstance(dilation, int):
240+
dilation = [dilation] * n_spatial_dims
241+
230242
strides_ptr = np.array(strides, dtype=np.uint32)
231243
padding_begins_ptr = np.array(padding_begins, dtype=np.uint32)
232244
padding_ends_ptr = np.array(padding_ends, dtype=np.uint32)
233245
dilation_ptr = np.array(dilation, dtype=np.uint32)
234246

247+
if bias is not None:
248+
bias_node = bias
249+
else:
250+
bias_node = ctypes.cast(ctypes.c_void_p(0), ctypes.POINTER(ctypes.c_char))
251+
235252
return backend_lib.convolution(
236253
self._mm,
237254
input_node,
238-
weights_shape_ptr.size,
239-
weights_shape_ptr,
255+
weights_node,
256+
bias_node,
240257
strides_ptr.size,
241258
strides_ptr,
242259
padding_begins_ptr.size,
@@ -246,9 +263,7 @@ def convolution(
246263
dilation_ptr.size,
247264
dilation_ptr,
248265
groups,
249-
bias,
250266
self.get_backend_dtype(act_dtype),
251-
self.get_backend_dtype(wt_dtype),
252267
)
253268

254269
@return_tensor

intel_npu_acceleration_library/backend/tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,16 +335,19 @@ def reshape(self, *shape: Union[int, Sequence[int]]) -> "Tensor":
335335
shape = shape[0] # type: ignore
336336
return generate_op([self], "reshape", shape)
337337

338-
def view(self, shape: Sequence[int]) -> "Tensor":
338+
def view(self, *shape: Union[Sequence[int], int]) -> "Tensor":
339339
"""
340340
Return the transpose of the tensor.
341341
342342
Args:
343-
shape (Sequence[int]): The new shape of the tensor.
343+
shape (Union[Sequence[int], int]): The new shape of the tensor.
344344
345345
Returns:
346346
Tensor: The transposed tensor.
347347
"""
348+
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
349+
shape = shape[0] # type: ignore
350+
348351
return self.reshape(*shape)
349352

350353
def flatten(self, start_dim=0, end_dim=-1) -> "Tensor":

intel_npu_acceleration_library/nn/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __init__(
189189
self.backend_cls = partial(
190190
Convolution,
191191
weights_shape=weights.shape,
192-
bias=bias is not None,
192+
bias=bias,
193193
strides=strides,
194194
padding=padding,
195195
dilation=dilation,

intel_npu_acceleration_library/nn/functional.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,3 +928,57 @@ def batch_norm(
928928
result = result + bias.view(1, -1, 1, 1)
929929

930930
return result
931+
932+
933+
@implements(torch.nn.functional.conv2d)
934+
def conv2d(
935+
input: Tensor,
936+
weight: Union[Tensor, torch.Tensor],
937+
bias: Optional[Union[Tensor, torch.Tensor]] = None,
938+
stride: int = 1,
939+
padding: Union[int, str] = 0,
940+
dilation: int = 1,
941+
groups: int = 1,
942+
) -> Tensor:
943+
"""Generate a convolution layer.
944+
945+
Args:
946+
input (Tensor): layer input node
947+
weight (Union[Tensor, torch.Tensor]): weight
948+
bias (Union[Tensor, torch.Tensor]): bias
949+
stride (int): stride
950+
padding (Union[int, str]): padding
951+
dilation (int): dilation
952+
groups (int): groups
953+
954+
Raises:
955+
ValueError: Padding mode not supported
956+
957+
Returns:
958+
Tensor: output node
959+
"""
960+
if isinstance(padding, str):
961+
if padding == "valid":
962+
padding = 0
963+
elif padding == "same":
964+
padding = weight.shape[2] // 2
965+
else:
966+
raise ValueError(f"Padding mode {padding} not supported")
967+
968+
if bias is not None:
969+
bias = bias.view((1, weight.shape[0], 1, 1))
970+
971+
if groups > 1:
972+
new_shape = [groups, weight.shape[0] // groups] + list(weight.shape[1:])
973+
weight = weight.view(new_shape)
974+
975+
conv = generate_op(
976+
[input, weight, bias],
977+
"convolution",
978+
strides=stride,
979+
padding=padding,
980+
dilation=dilation,
981+
groups=groups,
982+
)
983+
984+
return conv

src/bindings.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,24 +425,25 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* linear(intel_npu_acceleration
425425
return mm;
426426
}
427427

428-
intel_npu_acceleration_library_DLL_API ov::op::Op* convolution(
429-
intel_npu_acceleration_library::ModelFactory* factory, ov::op::Op* in0, size_t weight_shape_size,
430-
unsigned int* weight_shape_data, size_t strides_size, unsigned int* strides_data, size_t pad_begins_size,
431-
unsigned int* pad_begins_data, size_t pad_ends_size, unsigned int* pad_ends_data, size_t dilations_size,
432-
unsigned int* dilations_data, size_t groups, bool bias, char* act_dtype, char* wt_dtype) {
428+
intel_npu_acceleration_library_DLL_API ov::op::Op* convolution(intel_npu_acceleration_library::ModelFactory* factory,
429+
ov::op::Op* in0, ov::op::Op* weights, ov::op::Op* bias,
430+
size_t strides_size, unsigned int* strides_data,
431+
size_t pad_begins_size, unsigned int* pad_begins_data,
432+
size_t pad_ends_size, unsigned int* pad_ends_data,
433+
size_t dilations_size, unsigned int* dilations_data,
434+
size_t groups, char* act_dtype) {
433435
ov::element::Type_t act_ov_dtype = intel_npu_acceleration_library::dtype_from_string(std::string(act_dtype));
434-
ov::element::Type_t wt_ov_dtype = intel_npu_acceleration_library::dtype_from_string(std::string(wt_dtype));
435436

436437
// Create vectors from the input data
437-
std::vector<size_t> weight_shape(weight_shape_data, weight_shape_data + weight_shape_size);
438438
std::vector<size_t> strides(strides_data, strides_data + strides_size);
439439
std::vector<size_t> pad_begins(pad_begins_data, pad_begins_data + pad_begins_size);
440440
std::vector<size_t> pad_ends(pad_ends_data, pad_ends_data + pad_ends_size);
441441
std::vector<size_t> dilations(dilations_data, dilations_data + dilations_size);
442442

443-
bool quantized = wt_ov_dtype == ov::element::Type_t::i8 || wt_ov_dtype == ov::element::Type_t::i4;
443+
auto weight_shape = weights->get_output_shape(0);
444+
auto wt_ov_dtype = static_cast<ov::element::Type_t>(weights->get_output_element_type(0));
444445

445-
auto weights = factory->parameter(weight_shape, wt_ov_dtype);
446+
bool quantized = wt_ov_dtype == ov::element::Type_t::i8 || wt_ov_dtype == ov::element::Type_t::i4;
446447

447448
if (quantized) {
448449
weights = factory->convert_to(weights, act_ov_dtype);
@@ -459,7 +460,6 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* convolution(
459460
}
460461

461462
if (bias) {
462-
auto bias = factory->parameter({1, weight_shape[0], 1, 1}, act_ov_dtype);
463463
return factory->eltwise_add(mm, bias);
464464
}
465465
return mm;

test/python/test_op.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,52 @@ def test_batch_norm(shape, mean, variance, weight, bias):
322322
result = model.run(x.numpy())
323323

324324
assert 1 - r2_score(reference.flatten(), result.flatten()) < 0.01
325+
326+
327+
@pytest.mark.parametrize("in_channels", [32, 128, 256])
328+
@pytest.mark.parametrize("out_channels", [32, 128, 256])
329+
@pytest.mark.parametrize("kernels", [1, 3])
330+
@pytest.mark.parametrize("dim", [16, 32])
331+
@pytest.mark.parametrize("bias", [True, False])
332+
@pytest.mark.parametrize("dtype", [torch.float16])
333+
@pytest.mark.parametrize("stride", [1, 2])
334+
@pytest.mark.parametrize("padding", [0, 1, "same", "valid"])
335+
@pytest.mark.parametrize("groups", [1, -1])
336+
def test_conv(
337+
in_channels, out_channels, kernels, dim, bias, dtype, stride, padding, groups
338+
):
339+
torch.manual_seed(42)
340+
341+
if groups != 1 and in_channels != out_channels:
342+
pytest.skip("DW convolutions require in_channels == out_channels")
343+
344+
if padding == "same" and stride > 1:
345+
pytest.skip("padding='same' is not supported for strided convolutions")
346+
347+
if groups == -1:
348+
groups = in_channels
349+
350+
x = torch.rand((1, in_channels, dim, dim)).to(torch.float16)
351+
352+
weight = torch.rand((out_channels, in_channels // groups, kernels, kernels)).to(
353+
torch.float16
354+
)
355+
bias = torch.rand((out_channels,)).to(torch.float16) if bias else None
356+
357+
reference = (
358+
torch.nn.functional.conv2d(x, weight, bias, stride, padding, groups=groups)
359+
.detach()
360+
.numpy()
361+
)
362+
363+
model = NNFactory()
364+
par = model.parameter(x.shape, np.float16)
365+
366+
out = torch.nn.functional.conv2d(par, weight, bias, stride, padding, groups=groups)
367+
model.compile(out)
368+
369+
assert out.shape == list(reference.shape)
370+
371+
result = model.run(x.numpy())
372+
373+
assert 1 - r2_score(reference.flatten(), result.flatten()) < 0.01

test/python/test_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ def test_model_creation():
9999

100100
assert ff.dim() == 3
101101

102-
model.compile(ff)
102+
gg = ff.view(1, -1, 1, 1)
103+
104+
assert gg.shape == [1, 32 * 128 * 64, 1, 1]
105+
106+
model.compile(gg)
103107

104108

105109
def test_slice():

0 commit comments

Comments
 (0)