Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Commit 641f730

Browse files
SarahByrneIntelSarahByrneIntel
andauthored
Adding power and log softmax operations (#80)
* Adding power and log softmax operations * Add functionality to handle tensor params for power opeation * Fix of function comments --------- Co-authored-by: SarahByrneIntel <[email protected]>
1 parent af7a21d commit 641f730

File tree

6 files changed

+172
-0
lines changed

6 files changed

+172
-0
lines changed

include/intel_npu_acceleration_library/nn_factory.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,33 @@ class ModelFactory : public intel_npu_acceleration_library::OVInferenceModel {
884884
return normL2.get();
885885
}
886886

887+
/**
888+
* @brief Create a new power operation
889+
*
890+
* @param x1 operation's input node
891+
* @param x2 operation's input node of the exponent
892+
* @param auto_broadcast auto broadcast specification
893+
* @return ov::op::Op*
894+
*/
895+
ov::op::Op* power(ov::op::Op* x1, ov::op::Op* x2, ov::op::AutoBroadcastType auto_broadcast) {
896+
auto power = std::make_shared<ov::opset1::Power>(x1->output(0), x2->output(0), auto_broadcast);
897+
operations.push_back(power);
898+
return power.get();
899+
}
900+
901+
/**
902+
* @brief Create a new log softmax operation
903+
*
904+
* @param input operation's input node
905+
* @param axis the axis position on which to calculate the LogSoftmax
906+
* @return ov::op::Op*
907+
*/
908+
ov::op::Op* log_softmax(ov::op::Op* input, int64_t axis) {
909+
auto log_softmax = std::make_shared<ov::opset5::LogSoftmax>(input->output(0), axis);
910+
operations.push_back(log_softmax);
911+
return log_softmax.get();
912+
}
913+
887914
void result(ov::op::Op* op) {
888915
auto res = std::make_shared<ov::opset8::Result>(op->output(0));
889916
results.push_back(res);

intel_npu_acceleration_library/backend/factory.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,41 @@ def normL2(
563563
axis_node = self.constant(axis).node # type: ignore
564564
return backend_lib.normL2(self._mm, input_node, axis_node, eps)
565565

566+
@return_tensor
567+
def power(
568+
self,
569+
input_node: ctypes._Pointer,
570+
exponent: Union[ctypes._Pointer, torch.Tensor],
571+
) -> ctypes._Pointer:
572+
"""Generate a power layer.
573+
574+
Args:
575+
input_node (ctypes._Pointer): layer input node
576+
exponent (Union[ctypes._Pointer, torch.Tensor]): the exponent value
577+
578+
Raises:
579+
ValueError: Input tensor shapes are not equal
580+
581+
Returns:
582+
ctypes._Pointer: output node
583+
"""
584+
input_shape_size = backend_lib.op_shape_size(input_node)
585+
input_shape = [
586+
backend_lib.op_shape(input_node, i) for i in range(input_shape_size)
587+
]
588+
if isinstance(exponent, ctypes._Pointer):
589+
exponent_shape_size = backend_lib.op_shape_size(input_node)
590+
exponent_shape = [
591+
backend_lib.op_shape(exponent, i) for i in range(exponent_shape_size)
592+
]
593+
else:
594+
exponent_shape = list(exponent.shape)
595+
exponent = self.constant(exponent).node # type: ignore
596+
if exponent_shape != input_shape:
597+
raise ValueError("Input tensor shapes are not equal")
598+
599+
return backend_lib.power(self._mm, input_node, exponent)
600+
566601
@return_tensor
567602
def avg_pooling(
568603
self,

intel_npu_acceleration_library/backend/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,5 +132,7 @@ def get_supported_ops() -> List[SupportedOp]:
132132
),
133133
SupportedOp(name="adaptive_avg_pool", inputs=2),
134134
SupportedOp(name="adaptive_max_pool", inputs=2),
135+
SupportedOp(name="power", inputs=2),
136+
SupportedOp(name="log_softmax", inputs=1, parameters=[ctypes.c_int64]),
135137
]
136138
return supported_ops

intel_npu_acceleration_library/nn/functional.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,3 +1109,46 @@ def conv2d(
11091109
)
11101110

11111111
return conv
1112+
1113+
1114+
@implements(torch.pow)
1115+
def pow(input: Tensor, exponent: Union[Tensor, torch.Tensor, float]):
1116+
"""Return the tensor raised to the power of the exponent.
1117+
1118+
Args:
1119+
input (Tensor): The input tensor.
1120+
exponent (Union[Tensor, torch.Tensor, float]): The exponent value.
1121+
1122+
Returns:
1123+
Tensor: Output tensor.
1124+
"""
1125+
if isinstance(exponent, float):
1126+
exponent = torch.full(input.shape, exponent).to(torch.float16)
1127+
return generate_op([input], "power", exponent=exponent)
1128+
1129+
1130+
@implements(torch.nn.functional.log_softmax)
1131+
def log_softmax(
1132+
input: Tensor,
1133+
dim: Optional[int] = None,
1134+
_stacklevel=3,
1135+
dtype: Optional[torch.dtype] = None,
1136+
) -> Tensor:
1137+
"""Return the log softmax of a tensor element-wise.
1138+
1139+
Args:
1140+
input (Tensor): The input tensor.
1141+
dim (int): The dimension along which log_softmax will be computed. Defaults to -1.
1142+
_stacklevel (int): The stack level. Defaults to 3.
1143+
dtype (torch.dtype): The data type. Defaults to None.
1144+
1145+
Returns:
1146+
Tensor: Output tensor.
1147+
"""
1148+
if dim is None:
1149+
dim = -1
1150+
log_smax = generate_op([input], "log_softmax", dim)
1151+
1152+
if dtype:
1153+
log_smax = log_smax.to(dtype)
1154+
return log_smax

src/bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,4 +563,14 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* normL2(intel_npu_acceleration
563563
ov::op::Op* data, ov::op::Op* axes, float eps) {
564564
return factory->normL2(data, axes, eps);
565565
}
566+
567+
intel_npu_acceleration_library_DLL_API ov::op::Op* power(intel_npu_acceleration_library::ModelFactory* factory,
568+
ov::op::Op* x1, ov::op::Op* x2) {
569+
return factory->power(x1, x2, ov::op::AutoBroadcastType::NUMPY);
570+
}
571+
572+
intel_npu_acceleration_library_DLL_API ov::op::Op* log_softmax(intel_npu_acceleration_library::ModelFactory* factory,
573+
ov::op::Op* input, int64_t axis) {
574+
return factory->log_softmax(input, axis);
575+
}
566576
}

test/python/test_op.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,58 @@ def test_multiple_outputs():
452452

453453
assert 1 - r2_score(result0.detach().numpy().flatten(), ref1.flatten()) < 0.01
454454
assert 1 - r2_score(result1.detach().numpy().flatten(), ref2.flatten()) < 0.01
455+
456+
457+
@pytest.mark.parametrize("batch", [16, 128])
458+
@pytest.mark.parametrize("hidden_dim", [128, 256])
459+
@pytest.mark.parametrize("exponent", ["tensor", "float"])
460+
@pytest.mark.parametrize("exponent_type", ["parameter", "constant"])
461+
def test_power(batch, hidden_dim, exponent, exponent_type):
462+
463+
x = torch.rand((batch, hidden_dim)).to(torch.float16)
464+
if exponent == "tensor":
465+
exponent = torch.rand((batch, hidden_dim)).to(torch.float16)
466+
else:
467+
exponent = torch.rand(1).to(torch.float16).item()
468+
469+
reference = torch.pow(x, exponent=exponent).numpy()
470+
471+
model = NNFactory()
472+
par = model.parameter(x.shape, np.float16)
473+
if exponent == "tensor" and exponent_type == "parameter":
474+
exponent_par = model.parameter(exponent.shape, np.float16)
475+
_ = torch.pow(par, exponent_par)
476+
model.compile()
477+
out = model(x, exponent).numpy()
478+
else:
479+
_ = torch.pow(par, exponent=exponent)
480+
model.compile()
481+
out = model(x).numpy()
482+
483+
assert out.shape == reference.shape, "Output shape mismatch"
484+
assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf"
485+
assert np.isfinite(out).all(), "NPU output contains NaN or Inf"
486+
487+
assert 1 - r2_score(reference, out) < 0.01
488+
489+
490+
@pytest.mark.parametrize("batch", [16, 128])
491+
@pytest.mark.parametrize("hidden_dim", [128, 256])
492+
@pytest.mark.parametrize("axis", [0, 1, -1, -2])
493+
def test_logsoftmax(batch, hidden_dim, axis):
494+
x = torch.rand((batch, hidden_dim)).to(torch.float16)
495+
496+
reference = torch.nn.functional.log_softmax(x, dim=axis).numpy()
497+
498+
model = NNFactory()
499+
par = model.parameter(x.shape, np.float16)
500+
_ = torch.nn.functional.log_softmax(par, dim=axis)
501+
model.compile()
502+
503+
out = model(x).numpy()
504+
505+
assert out.shape == reference.shape, "Output shape mismatch"
506+
assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf"
507+
assert np.isfinite(out).all(), "NPU output contains NaN or Inf"
508+
509+
assert 1 - r2_score(reference, out) < 0.01

0 commit comments

Comments
 (0)