Skip to content

Nans in JAX multinomial dispatch #1327

Closed
@ricardoV94

Description

@ricardoV94
Member

Description

PyMC tests are failing: pymc-devs/pymc#7740

Reproducible example:

import pytensor.tensor as pt

p = pt.eye(3)
rv = pt.random.multinomial(n=5, p=p)
rv.eval(mode="JAX")
# Array([[ 5., nan, nan],
#        [ 0.,  5.,  0.],
#        [ 0.,  0.,  5.]], dtype=float64)

I guess it could be a problem with binomial with p=0?

CC @educhesne

Activity

educhesne

educhesne commented on Mar 28, 2025

@educhesne
Contributor

I am very sorry about this... I think there is a division by zero happening here

educhesne

educhesne commented on Mar 28, 2025

@educhesne
Contributor

Do you know why it was not caught by the tests in the PR ?

ricardoV94

ricardoV94 commented on Mar 28, 2025

@ricardoV94
MemberAuthor

I guess we were not testing edge cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @educhesne@ricardoV94

      Issue actions

        Nans in JAX multinomial dispatch · Issue #1327 · pymc-devs/pytensor