Skip to content

Dask LightGBM breaks if num_rows * bagging_fraction > int32_t max #5861

Open
@adfea9c0

Description

@adfea9c0

Description

LightGBM will silently fail if the total number of data points used in a single iteration exceeds the int32_t upper limit of 2^31. This is not always the same as saying the total number of rows cannot exceed this limit; it seems the bug happens after bagging. The way it will fail is by outputting a model containing a single constant tree.

Reproducible example

The snippet below contains a function run, which generates a small problem and runs distributed lightgbm on it with some simple settings that will ensure it requires more than one tree to be fit. So when we only a single tree is output we can see this as an indication that training failed (you can go into the model yourself to verify that it outputs a single constant tree).

import time

import pandas as pd
import lightgbm as lgb

import dask
import dask.array
import dask.dataframe
import dask.distributed

client = dask.distributed.Client(scheduler_file="redacted")

def run(client, num_rows, bagging_fraction):
    client.restart()
    time.sleep(3)
    
    # Construct a 'simple' problem
    X = dask.array.random.random(size=(num_rows, 4))
    y = dask.array.where(X[:,0] > X[:,1], X[:,2], X[:,3])
        
    # Parameters
    params = {
        "verbose": -1,
        "seed": 1,
        "deterministic": True,
        # Turn on bagging
        "bagging_freq": 1,
        "bagging_fraction": bagging_fraction,
        # Some parameters that guarantee we need to fit more than one tree.
        "num_iterations": 3,
        "max_depth": 1,
        "num_leaves": 2,
        "learning_rate": 0.1
    }
    
    model = lgb.DaskLGBMRegressor(
        client=client, tree_learner="data", silent=True, **params
    ).fit(X, y)
    
    return model.booster_.num_trees()

df = pd.DataFrame(
    data=[
        [
            num_rows,
            bagging_fraction,
            run(client, num_rows, bagging_fraction),
            num_rows * bagging_fraction < 2**31
        ]
        for num_rows, bagging_fraction in [
            (2_000_000_000, 1.0),
            (2_200_000_000, 1.0),
            (2_200_000_000, 0.8),
            (4_000_000_000, 0.6),
            (4_000_000_000, 0.5),
        ]
    ],
    columns=["num_rows", "bagging_fraction", "num_trees", "rows * frac < 2^31"]
)
df

Produces this table:

     num_rows  bagging_fraction  num_trees   rows * frac < 2^31
2_000_000_000               1.0          3                 True
2_200_000_000               1.0          1                False
2_200_000_000               0.8          3                 True
4_000_000_000               0.6          1                False
4_000_000_000               0.5          3                 True

Environment info

I'm using LightGBM 3.3.2. I'm running a dask cluster with 10 workers although I suspect this is irrelevant and the bug is inside LightGBM.

Additional Comments

It seems there is some 32 bit integer computation somewhere in LightGBM that should really be a 64 bit integer. I know data_size_t is an int32_t but I don't think changing that is the only way to fix this. By-node indexing can still be 32-bit as evidenced by e.g. the 4B row job with 0.5 bagging succeeding, it's only the global data count that is an issue.

I'm happy to help fix this but I'm not really familiar enough with internal LightGBM to know exactly which 32 bit fields might overflow here. If you give me some pointers I could help.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions