Skip to content

Commit 1863985

Browse files
minhua-chenfacebook-github-bot
authored andcommitted
ensemble rowwise adagrad (fbgemm backend) (pytorch#2889)
Summary: X-link: facebookresearch/FBGEMM#49 Pull Request resolved: pytorch#2889 ensemble rowwise adagrad (fbgemm diff) Differential Revision: D60189486 Reviewed By: csmiler, spcyppt
1 parent 4ae45b7 commit 1863985

File tree

5 files changed

+128
-1
lines changed

5 files changed

+128
-1
lines changed

fbgemm_gpu/FbgemmGpu.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ set(GPU_ONLY_OPTIMIZERS
6060
lamb
6161
partial_rowwise_adam
6262
partial_rowwise_lamb
63+
ensemble_rowwise_adagrad
6364
lars_sgd
6465
none
6566
rowwise_adagrad_with_counter)
@@ -86,6 +87,7 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS})
8687
set(VBE_OPTIMIZERS
8788
rowwise_adagrad
8889
rowwise_adagrad_with_counter
90+
ensemble_rowwise_adagrad
8991
sgd
9092
dense)
9193

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def generate() -> None:
335335
lars_sgd(),
336336
partial_rowwise_adam(),
337337
partial_rowwise_lamb(),
338+
ensemble_rowwise_adagrad(),
338339
rowwise_adagrad(),
339340
approx_rowwise_adagrad(),
340341
rowwise_adagrad_with_weight_decay(),

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,102 @@ def adam() -> Dict[str, Any]:
10201020
}
10211021

10221022

1023+
def ensemble_rowwise_adagrad() -> Dict[str, Any]:
1024+
split_precomputation = """
1025+
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
1026+
"""
1027+
split_precomputation += generate_optimized_grad_sum_loop_access(
1028+
"""
1029+
const float4* grad = &{grad_vec}.acc;
1030+
auto gx = grad->x;
1031+
auto gy = grad->y;
1032+
auto gz = grad->z;
1033+
auto gw = grad->w;
1034+
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
1035+
"""
1036+
)
1037+
split_precomputation += """
1038+
const at::acc_type<cache_t, true> g_avg_square =
1039+
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
1040+
1041+
at::acc_type<cache_t, true> multiplier;
1042+
at::acc_type<bool, true> should_ema;
1043+
at::acc_type<bool, true> should_swap;
1044+
if (threadIdx.x == 0) {
1045+
at::acc_type<cache_t, true> new_sum_square_grads = momentum2[idx] + g_avg_square;
1046+
momentum2[idx] = new_sum_square_grads;
1047+
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
1048+
row_counter[idx] += 1.0;
1049+
should_ema = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_ema)) == 0);
1050+
should_swap = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_swap)) == 0);
1051+
}
1052+
multiplier = SHFL_SYNC(multiplier, 0);
1053+
should_ema = SHFL_SYNC(should_ema, 0);
1054+
should_swap = SHFL_SYNC(should_swap, 0);
1055+
"""
1056+
1057+
split_weight_update = """
1058+
weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x;
1059+
weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y;
1060+
weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z;
1061+
weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w;
1062+
1063+
if (should_ema) { // slow table ema
1064+
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
1065+
m_t.acc.x = (1.0 - momentum) * weight_new.acc.x + momentum * m_t.acc.x;
1066+
m_t.acc.y = (1.0 - momentum) * weight_new.acc.y + momentum * m_t.acc.y;
1067+
m_t.acc.z = (1.0 - momentum) * weight_new.acc.z + momentum * m_t.acc.z;
1068+
m_t.acc.w = (1.0 - momentum) * weight_new.acc.w + momentum * m_t.acc.w;
1069+
m_t.store(&momentum1[idx * D + d]);
1070+
}
1071+
1072+
if (should_swap) { // slow-to-fast swap
1073+
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
1074+
weight_new.acc.x = m_t.acc.x;
1075+
weight_new.acc.y = m_t.acc.y;
1076+
weight_new.acc.z = m_t.acc.z;
1077+
weight_new.acc.w = m_t.acc.w;
1078+
}
1079+
"""
1080+
1081+
split_weight_update_cpu = "" # TODO
1082+
1083+
return {
1084+
"optimizer": "ensemble_rowwise_adagrad",
1085+
"is_prototype_optimizer": True,
1086+
"args": OptimizerArgsSet.create(
1087+
[
1088+
OptimItem(
1089+
ArgType.PLACEHOLDER_TENSOR,
1090+
"momentum1",
1091+
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
1092+
),
1093+
OptimItem(
1094+
ArgType.PLACEHOLDER_TENSOR,
1095+
"momentum2",
1096+
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
1097+
),
1098+
OptimItem(ArgType.TENSOR, "row_counter"),
1099+
OptimItem(ArgType.FLOAT, "learning_rate"),
1100+
OptimItem(ArgType.FLOAT, "eps"),
1101+
OptimItem(ArgType.FLOAT, "step_ema"),
1102+
OptimItem(ArgType.FLOAT, "step_swap"),
1103+
OptimItem(ArgType.FLOAT, "step_start"),
1104+
OptimItem(ArgType.FLOAT, "momentum"),
1105+
]
1106+
),
1107+
"split_precomputation": split_precomputation,
1108+
"split_weight_update": split_weight_update,
1109+
"split_post_update": "",
1110+
"split_weight_update_cpu": split_weight_update_cpu,
1111+
"has_cpu_support": False,
1112+
"has_gpu_support": True,
1113+
"has_vbe_support": True,
1114+
"has_global_weight_decay_support": False,
1115+
"has_ssd_support": False,
1116+
}
1117+
1118+
10231119
def partial_rowwise_adam() -> Dict[str, Any]:
10241120
split_precomputation = """
10251121
at::acc_type<cache_t, true> g_local_sum_square = 0.0;

fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
6767
{%- if "beta2" in args.split_function_arg_names %}
6868
beta2: float = 0.999,
6969
{%- endif %}
70+
{%- if "step_ema" in args.split_function_arg_names %}
71+
step_ema: float = 10000,
72+
{%- endif %}
73+
{%- if "step_swap" in args.split_function_arg_names %}
74+
step_swap: float = 10000,
75+
{%- endif %}
76+
{%- if "step_start" in args.split_function_arg_names %}
77+
step_start: float = 0,
78+
{%- endif %}
7079
{%- if "weight_decay" in args.split_function_arg_names %}
7180
weight_decay: float = 0.0,
7281
{%- endif %}
@@ -95,6 +104,15 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
95104
{%- if "beta2" in args.split_function_arg_names %}
96105
beta2=beta2,
97106
{%- endif %}
107+
{%- if "step_ema" in args.split_function_arg_names %}
108+
step_ema=step_ema,
109+
{%- endif %}
110+
{%- if "step_swap" in args.split_function_arg_names %}
111+
step_swap=step_swap,
112+
{%- endif %}
113+
{%- if "step_start" in args.split_function_arg_names %}
114+
step_start=step_start,
115+
{%- endif %}
98116
{%- if "weight_decay" in args.split_function_arg_names %}
99117
weight_decay=weight_decay,
100118
{%- endif %}
@@ -139,7 +157,7 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
139157
rowwise = False
140158
{% endif %}
141159
{% elif state_tensor == "momentum2" %}
142-
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %}
160+
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb", "ensemble_rowwise_adagrad"] %}
143161
rowwise = True
144162
{% else %}
145163
rowwise = False
@@ -189,6 +207,15 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
189207
{%- if "beta2" in args.split_function_arg_names %}
190208
self.beta2 = beta2
191209
{%- endif %}
210+
{%- if "step_ema" in args.split_function_arg_names %}
211+
self.step_ema = step_ema
212+
{%- endif %}
213+
{%- if "step_swap" in args.split_function_arg_names %}
214+
self.step_swap = step_swap
215+
{%- endif %}
216+
{%- if "step_start" in args.split_function_arg_names %}
217+
self.step_start = step_start
218+
{%- endif %}
192219
{%- if "weight_decay" in args.split_function_arg_names %}
193220
self.weight_decay = weight_decay
194221
{%- endif %}

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class EmbOptimType(enum.Enum):
3333
SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables
3434
MADGRAD = "madgrad"
3535
EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated
36+
ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad"
3637
NONE = "none"
3738

3839
def __str__(self) -> str:

0 commit comments

Comments
 (0)