Skip to content

Speedup Numba Scans with vector inputs #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 10, 2023

Conversation

ricardoV94
Copy link
Member

Related to #233

In talks with @jessegrabowski and @aseyboldt we found out that there is a large overhead of asarray when indexing vector inputs to create taps, as these become non-array numerical variables that must be wrapped again into a scalar tensor (at least until we allow proper scalar inputs in Scan).

The new benchmark test runs 2.5x faster on my machine compared to main.

The scan function used to look like this

def scan(n_steps, outer_in_1, outer_in_2, outer_in_3, outer_in_4):

    outer_in_3_len = outer_in_3.shape[0]
    outer_in_3_mitsot_storage = outer_in_3
    outer_in_4_len = outer_in_4.shape[0]
    outer_in_4_sitsot_storage = outer_in_4

    i = 0
    cond = np.array(False)
    while i < n_steps and not cond.item():
        (inner_out_0, inner_out_1) = scan_inner_func(np.asarray(outer_in_1[i]), np.asarray(outer_in_2[i]), np.asarray(outer_in_3_mitsot_storage[(i) % outer_in_3_len]), np.asarray(outer_in_3_mitsot_storage[(i + 1) % outer_in_3_len]), np.asarray(outer_in_4_sitsot_storage[(i) % outer_in_4_len]))

        outer_in_3_mitsot_storage[(i + 2) % outer_in_3_len] = inner_out_0
        outer_in_4_sitsot_storage[(i + 1) % outer_in_4_len] = inner_out_1
        i += 1

    if (i + 2) > outer_in_3_len:
        outer_in_3_mitsot_storage_shift = (i + 2) % (outer_in_3_len)
        outer_in_3_mitsot_storage_left = outer_in_3_mitsot_storage[:outer_in_3_mitsot_storage_shift]
        outer_in_3_mitsot_storage_right = outer_in_3_mitsot_storage[outer_in_3_mitsot_storage_shift:]
        outer_in_3_mitsot_storage = np.concatenate((outer_in_3_mitsot_storage_right, outer_in_3_mitsot_storage_left))
    if (i + 1) > outer_in_4_len:
        outer_in_4_sitsot_storage_shift = (i + 1) % (outer_in_4_len)
        outer_in_4_sitsot_storage_left = outer_in_4_sitsot_storage[:outer_in_4_sitsot_storage_shift]
        outer_in_4_sitsot_storage_right = outer_in_4_sitsot_storage[outer_in_4_sitsot_storage_shift:]
        outer_in_4_sitsot_storage = np.concatenate((outer_in_4_sitsot_storage_right, outer_in_4_sitsot_storage_left))

    return outer_in_3_mitsot_storage, outer_in_4_sitsot_storage

And now looks like this

def scan(n_steps, outer_in_1, outer_in_2, outer_in_3, outer_in_4):

    outer_in_3_len = outer_in_3.shape[0]
    outer_in_3_mitsot_storage = outer_in_3
    outer_in_4_len = outer_in_4.shape[0]
    outer_in_4_sitsot_storage = outer_in_4

    outer_in_1_temp_scalar_0 = np.empty((), dtype=np.float64)
    outer_in_2_temp_scalar_0 = np.empty((), dtype=np.float64)
    outer_in_3_mitsot_storage_temp_scalar_0 = np.empty((), dtype=np.float64)
    outer_in_3_mitsot_storage_temp_scalar_1 = np.empty((), dtype=np.float64)
    outer_in_4_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.float64)

    i = 0
    cond = np.array(False)
    while i < n_steps and not cond.item():
        outer_in_1_temp_scalar_0[()] = outer_in_1[i]
        outer_in_2_temp_scalar_0[()] = outer_in_2[i]
        outer_in_3_mitsot_storage_temp_scalar_0[()] = outer_in_3_mitsot_storage[(i) % outer_in_3_len]
        outer_in_3_mitsot_storage_temp_scalar_1[()] = outer_in_3_mitsot_storage[(i + 1) % outer_in_3_len]
        outer_in_4_sitsot_storage_temp_scalar_0[()] = outer_in_4_sitsot_storage[(i) % outer_in_4_len]

        (inner_out_0, inner_out_1) = scan_inner_func(outer_in_1_temp_scalar_0, outer_in_2_temp_scalar_0, outer_in_3_mitsot_storage_temp_scalar_0, outer_in_3_mitsot_storage_temp_scalar_1, outer_in_4_sitsot_storage_temp_scalar_0)

        outer_in_3_mitsot_storage[(i + 2) % outer_in_3_len] = inner_out_0
        outer_in_4_sitsot_storage[(i + 1) % outer_in_4_len] = inner_out_1
        i += 1

    if (i + 2) > outer_in_3_len:
        outer_in_3_mitsot_storage_shift = (i + 2) % (outer_in_3_len)
        outer_in_3_mitsot_storage_left = outer_in_3_mitsot_storage[:outer_in_3_mitsot_storage_shift]
        outer_in_3_mitsot_storage_right = outer_in_3_mitsot_storage[outer_in_3_mitsot_storage_shift:]
        outer_in_3_mitsot_storage = np.concatenate((outer_in_3_mitsot_storage_right, outer_in_3_mitsot_storage_left))
    if (i + 1) > outer_in_4_len:
        outer_in_4_sitsot_storage_shift = (i + 1) % (outer_in_4_len)
        outer_in_4_sitsot_storage_left = outer_in_4_sitsot_storage[:outer_in_4_sitsot_storage_shift]
        outer_in_4_sitsot_storage_right = outer_in_4_sitsot_storage[outer_in_4_sitsot_storage_shift:]
        outer_in_4_sitsot_storage = np.concatenate((outer_in_4_sitsot_storage_right, outer_in_4_sitsot_storage_left))

    return outer_in_3_mitsot_storage, outer_in_4_sitsot_storage

@ricardoV94 ricardoV94 requested a review from aseyboldt March 6, 2023 16:56
@ricardoV94 ricardoV94 changed the title Numba scan tweak Speedubp Numba Scans with vector inputs Mar 6, 2023
@ricardoV94 ricardoV94 changed the title Speedubp Numba Scans with vector inputs Speedup Numba Scans with vector inputs Mar 6, 2023
@ricardoV94 ricardoV94 requested a review from Armavica March 29, 2023 07:13
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, provided I understand what's going on correctly.

temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
storage_dtype = outer_in_var.type.numpy_dtype.name
temp_scalar_storage_alloc_stmts.append(
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question so I'm clear. is_vector refers to the values passed by outputs_info. This will be a vector if there are multiple taps requested. You're making this storage scalar as a place to break up this vector and store the individual values as they are fed into the inner function. Doing this prevents the need to call np.asarray on the output, because it's already an array, because these storage values are 0d arrays. Is that right?

Copy link
Member Author

@ricardoV94 ricardoV94 May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The is_vector refers to any vector input (sequence or recurrent), which must be broken into scalar arrays in each iteration. This happens regardless of the number of taps (or you could say there's always at least one tap of -1).

Indexing vector inputs to create taps during scan, yields numeric variables which must be wrapped again into scalar arrays before passing into the inernal function.

This commit pre-allocates such arrays and reuses them during looping.
@ricardoV94 ricardoV94 requested a review from twiecki May 9, 2023 12:33
@twiecki twiecki merged commit cb417fe into pymc-devs:main May 10, 2023
@ricardoV94 ricardoV94 deleted the numba_scan_tweak branch June 21, 2023 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants