Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Sarah/feature/reduce ops #74

Merged
merged 8 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions include/intel_npu_acceleration_library/nn_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,81 @@ class ModelFactory : public intel_npu_acceleration_library::OVInferenceModel {
return concat.get();
}

/**
* @brief create a new reduce max operation
*
* @param input operation's input node
* @param reduction_axes the axis positions to be reduced
* @param keep_dims if set to 1 it holds axes that are used for reduction
* @return ov::op::Op*
*/
ov::op::Op* reduce_max(ov::op::Op* input, ov::op::Op* reduction_axes, bool keep_dims) {
auto reduce_max =
std::make_shared<ov::opset1::ReduceMax>(input->output(0), reduction_axes->output(0), keep_dims);
operations.push_back(reduce_max);
return reduce_max.get();
}

/**
* @brief create a new reduce mean operation
*
* @param input operation's input node
* @param reduction_axes the axis positions to be reduced
* @param keep_dims if set to 1 it holds axes that are used for reduction
* @return ov::op::Op*
*/
ov::op::Op* reduce_mean(ov::op::Op* input, ov::op::Op* reduction_axes, bool keep_dims) {
auto reduce_mean =
std::make_shared<ov::opset1::ReduceMean>(input->output(0), reduction_axes->output(0), keep_dims);
operations.push_back(reduce_mean);
return reduce_mean.get();
}

/**
* @brief create a new reduce min operation
*
* @param input operation's input node
* @param reduction_axes the axis positions to be reduced
* @param keep_dims if set to 1 it holds axes that are used for reduction
* @return ov::op::Op*
*/
ov::op::Op* reduce_min(ov::op::Op* input, ov::op::Op* reduction_axes, bool keep_dims) {
auto reduce_min =
std::make_shared<ov::opset1::ReduceMin>(input->output(0), reduction_axes->output(0), keep_dims);
operations.push_back(reduce_min);
return reduce_min.get();
}

/**
* @brief create a new reduce product operation
*
* @param input operation's input node
* @param reduction_axes the axis positions to be reduced
* @param keep_dims if set to 1 it holds axes that are used for reduction
* @return ov::op::Op*
*/
ov::op::Op* reduce_prod(ov::op::Op* input, ov::op::Op* reduction_axes, bool keep_dims) {
auto reduce_prod =
std::make_shared<ov::opset1::ReduceProd>(input->output(0), reduction_axes->output(0), keep_dims);
operations.push_back(reduce_prod);
return reduce_prod.get();
}

/**
* @brief create a new reduce sum operation
*
* @param input operation's input node
* @param reduction_axes the axis positions to be reduced
* @param keep_dims if set to 1 it holds axes that are used for reduction
* @return ov::op::Op*
*/
ov::op::Op* reduce_sum(ov::op::Op* input, ov::op::Op* reduction_axes, bool keep_dims) {
auto reduce_sum =
std::make_shared<ov::opset1::ReduceSum>(input->output(0), reduction_axes->output(0), keep_dims);
operations.push_back(reduce_sum);
return reduce_sum.get();
}

/**
* @brief Create a new absolute activation operation
*
Expand Down
115 changes: 115 additions & 0 deletions intel_npu_acceleration_library/backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,121 @@ def concat(
axis = np.int64(axis)
return backend_lib.concat(self._mm, input_node_1, input_node_2, axis)

@return_tensor
def reduce_max(
self,
input_node: ctypes._Pointer,
reduction_axes: Optional[Union[int, Sequence[int]]] = None,
keep_dims: Optional[bool] = False,
) -> ctypes._Pointer:
"""Generate a reduce max layer.

Args:
input_node (ctypes._Pointer): layer input node
reduction_axes (Optional[Union[int, Sequence[int]]]): the axis positions to be reduced
keep_dims (Optional[bool]): if set to 1 it holds axes that are used for reduction. Defaults to False

Returns:
ctypes._Pointer: output node
"""
if reduction_axes is None:
shape_size = backend_lib.op_shape_size(input_node)
reduction_axes = list(range(shape_size - 1, -1, -1))
axis_node = self.constant(reduction_axes).node # type: ignore
return backend_lib.reduce_max(self._mm, input_node, axis_node, keep_dims)

@return_tensor
def reduce_mean(
self,
input_node: ctypes._Pointer,
reduction_axes: Optional[Union[int, Sequence[int]]] = None,
keep_dims: Optional[bool] = False,
) -> ctypes._Pointer:
"""Generate a reduce mean layer.

Args:
input_node (ctypes._Pointer): layer input node
reduction_axes (Optional[Union[int, Sequence[int]]]): the axis positions to be reduced
keep_dims (Optional[bool] ): if set to 1 it holds axes that are used for reduction. Defaults to False

Returns:
ctypes._Pointer: output node
"""
if reduction_axes is None:
shape_size = backend_lib.op_shape_size(input_node)
reduction_axes = list(range(shape_size - 1, -1, -1))
axis_node = self.constant(reduction_axes).node # type: ignore
return backend_lib.reduce_mean(self._mm, input_node, axis_node, keep_dims)

@return_tensor
def reduce_min(
self,
input_node: ctypes._Pointer,
reduction_axes: Optional[Union[int, Sequence[int]]] = None,
keep_dims: Optional[bool] = False,
) -> ctypes._Pointer:
"""Generate a reduce min layer.

Args:
input_node (ctypes._Pointer): layer input node
reduction_axes (Optional[Union[int, Sequence[int]]]): the axis positions to be reduced
keep_dims (Optional[bool] ): if set to 1 it holds axes that are used for reduction. Defaults to False

Returns:
ctypes._Pointer: output node
"""
if reduction_axes is None:
shape_size = backend_lib.op_shape_size(input_node)
reduction_axes = list(range(shape_size - 1, -1, -1))
axis_node = self.constant(reduction_axes).node # type: ignore
return backend_lib.reduce_min(self._mm, input_node, axis_node, keep_dims)

@return_tensor
def reduce_prod(
self,
input_node: ctypes._Pointer,
reduction_axes: Optional[Union[int, Sequence[int]]] = None,
keep_dims: Optional[bool] = False,
) -> ctypes._Pointer:
"""Generate a reduce product layer.

Args:
input_node (ctypes._Pointer): layer input node
reduction_axes (Optional[Union[int, Sequence[int]]]): the axis positions to be reduced
keep_dims (Optional[bool] ): if set to 1 it holds axes that are used for reduction. Defaults to False

Returns:
ctypes._Pointer: output node
"""
if reduction_axes is None:
shape_size = backend_lib.op_shape_size(input_node)
reduction_axes = list(range(shape_size - 1, -1, -1))
axis_node = self.constant(reduction_axes).node # type: ignore
return backend_lib.reduce_prod(self._mm, input_node, axis_node, keep_dims)

@return_tensor
def reduce_sum(
self,
input_node: ctypes._Pointer,
reduction_axes: Optional[Union[int, Sequence[int]]] = None,
keep_dims: Optional[bool] = False,
) -> ctypes._Pointer:
"""Generate a reduce sum layer.

Args:
input_node (ctypes._Pointer): layer input node
reduction_axes (Optional[Union[int, Sequence[int]]]): the axis positions to be reduced
keep_dims (Optional[bool] ): if set to 1 it holds axes that are used for reduction. Defaults to False

Returns:
ctypes._Pointer: output node
"""
if reduction_axes is None:
shape_size = backend_lib.op_shape_size(input_node)
reduction_axes = list(range(shape_size - 1, -1, -1))
axis_node = self.constant(reduction_axes).node # type: ignore
return backend_lib.reduce_sum(self._mm, input_node, axis_node, keep_dims)

@return_tensor
def normL2(
self, input_node: ctypes._Pointer, axis: int, eps: Optional[float] = 1e-12
Expand Down
25 changes: 25 additions & 0 deletions intel_npu_acceleration_library/backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,31 @@ def get_supported_ops() -> List[SupportedOp]:
inputs=2,
parameters=[ctypes.c_int64],
),
SupportedOp(
name="reduce_max",
inputs=2,
parameters=[ctypes.c_bool],
),
SupportedOp(
name="reduce_mean",
inputs=2,
parameters=[ctypes.c_bool],
),
SupportedOp(
name="reduce_min",
inputs=2,
parameters=[ctypes.c_bool],
),
SupportedOp(
name="reduce_prod",
inputs=2,
parameters=[ctypes.c_bool],
),
SupportedOp(
name="reduce_sum",
inputs=2,
parameters=[ctypes.c_bool],
),
SupportedOp(name="adaptive_avg_pool", inputs=2),
SupportedOp(name="adaptive_max_pool", inputs=2),
]
Expand Down
75 changes: 75 additions & 0 deletions intel_npu_acceleration_library/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,81 @@ def cat(input: Sequence[Tensor], dim: int, out: Optional[Tensor] = None) -> Tens
return tensor


@implements(torch.max)
def max(x, dim: Optional[int] = None, keep_dims: Optional[bool] = False) -> Tensor:
"""Return the reduced max tensor.

Args:
x (Tensor): The input tensor.
dim (Optional[int]): The dim to reduce.
keep_dims (Optional[bool]): If set to 1 it holds axes that are used for reduction. Defaults to False.

Returns:
Tensor: The the reduced max tensor.
"""
return generate_op(x, "reduce_max", reduction_axes=dim, keep_dims=keep_dims)


@implements(torch.mean)
def mean(x, dim: Optional[int] = None, keep_dims: Optional[bool] = False) -> Tensor:
"""Return the reduced mean tensor.

Args:
x (Tensor): The input tensor.
dim (Optional[int]): The dim to reduce.
keep_dims (Optional[bool]): If set to 1 it holds axes that are used for reduction. Defaults to False.

Returns:
Tensor: The the reduced mean tensor.
"""
return generate_op(x, "reduce_mean", reduction_axes=dim, keep_dims=keep_dims)


@implements(torch.min)
def min(x, dim: Optional[int] = None, keep_dims: Optional[bool] = False) -> Tensor:
"""Return the reduced min tensor.

Args:
x (Tensor): The input tensor.
dim (Optional[int]): The dim to reduce.
keep_dims (Optional[bool]): If set to 1 it holds axes that are used for reduction. Defaults to False.

Returns:
Tensor: The the reduced min tensor.
"""
return generate_op(x, "reduce_min", reduction_axes=dim, keep_dims=keep_dims)


@implements(torch.prod)
def prod(x, dim: Optional[int] = None, keep_dims: Optional[bool] = False) -> Tensor:
"""Return the reduced product tensor.

Args:
x (Tensor): The input tensor.
dim (Optional[int]): The dim to reduce.
keep_dims (Optional[bool]): If set to 1 it holds axes that are used for reduction. Defaults to False.

Returns:
Tensor: The the reduced product tensor.
"""
return generate_op(x, "reduce_prod", reduction_axes=dim, keep_dims=keep_dims)


@implements(torch.sum)
def sum(x, dim: Optional[int] = None, keep_dims: Optional[bool] = False) -> Tensor:
"""Return the reduced sum tensor.

Args:
x (Tensor): The input tensor.
dim (Optional[int]): The dim to reduce.
keep_dims (Optional[bool]): If set to 1 it holds axes that are used for reduction. Defaults to False.

Returns:
Tensor: The the reduced sum tensor.
"""
return generate_op(x, "reduce_sum", reduction_axes=dim, keep_dims=keep_dims)


# Functional activations


Expand Down
30 changes: 30 additions & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,36 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* concat(intel_npu_acceleration
return factory->concat(x1, x2, axis);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* reduce_max(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* input, ov::op::Op* reduction_axes,
bool keep_dims) {
return factory->reduce_max(input, reduction_axes, keep_dims);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* reduce_mean(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* input, ov::op::Op* reduction_axes,
bool keep_dims) {
return factory->reduce_mean(input, reduction_axes, keep_dims);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* reduce_min(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* input, ov::op::Op* reduction_axes,
bool keep_dims) {
return factory->reduce_min(input, reduction_axes, keep_dims);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* reduce_prod(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* input, ov::op::Op* reduction_axes,
bool keep_dims) {
return factory->reduce_prod(input, reduction_axes, keep_dims);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* reduce_sum(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* input, ov::op::Op* reduction_axes,
bool keep_dims) {
return factory->reduce_sum(input, reduction_axes, keep_dims);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* convert_to_fp16(
intel_npu_acceleration_library::ModelFactory* factory, ov::op::Op* in0) {
return factory->convert_to(in0, ov::element::Type_t::f16);
Expand Down
32 changes: 32 additions & 0 deletions test/python/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,38 @@ def test_concatenation(batch, hidden_dim, tensors, axis):
assert 1 - r2_score(reference, result) < 0.01


@pytest.mark.parametrize("batch", [16, 128])
@pytest.mark.parametrize("hidden_dim", [128, 256])
@pytest.mark.parametrize("axis", [0, 1, -1, -2, None])
@pytest.mark.parametrize(
"op", [torch.max, torch.mean, torch.min, torch.prod, torch.sum]
)
def test_reduce_operations(batch, hidden_dim, axis, op):

x = torch.rand((batch, hidden_dim)).to(torch.float16)
if axis is None:
reference = op(x)
else:
if op in [torch.max, torch.min]:
reference, _ = op(x, dim=axis)
else:
reference = op(x, dim=axis)
reference = reference.numpy()

model = NNFactory()
par = model.parameter(x.shape, np.float16)
out = op(par) if axis is None else op(par, dim=axis)
model.compile()

assert out.shape == list(reference.shape)

result = model.run(x.numpy())
if not out.shape:
assert 1 - r2_score([reference, 1], [result, 1]) < 0.01
else:
assert 1 - r2_score(reference, result) < 0.01


@pytest.mark.parametrize("channel", [16, 128])
@pytest.mark.parametrize("xydim", [4, 16])
@pytest.mark.parametrize(
Expand Down
Loading