Skip to content

Commit 6623164

Browse files
author
devsh
committed
Merge remote-tracking branch 'remotes/origin/sync-subgroup-fix'
2 parents ef2ee17 + f1c5a77 commit 6623164

File tree

9 files changed

+114
-21
lines changed

9 files changed

+114
-21
lines changed

include/nbl/builtin/hlsl/spirv_intrinsics/subgroup_ballot.hlsl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@ namespace hlsl
1515
{
1616
namespace spirv
1717
{
18+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
1819
[[vk::ext_builtin_input(spv::BuiltInSubgroupEqMask)]]
1920
static const uint32_t4 BuiltInSubgroupEqMask;
21+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
2022
[[vk::ext_builtin_input(spv::BuiltInSubgroupGeMask)]]
2123
static const uint32_t4 BuiltInSubgroupGeMask;
24+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
2225
[[vk::ext_builtin_input(spv::BuiltInSubgroupGtMask)]]
2326
static const uint32_t4 BuiltInSubgroupGtMask;
27+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
2428
[[vk::ext_builtin_input(spv::BuiltInSubgroupLeMask)]]
2529
static const uint32_t4 BuiltInSubgroupLeMask;
30+
[[vk::ext_capability(spv::CapabilityGroupNonUniformBallot)]]
2631
[[vk::ext_builtin_input(spv::BuiltInSubgroupLtMask)]]
2732
static const uint32_t4 BuiltInSubgroupLtMask;
2833

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PARAMS_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PARAMS_INCLUDED_
6+
7+
#ifdef __HLSL_VERSION
8+
#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
9+
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
10+
#endif
11+
#include "nbl/builtin/hlsl/concepts.hlsl"
12+
13+
namespace nbl
14+
{
15+
namespace hlsl
16+
{
17+
namespace subgroup2
18+
{
19+
20+
#ifdef __HLSL_VERSION
21+
template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config> && is_scalar_v<typename BinOp::type_t>)
22+
struct ArithmeticParams
23+
{
24+
using config_t = Config;
25+
using binop_t = BinOp;
26+
using scalar_t = typename BinOp::type_t;
27+
using type_t = vector<scalar_t, _ItemsPerInvocation>;
28+
using device_traits = device_capabilities_traits<device_capabilities>;
29+
30+
NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
31+
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
32+
// TODO add a IHV enum to device_capabilities_traits to check !is_nvidia
33+
};
34+
#endif
35+
36+
#ifndef __HLSL_VERSION
37+
#include <sstream>
38+
#include <string>
39+
struct SArithmeticParams
40+
{
41+
void init(const uint16_t _SubgroupSizeLog2, const uint16_t _ItemsPerInvocation)
42+
{
43+
SubgroupSizeLog2 = _SubgroupSizeLog2;
44+
ItemsPerInvocation = _ItemsPerInvocation;
45+
}
46+
47+
// alias should provide Binop and device_capabilities template parameters
48+
std::string getParamTemplateStructString()
49+
{
50+
std::ostringstream os;
51+
os << "nbl::hlsl::subgroup2::ArithmeticParams<nbl::hlsl::subgroup2::Configuration<" << SubgroupSizeLog2 << ">, Binop," << ItemsPerInvocation << ", device_capabilities>;";
52+
return os.str();
53+
}
54+
55+
uint32_t SubgroupSizeLog2;
56+
uint32_t ItemsPerInvocation;
57+
};
58+
#endif
59+
60+
}
61+
}
62+
}
63+
64+
#endif

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
66

7-
87
#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
9-
10-
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
118
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl"
12-
#include "nbl/builtin/hlsl/concepts.hlsl"
13-
149

1510
namespace nbl
1611
{
@@ -19,20 +14,6 @@ namespace hlsl
1914
namespace subgroup2
2015
{
2116

22-
template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config> && is_scalar_v<typename BinOp::type_t>)
23-
struct ArithmeticParams
24-
{
25-
using config_t = Config;
26-
using binop_t = BinOp;
27-
using scalar_t = typename BinOp::type_t;
28-
using type_t = vector<scalar_t, _ItemsPerInvocation>;
29-
using device_traits = device_capabilities_traits<device_capabilities>;
30-
31-
NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
32-
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
33-
// TODO add a IHV enum to device_capabilities_traits to check !is_nvidia
34-
};
35-
3617
template<typename Params>
3718
struct reduction : impl::reduction<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
3819
template<typename Params>

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ struct inclusive_scan<Params, BinOp, 1, false>
158158

159159
static scalar_t __call(scalar_t value)
160160
{
161+
// sync up each subgroup invocation so it runs in lockstep
162+
// not ideal because might not write to shared memory but a storage class is needed
163+
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
164+
161165
binop_t op;
162166
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
163167

@@ -185,6 +189,10 @@ struct exclusive_scan<Params, BinOp, 1, false>
185189

186190
scalar_t operator()(scalar_t value)
187191
{
192+
// sync up each subgroup invocation so it runs in lockstep
193+
// not ideal because might not write to shared memory but a storage class is needed
194+
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
195+
188196
scalar_t left = hlsl::mix(binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1), bool(glsl::gl_SubgroupInvocationID()));
189197
return inclusive_scan<Params, BinOp, 1, false>::__call(left);
190198
}
@@ -203,8 +211,11 @@ struct reduction<Params, BinOp, 1, false>
203211

204212
scalar_t operator()(scalar_t value)
205213
{
206-
binop_t op;
214+
// sync up each subgroup invocation so it runs in lockstep
215+
// not ideal because might not write to shared memory but a storage class is needed
216+
spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
207217

218+
binop_t op;
208219
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
209220
[unroll]
210221
for (uint32_t i = 0; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ struct SItemsPerInvoc
181181
};
182182
}
183183

184+
#include <sstream>
185+
#include <string>
184186
struct SArithmeticConfiguration
185187
{
186188
void init(const uint16_t _WorkgroupSizeLog2, const uint16_t _SubgroupSizeLog2, const uint16_t _ItemsPerInvocation)
@@ -203,6 +205,13 @@ struct SArithmeticConfiguration
203205
#undef DEFINE_ASSIGN
204206
}
205207

208+
std::string getConfigTemplateStructString()
209+
{
210+
std::ostringstream os;
211+
os << "nbl::hlsl::workgroup2::ArithmeticConfiguration<" << WorkgroupSizeLog2 << "," << SubgroupSizeLog2 << "," << ItemsPerInvocation_0 << ">;";
212+
return os.str();
213+
}
214+
206215
#define DEFINE_ASSIGN(TYPE,ID,...) TYPE ID;
207216
#include "impl/arithmetic_config_def.hlsl"
208217
#undef DEFINE_ASSIGN
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_BASIC_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_WORKGROUP2_BASIC_INCLUDED_
6+
7+
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
8+
9+
namespace nbl
10+
{
11+
namespace hlsl
12+
{
13+
namespace workgroup2
14+
{
15+
// empty
16+
}
17+
}
18+
}
19+
20+
#endif

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl"
88
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
99
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
10+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_params.hlsl"
1011
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl"
1112
#include "nbl/builtin/hlsl/mpl.hlsl"
1213
#include "nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl"

src/nbl/builtin/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portabili
332332
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/fft.hlsl")
333333
#subgroup2
334334
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/ballot.hlsl")
335+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/arithmetic_params.hlsl")
335336
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/arithmetic_portability.hlsl")
336337
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/arithmetic_portability_impl.hlsl")
337338
#shared header between C++ and HLSL
@@ -346,6 +347,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/scratch_size.hlsl")
346347
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shared_scan.hlsl")
347348
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shuffle.hlsl")
348349
#workgroup2
350+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/basic.hlsl")
349351
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/arithmetic_config.hlsl")
350352
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/impl/virtual_wg_size_def.hlsl")
351353
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/impl/items_per_invoc_def.hlsl")

0 commit comments

Comments
 (0)