diff --git a/include/intel_npu_acceleration_library/nn_factory.h b/include/intel_npu_acceleration_library/nn_factory.h index bb5f755..e0b9f89 100644 --- a/include/intel_npu_acceleration_library/nn_factory.h +++ b/include/intel_npu_acceleration_library/nn_factory.h @@ -702,6 +702,21 @@ class ModelFactory : public intel_npu_acceleration_library::OVInferenceModel { return sdpa.get(); } + /** + * @brief Create a new L2 normalization operation + * + * @param data operation's input node + * @param axes node indicating axes along which reduction is calculated + * @param eps the epsilon added to L2 norm + * @return ov::op::Op* + */ + ov::op::Op* normL2(ov::op::Op* data, ov::op::Op* axes, float eps) { + auto normL2 = + std::make_shared(data->output(0), axes->output(0), eps, ov::op::EpsMode::MAX); + operations.push_back(normL2); + return normL2.get(); + } + /** * @brief Compile the model * diff --git a/intel_npu_acceleration_library/backend/factory.py b/intel_npu_acceleration_library/backend/factory.py index a848e87..5a0d3f1 100644 --- a/intel_npu_acceleration_library/backend/factory.py +++ b/intel_npu_acceleration_library/backend/factory.py @@ -376,6 +376,25 @@ def slice( end_mask_ptr, ) + @return_tensor + def normL2( + self, input_node: ctypes._Pointer, axis: int, eps: Optional[float] = 1e-12 + ) -> ctypes._Pointer: + """Generate an L2 normalization layer. + + Args: + input_node (ctypes._Pointer): layer input node + axis (int): axis + eps (float): epsilon added to L2 norm + + Returns: + ctypes._Pointer: output node + """ + if axis < 0: + axis = abs(axis) + axis_node = self.constant(axis).node # type: ignore + return backend_lib.normL2(self._mm, input_node, axis_node, eps) + def get_output_tensor_shape(self): """Get output tensor shape. diff --git a/intel_npu_acceleration_library/backend/ops.py b/intel_npu_acceleration_library/backend/ops.py index d0c7ecd..8c24f6e 100644 --- a/intel_npu_acceleration_library/backend/ops.py +++ b/intel_npu_acceleration_library/backend/ops.py @@ -79,6 +79,11 @@ def get_supported_ops() -> List[SupportedOp]: inputs=4, parameters=[ctypes.c_bool], ), + SupportedOp( + name="normL2", + inputs=2, + parameters=[ctypes.c_float], + ), SupportedOp( name="gather", inputs=3, diff --git a/src/bindings.cpp b/src/bindings.cpp index 702c0dc..41693f9 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -465,4 +465,9 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* scaled_dot_product_attention( ov::op::Op* attn_mask, bool is_causal) { return factory->scaled_dot_product_attention(query, key, value, attn_mask, is_causal); } + +intel_npu_acceleration_library_DLL_API ov::op::Op* normL2(intel_npu_acceleration_library::ModelFactory* factory, + ov::op::Op* data, ov::op::Op* axes, float eps) { + return factory->normL2(data, axes, eps); +} } \ No newline at end of file diff --git a/test/python/test_layers.py b/test/python/test_layers.py index cc87025..cdb9484 100644 --- a/test/python/test_layers.py +++ b/test/python/test_layers.py @@ -8,7 +8,6 @@ import numpy as np import pytest import torch -import itertools class MLP_PT(torch.nn.Module): @@ -313,3 +312,26 @@ def test_constant(batch, hidden_dim): assert np.isfinite(out).all(), "NPU output contains NaN or Inf" assert 1 - r2_score(reference, out) < 0.001 + + +@pytest.mark.parametrize("batch", [16, 128]) +@pytest.mark.parametrize("hidden_dim", [256, 512]) +@pytest.mark.parametrize("axis", [0, 1]) +def test_normalisation(batch, hidden_dim, axis): + + X = torch.rand((batch, hidden_dim)).to(torch.float16) - 0.5 + + model = NNFactory() + input = model.parameter(X.shape) + output = model.normL2(input, axis) + model.compile(output) + out = model.run(X.numpy()) + + reference = torch.nn.functional.normalize(X, p=2.0, dim=axis).numpy() + print(out) + print(reference) + assert out.shape == reference.shape, "Output shape mismatch" + assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf" + assert np.isfinite(out).all(), "NPU output contains NaN or Inf" + + assert 1 - r2_score(reference, out) < 0.001