Description
Description
LightGBM will silently fail if the total number of data points used in a single iteration exceeds the int32_t
upper limit of 2^31
. This is not always the same as saying the total number of rows cannot exceed this limit; it seems the bug happens after bagging. The way it will fail is by outputting a model containing a single constant tree.
Reproducible example
The snippet below contains a function run
, which generates a small problem and runs distributed lightgbm on it with some simple settings that will ensure it requires more than one tree to be fit. So when we only a single tree is output we can see this as an indication that training failed (you can go into the model yourself to verify that it outputs a single constant tree).
import time
import pandas as pd
import lightgbm as lgb
import dask
import dask.array
import dask.dataframe
import dask.distributed
client = dask.distributed.Client(scheduler_file="redacted")
def run(client, num_rows, bagging_fraction):
client.restart()
time.sleep(3)
# Construct a 'simple' problem
X = dask.array.random.random(size=(num_rows, 4))
y = dask.array.where(X[:,0] > X[:,1], X[:,2], X[:,3])
# Parameters
params = {
"verbose": -1,
"seed": 1,
"deterministic": True,
# Turn on bagging
"bagging_freq": 1,
"bagging_fraction": bagging_fraction,
# Some parameters that guarantee we need to fit more than one tree.
"num_iterations": 3,
"max_depth": 1,
"num_leaves": 2,
"learning_rate": 0.1
}
model = lgb.DaskLGBMRegressor(
client=client, tree_learner="data", silent=True, **params
).fit(X, y)
return model.booster_.num_trees()
df = pd.DataFrame(
data=[
[
num_rows,
bagging_fraction,
run(client, num_rows, bagging_fraction),
num_rows * bagging_fraction < 2**31
]
for num_rows, bagging_fraction in [
(2_000_000_000, 1.0),
(2_200_000_000, 1.0),
(2_200_000_000, 0.8),
(4_000_000_000, 0.6),
(4_000_000_000, 0.5),
]
],
columns=["num_rows", "bagging_fraction", "num_trees", "rows * frac < 2^31"]
)
df
Produces this table:
num_rows bagging_fraction num_trees rows * frac < 2^31
2_000_000_000 1.0 3 True
2_200_000_000 1.0 1 False
2_200_000_000 0.8 3 True
4_000_000_000 0.6 1 False
4_000_000_000 0.5 3 True
Environment info
I'm using LightGBM 3.3.2. I'm running a dask cluster with 10 workers although I suspect this is irrelevant and the bug is inside LightGBM.
Additional Comments
It seems there is some 32 bit integer computation somewhere in LightGBM that should really be a 64 bit integer. I know data_size_t
is an int32_t
but I don't think changing that is the only way to fix this. By-node indexing can still be 32-bit as evidenced by e.g. the 4B row job with 0.5
bagging succeeding, it's only the global data count that is an issue.
I'm happy to help fix this but I'm not really familiar enough with internal LightGBM to know exactly which 32 bit fields might overflow here. If you give me some pointers I could help.