19
19
#include " fbgemm_gpu/embedding_common.h"
20
20
#include " fbgemm_gpu/utils/dispatch_macros.h"
21
21
22
+ #if FBGEMM_GPU_MEMCHECK
23
+ #define FBGEMM_MEM_CHECK_ONLY
24
+ #else
25
+ #define FBGEMM_MEM_CHECK_ONLY maybe_unused
26
+ #endif
27
+
22
28
using Tensor = at::Tensor;
23
29
using namespace fbgemm_gpu ;
24
30
25
31
namespace {
26
- template <typename scalar_t , typename grad_t >
32
+ template <typename index_t , typename scalar_t , typename grad_t >
27
33
void split_embedding_backward_approx_cpu_kernel (
28
34
Tensor grad_output,
29
35
Tensor host_weights,
@@ -44,8 +50,11 @@ void split_embedding_backward_approx_cpu_kernel(
44
50
{{ args.split_cpu_kernel_args | join (" , " ) }}) {
45
51
auto grad_output_data = grad_output.accessor <grad_t , 2 >();
46
52
auto host_weights_data = host_weights.accessor <scalar_t , 1 >();
47
- const auto indices_data = indices.accessor <int64_t , 1 >();
48
- const auto offsets_data = offsets.accessor <int64_t , 1 >();
53
+
54
+ [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = " split_embedding_backward_approx_cpu_kernel" ;
55
+ const auto indices_data = MAKE_TA_WITH_NAME (func_name, indices, index_t , 1 );
56
+ const auto offsets_data = MAKE_TA_WITH_NAME (func_name, offsets, index_t , 1 );
57
+
49
58
// If indice_weights are not defined, then this accessor won't be used
50
59
auto indice_weights_data = indice_weights.defined ()
51
60
? indice_weights.accessor <at::acc_type<scalar_t , true >, 1 >()
@@ -133,75 +142,84 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
133
142
!indice_weights.defined () && static_cast <PoolingMode>(pooling_mode) == PoolingMode::SUM;
134
143
135
144
if (use_fbgemm) {
136
- auto grad_stride = grad_output.size (1 );
137
- const float * grad_output_data = grad_output.data_ptr <float >();
138
- float * host_weights_data = host_weights.data_ptr <float >();
139
- const int64_t * indices_data = indices.data_ptr <int64_t >();
140
- const int64_t * offsets_data = offsets.data_ptr <int64_t >();
141
- const auto hash_size_cumsum_data = hash_size_cumsum.accessor <int64_t , 1 >();
142
- float * momentum1_data = momentum1_host.data_ptr <float >();
143
-
144
- at::parallel_for (0 , T * B, 0 , [&](int64_t tb_begin, int64_t tb_end) {
145
- int t_begin = tb_begin / B;
146
- int t_end = (tb_end + B - 1 ) / B;
147
- for (const auto t : c10::irange (t_begin,t_end)) {
148
- auto D_begin = D_offsets_data[t];
149
- auto D = D_offsets_data[t + 1 ] - D_offsets_data[t];
150
- auto table_begin = weights_offsets_data[t];
151
- auto momentum_begin = momentum1_offsets_data[t];
152
-
153
- int64_t hash_size;
154
- int t_temp = t + 1 ;
155
- do {
156
- hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
157
- ++t_temp;
158
- } while (hash_size == 0 );
159
-
160
- int b_begin = (t == t_begin) ? tb_begin % B : 0 ;
161
- int b_end = (t == t_end - 1 && tb_end % B != 0 ) ? tb_end % B : B;
162
-
163
- auto kernel =
164
- fbgemm::GenerateRowWiseSparseAdaGradFused<int64_t , int64_t , float >(
165
- D,
166
- /* prefetch=*/ 16 ,
167
- /* use_offsets=*/ true ,
168
- /* use_stochastic_round=*/ true ,
169
- /* grad_stride=*/ grad_stride);
170
- auto offsets_begin_ptr = offsets_data + t * B + b_begin;
171
- auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
172
- bool success = kernel (
173
- b_end - b_begin,
174
- index_size,
175
- hash_size,
176
- reinterpret_cast <float *>(host_weights_data + table_begin),
177
- reinterpret_cast <const float *>(
178
- grad_output_data + b_begin * grad_stride + D_begin),
179
- reinterpret_cast <float *>(momentum1_data + momentum_begin),
180
- indices_data + *offsets_begin_ptr,
181
- offsets_begin_ptr,
182
- eps,
183
- // fbgemm follows caffe2 convention of negative learning rate
184
- -learning_rate);
185
-
186
- if (!success) {
187
- fbgemm_gpu::report_embedding_error (
188
- t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
145
+ AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " split_embedding_backward_approx_cpu_kernel_1" , [&] {
146
+
147
+ auto grad_stride = grad_output.size (1 );
148
+ const float * grad_output_data = grad_output.data_ptr <float >();
149
+ float * host_weights_data = host_weights.data_ptr <float >();
150
+
151
+ const auto * indices_data = indices.data_ptr <index_t >();
152
+ const auto * offsets_data = offsets.data_ptr <index_t >();
153
+
154
+ const auto hash_size_cumsum_data = hash_size_cumsum.accessor <int64_t , 1 >();
155
+ float * momentum1_data = momentum1_host.data_ptr <float >();
156
+
157
+ at::parallel_for (0 , T * B, 0 , [&](int64_t tb_begin, int64_t tb_end) {
158
+ int t_begin = tb_begin / B;
159
+ int t_end = (tb_end + B - 1 ) / B;
160
+
161
+ for (const auto t : c10::irange (t_begin,t_end)) {
162
+ auto D_begin = D_offsets_data[t];
163
+ auto D = D_offsets_data[t + 1 ] - D_offsets_data[t];
164
+ auto table_begin = weights_offsets_data[t];
165
+ auto momentum_begin = momentum1_offsets_data[t];
166
+
167
+ int64_t hash_size;
168
+ int t_temp = t + 1 ;
169
+ do {
170
+ hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
171
+ ++t_temp;
172
+ } while (hash_size == 0 );
173
+
174
+ int b_begin = (t == t_begin) ? tb_begin % B : 0 ;
175
+ int b_end = (t == t_end - 1 && tb_end % B != 0 ) ? tb_end % B : B;
176
+
177
+ auto kernel =
178
+ fbgemm::GenerateRowWiseSparseAdaGradFused<index_t , index_t , float >(
179
+ D,
180
+ /* prefetch=*/ 16 ,
181
+ /* use_offsets=*/ true ,
182
+ /* use_stochastic_round=*/ true ,
183
+ /* grad_stride=*/ grad_stride);
184
+ auto offsets_begin_ptr = offsets_data + t * B + b_begin;
185
+ auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
186
+ bool success = kernel (
187
+ b_end - b_begin,
188
+ index_size,
189
+ hash_size,
190
+ reinterpret_cast <float *>(host_weights_data + table_begin),
191
+ reinterpret_cast <const float *>(
192
+ grad_output_data + b_begin * grad_stride + D_begin),
193
+ reinterpret_cast <float *>(momentum1_data + momentum_begin),
194
+ indices_data + *offsets_begin_ptr,
195
+ offsets_begin_ptr,
196
+ eps,
197
+ // fbgemm follows caffe2 convention of negative learning rate
198
+ -learning_rate);
199
+
200
+ if (!success) {
201
+ fbgemm_gpu::report_embedding_error (
202
+ t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
203
+ }
189
204
}
190
- }
191
- }); // parallel_for
205
+ }); // parallel_for
206
+ }); // dispatch indices.scalar_type()
207
+
192
208
return ;
193
209
} // use_fbgemm
194
210
195
211
{% endif %}
196
212
197
- FBGEMM_DISPATCH_FLOAT_AND_HALF (
198
- grad_output.scalar_type (), " split_embedding_backward_cpu" , [&] {
213
+ AT_DISPATCH_INDEX_TYPES (
214
+ indices.scalar_type (), " split_embedding_backward_approx_cpu_kernel_1" , [&] {
215
+
216
+ FBGEMM_DISPATCH_FLOAT_AND_HALF (
217
+ grad_output.scalar_type (), " split_embedding_backward_approx_cpu_kernel_2" , [&] {
199
218
using grad_t = scalar_t ;
200
- FBGEMM_DISPATCH_FLOAT_AND_HALF (
201
- host_weights.scalar_type (),
202
- " split_embedding_backward_cpu_inner" ,
203
- [&] {
204
- split_embedding_backward_approx_cpu_kernel<scalar_t , grad_t >(
219
+
220
+ FBGEMM_DISPATCH_FLOAT_AND_HALF (
221
+ host_weights.scalar_type (), " split_embedding_backward_approx_cpu_kernel_3" , [&] {
222
+ split_embedding_backward_approx_cpu_kernel<index_t , scalar_t , grad_t >(
205
223
grad_output,
206
224
host_weights,
207
225
weights_offsets_data,
@@ -220,7 +238,8 @@ for (const auto t : c10::irange(t_begin,t_end)) {
220
238
{% endif %}
221
239
{{ args.split_cpu_kernel_arg_constructors | join (" , " ) }});
222
240
}); // dispatch host_weights.scalar_type()
223
- }); // dispatch grad_output.scalar_type()
241
+ }); // dispatch grad_output.scalar_type()
242
+ }); // dispatch indices.scalar_type()
224
243
225
244
return ;
226
245
}
0 commit comments