@@ -355,13 +355,14 @@ def test_eye(n_rows, n_cols, kw):
355355 _n_cols = n_rows if n_cols is None else n_cols
356356 ph .assert_shape ("eye" , out_shape = out .shape , expected = (n_rows , _n_cols ), kw = dict (n_rows = n_rows , n_cols = n_cols ))
357357 f_func = f"[eye({ n_rows = } , { n_cols = } )]"
358- for i in range (n_rows ):
359- for j in range (_n_cols ):
358+ k = kw .get ("k" , 0 )
359+ expected = xp .asarray ([[1 if j - i == k else 0
360+ for j in range (_n_cols )] for i in range (n_rows )]).reshape (n_rows , _n_cols )
361+ assert out .shape == expected .shape
362+ if xp .any (out != expected ):
363+ for i , j in zip (* xp .where (out != expected )):
360364 f_indexed_out = f"out[{ i } , { j } ]={ out [i , j ]} "
361- if j - i == kw .get ("k" , 0 ):
362- assert out [i , j ] == 1 , f"{ f_indexed_out } , should be 1 { f_func } "
363- else :
364- assert out [i , j ] == 0 , f"{ f_indexed_out } , should be 0 { f_func } "
365+ assert out [i , j ] == expected [i , j ], f"{ f_indexed_out } , should be { expected [i , j ]} { f_func } "
365366
366367
367368default_unsafe_dtypes = [xp .uint64 ]
0 commit comments