In the An Introduction to JAX page the Writing Vectorized Code section uses the default 'xy' cartesian indexing for meshgrid
|
x_mesh, y_mesh = jnp.meshgrid(x, y) |
This results in output shape (n_y, n_x) for the x_mesh, y_mesh, and z_mesh arrays, while the z_vmap returned from f_vec has shape (n_x, n_y).
Testing with
n = 6000
x = jnp.linspace(-2, 2, n)
y = jnp.linspace(-4, 4, 2*n)
Default xy cartesian indexing

Changing the call to x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') results in both versions agreeing.
ij matrix indexing

In this example there's no difference since x and y are the same shape and the function f is symmetric but in general this can be a subtle bug that catches people out.