Skip to content

Check copy alignment for MMA and Epilogue #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/cutlass/detail/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,18 @@ constexpr bool is_tma_copy_engine() {
return false;
}

template<class GmemTiledCopy>
constexpr bool is_xe_2d_copy_engine() {
if constexpr (cute::is_void_v<GmemTiledCopy>) {
return false;
}
// TODO(Codeplay): Add a marker base class to identify all xe_2d copy operations
#if defined(SYCL_INTEL_TARGET)
return true;
#endif
return false;
}

template <class X, class = void>
struct RawDtype { using type = X; };

Expand Down Expand Up @@ -356,6 +368,10 @@ get_alignment_count_from_gmem_tiled_copy() {
}
return 128 / sizeof_bits<Element>::value;
}
// Intel 2D copy
else if constexpr (is_xe_2d_copy_engine<GmemTiledCopy>()) {
return 128;
}
else {
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
return GmemTiledCopy::NumValSrc;
Expand Down
38 changes: 34 additions & 4 deletions include/cutlass/epilogue/collective/xe_array_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,41 @@ class CollectiveEpilogue<
}

template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool
static bool
can_implement(
ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
return true;
ProblemShape problem_shape,
Arguments const& args) {
constexpr int copy_alignment_bits = 128;

bool implementable = true;
bool fusion_implementable = true;

for (int i = 0; i < problem_shape.groups(); ++i) {
auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;

if constexpr (is_destination_supported) {
constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits<ElementD>::value;
implementable &= cutlass::detail::check_alignment<min_aligned_elements_D>(cute::make_shape(M,N,L), args.dD);
}

if constexpr (is_source_supported) {
constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits<ElementC>::value;
implementable &= cutlass::detail::check_alignment<min_aligned_elements_C>(cute::make_shape(M,N,L), args.dC);
}

fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread);
}

if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n");
}

if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
}

return implementable && fusion_implementable;
}

CUTLASS_HOST_DEVICE
Expand Down
33 changes: 30 additions & 3 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,36 @@ class CollectiveEpilogue<
template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool
can_implement(
ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
return true;
ProblemShape const& problem_shapes,
Arguments const& args) {
constexpr int copy_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shapes, 1);
auto [M,N,K,L] = problem_shape_MNKL;

bool implementable = true;
bool fusion_implementable = true;

if constexpr (is_destination_supported) {
constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits<ElementD>::value;
implementable &= cutlass::detail::check_alignment<min_aligned_elements_D>(cute::make_shape(M,N,L), args.dD);
}

if constexpr (is_source_supported) {
constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits<ElementC>::value;
implementable &= cutlass::detail::check_alignment<min_aligned_elements_C>(cute::make_shape(M,N,L), args.dC);
}

fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread);

if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n");
}

if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
}

return implementable && fusion_implementable;
}

CUTLASS_HOST_DEVICE
Expand Down
29 changes: 24 additions & 5 deletions include/cutlass/gemm/collective/xe_array_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,6 @@ struct CollectiveMma<MainloopIntelXeXMX16Group<Stages, Schedule>, TileShape_, El
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
(void) workspace;

// Batches/Groups are managed by using appropriate pointers to input matrices
const int32_t mock_L = 1;
ElementA const* ptr_A_first_batch = reinterpret_cast<ElementA const*>(args.ptr_A);
ElementB const* ptr_B_first_batch = reinterpret_cast<ElementB const*>(args.ptr_B);

auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));;
auto init_M = get<0>(problem_shape_MNK);
auto init_N = get<1>(problem_shape_MNK);
Expand All @@ -157,6 +152,30 @@ struct CollectiveMma<MainloopIntelXeXMX16Group<Stages, Schedule>, TileShape_, El
};
}

template<class ProblemShape>
static bool
can_implement(
ProblemShape problem_shapes,
Arguments const& args) {
constexpr int copy_alignment_bits = 128;
bool implementable = true;

constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits<ElementA>::value;
constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits<ElementB>::value;
for (int i = 0; i < problem_shapes.groups(); i++) {
auto problem_shape_MNKL = append<4>(problem_shapes.get_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
implementable &= cutlass::detail::check_alignment<min_aligned_elements_A>(cute::make_shape(M,K,L), args.dA);
implementable &= cutlass::detail::check_alignment<min_aligned_elements_B>(cute::make_shape(N,K,L), args.dB);
}

if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n");
}

return implementable;
}

/// Perform a subgroup-scoped matrix multiply-accumulate
template <class FrgTensorD, class TensorA, class TensorB, class FrgTensorC, class KTileIterator,
class BlkCoord, class LoadTensors>
Expand Down
23 changes: 23 additions & 0 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,29 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
return Params{tiled_copy_a, tiled_copy_b};
}

template<class ProblemShape>
static bool
can_implement(
ProblemShape problem_shapes,
Arguments const& args) {
constexpr int copy_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shapes, 1);
auto [M,N,K,L] = problem_shape_MNKL;

bool implementable = true;

constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits<ElementA>::value;
implementable &= cutlass::detail::check_alignment<min_aligned_elements_A>(cute::make_shape(M,K,L), args.dA);
Comment on lines +161 to +162
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now you are checking only one alignment instead of separate inner and outer dimension and strides. How does that work with 2d block copies requirements?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the batch alignment too.

constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits<ElementB>::value;
implementable &= cutlass::detail::check_alignment<min_aligned_elements_B>(cute::make_shape(N,K,L), args.dB);

if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n");
}

return implementable;
}

/// Perform a subgroup-scoped matrix multiply-accumulate
template <class FrgTensorD, class TensorA, class TensorB, class FrgTensorC, class KTileIterator, class BlkCoord>
CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum,
Expand Down
55 changes: 14 additions & 41 deletions include/cutlass/gemm/kernel/xe_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,47 +159,20 @@ class GemmUniversal<

static bool
can_implement(Arguments const& args) {
auto m = get<0>(args.problem_shape);
auto n = get<1>(args.problem_shape);
auto k = get<2>(args.problem_shape);
auto l = get<3>(args.problem_shape);
bool is_batch = l > 1;
// all these requirements are in bytes
constexpr int inner_alignment_requirement = 16;
constexpr int outer_alignment_requirement = 64;
constexpr int width_alignment_requirement = 4;

auto check_stride = [is_batch](auto stride, int el_size){
auto a = get<0>(stride);
auto b = get<1>(stride);
auto valid_is_unit = a == 1 || b == 1;
auto inner = a == 1 ? b : a;
auto valid_inner = inner % (inner_alignment_requirement / el_size) == 0;
auto valid_outer = !is_batch || get<2>(stride) % (outer_alignment_requirement / el_size) == 0;
return valid_is_unit && valid_inner && valid_outer;
};
bool strides_valid = check_stride(args.mainloop.dA, sizeof(ElementA)) &&
check_stride(args.mainloop.dB, sizeof(ElementB)) &&
// TODO(Codeplay): Use proper check when ElementC is correctly set.
((args.epilogue.ptr_C == nullptr) ||
check_stride(args.epilogue.dC, sizeof(ElementC))) &&
check_stride(args.epilogue.dD, sizeof(ElementD));
// TODO(codeplay): base *_valid on the atom shapes
auto check_dim = [](int dimm, int el_size, bool do_check){
return !do_check || dimm % (width_alignment_requirement / el_size) == 0;
};
bool m_valid = m > 0 && check_dim(m, sizeof(ElementA), get<0>(args.mainloop.dA) == _1{}) &&
check_dim(m, sizeof(ElementC), get<0>(args.epilogue.dC) == _1{}) &&
check_dim(m, sizeof(ElementD), get<0>(args.epilogue.dD) == _1{});
bool n_valid = n > 0 && check_dim(n, sizeof(ElementB), get<1>(args.mainloop.dB) == _1{}) &&
check_dim(n, sizeof(ElementC), get<1>(args.epilogue.dC) == _1{}) &&
check_dim(n, sizeof(ElementD), get<1>(args.epilogue.dD) == _1{});
bool k_valid = k > 0 && check_dim(k, sizeof(ElementA), get<0>(args.mainloop.dA) == _1{}) &&
check_dim(k, sizeof(ElementB), get<0>(args.mainloop.dB) == _1{});
bool shape_implementable = m_valid && n_valid && k_valid && strides_valid;
bool mode_implementable = args.mode == GemmUniversalMode::kGemm ||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
return shape_implementable && mode_implementable && TileScheduler::can_implement(args.scheduler);
auto problem_shape_MNKL = append<4>(args.problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;

bool implementable = true;

implementable = implementable && (args.mode == GemmUniversalMode::kGemm ||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4));

implementable &= TileScheduler::can_implement(args.scheduler);

implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);

return implementable;
}

static int
Expand Down
19 changes: 15 additions & 4 deletions include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,21 @@ class GemmUniversal<
}

static bool
can_implement(Arguments const& args) {
bool mode_implementable = args.mode == GemmUniversalMode::kGrouped
&& rank(typename ProblemShape::UnderlyingProblemShape{}) == 3;
return mode_implementable && TileScheduler::can_implement(args.scheduler);
can_implement(Arguments const& args) {constexpr int copy_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(args.problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;

bool implementable = true;

implementable = implementable && (args.mode == GemmUniversalMode::kGrouped ||
(args.mode == GemmUniversalMode::kBatched && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3));

implementable = implementable && TileScheduler::can_implement(args.scheduler);

implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);

return implementable;
}

static size_t
Expand Down
20 changes: 13 additions & 7 deletions test/unit/gemm/device/gemm_testbed_3x.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4024,8 +4024,8 @@ template <typename Gemm, template <class T> class ActivationFunctor =
cutlass::epilogue::thread::Identity>
// TODO(Codeplay): remove the test_batch option once batching is enabled for all tests
bool TestXe(
double alpha = 1.0, double beta = 0.0,
bool test_batch = true, int max_alignment = 8,
double alpha = 1.0, double beta = cute::is_same_v<typename Gemm::GemmKernel::ElementC, void> ? 0.0 : 1.0,
bool test_batch = true,
CheckEquality check_relative_equality = CheckEquality::RELATIVE) {
using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
Expand All @@ -4040,14 +4040,20 @@ bool TestXe(

// For M & N we test a small and a big size
// For K, we currently only support K = TileShapeK
// TODO(codeplay): unhardcode max_alignment

std::vector<int> problem_size_m{max_alignment, 512 - 3 * max_alignment};
std::vector<int> problem_size_n{max_alignment, 512 - 2 * max_alignment};
int max_alignment_m = std::max({Gemm::kAlignmentA, Gemm::kAlignmentC, Gemm::kAlignmentD});
int max_alignment_n = std::max({Gemm::kAlignmentB, Gemm::kAlignmentC, Gemm::kAlignmentD});
if constexpr (std::is_base_of_v<cutlass::epilogue::fusion::FusionOperation, typename Gemm::EpilogueOutputOp>) {
max_alignment_m = std::max(max_alignment_m, Gemm::EpilogueOutputOp::AlignmentAux);
max_alignment_n = std::max(max_alignment_n, Gemm::EpilogueOutputOp::AlignmentAux);
}
std::vector<int> problem_size_m = {max_alignment_m, 512 - 3 * max_alignment_m};
std::vector<int> problem_size_n = {max_alignment_n, 512 - 2 * max_alignment_n};
std::vector<int> problem_size_l = test_batch ? std::vector{1, 3, 4} : std::vector{1};

constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});
std::vector<int> problem_size_k{TileShapeK, TileShapeK*32};
int max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
std::vector<int> problem_size_k = {max_alignment_k, TileShapeK * (Stages + 1) - max_alignment_k};

using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode;
std::vector decomposition_modes = {DecompositionMode::Heuristic};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST(XE_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 256x256x32_LD32x32) {

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 1.0));
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
}

TEST(XE_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 256x256x32_LD16x32) {
Expand Down Expand Up @@ -204,7 +204,7 @@ TEST(XE_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 256x256x32_LD16x32) {

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 1.0));
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
}

TEST(XE_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 256x256x32_LDA8x32_LDB16x32) {
Expand Down Expand Up @@ -277,7 +277,7 @@ TEST(XE_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 256x256x32_LDA8x32_LD

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 1.0));
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ TEST(XE_Device_GemmUniversal_f16t_s4n_f32t_mixed_input_tensor_op_f32, 128x128x64
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// TODO(Codeplay): gemm batch doesn't work for mixed type
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false, 16);
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false);
EXPECT_TRUE(passed);
}
////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ TEST(XE_Device_GemmUniversal_f16t_s4t_f32t_mixed_input_tensor_op_f32, 128x128x64
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// TODO(Codeplay): gemm batch doesn't work for mixed type
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false, 32);
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false);
EXPECT_TRUE(passed);
}
////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ TEST(XE_Device_Gemm_bf16t_bf16t_f32t_tensor_op_gmma_f32_epilogue, 256x256x32_Lin

using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_gmma_f32_epilogue<CollectiveEpilogue>::Gemm;

bool passed = test::gemm::device::TestXe<Gemm, epilogue::thread::ReLu>(1.0, 1.0);
bool passed = test::gemm::device::TestXe<Gemm, epilogue::thread::ReLu>();
EXPECT_TRUE(passed);
}

Expand Down Expand Up @@ -195,7 +195,7 @@ TEST(XE_Device_Gemm_bf16t_bf16t_f32_tensor_op_gmma_f32_epilogue, 256x256x32_LinC

using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_gmma_f32_epilogue<CollectiveEpilogue>::Gemm;

bool passed = test::gemm::device::TestXe<Gemm>(1.0, 0.0);
bool passed = test::gemm::device::TestXe<Gemm>();
EXPECT_TRUE(passed);
}
}
Expand Down
Loading
Loading