Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Research language for array processing in the Haskell/ML family

Optimizing control variates for black-box gradient estimation

Code for the paper "Learning Differential Equations that are Easy to Solve"

Recognizing and exploiting conjugacy without a domain-specific language

Prototypes of differentiable differential equation solvers in JAX.

my fish configuration

my .vim

omnistaging on by default

push time in 14 hours

commit sha e7b74caa79b9592ab3b97d7c899f9c7e030ae56a

omnistaging on by default

push time in 14 hours

One more thought: to get the second derivatives, you just need to write grad(grad(...)) inside the double-vmap. Getting rid of batch dimensions makes things so much easier!

tawe141

comment created time in 15 hours

omnistaging on by default

push time in 15 hours

There may be a nicer way to write the nested vmap using the vectorize wrapper.

tawe141

comment created time in 15 hours

Thanks for the question! I agree with what @shoyer said.

To say a bit more, if f: R^n -> R^m then we expect its Jacobian to be an m x n matrix. The function being differentiated here, namely RBF.forward with respect to its first argument, takes an input of shape (50, 1) and produces an output of shape (50, 50), which is why we expect to see a Jacobian of shape (50, 50, 50, 1). (If we want to think in flattened terms, we'd say that the input has dimension n=50 and the output has dimension m=2500.)

The trouble here is we just want to see the data axes of size 50 as batch axes. That is, we really want to think of the kernel as taking two vectors and outputting a scalar. We only carry along the batch axis in the implementation for vectorization efficiency.

But with vmap, we don't need to batch by hand, and as a result it can be easier to express the functions and derivatives we want. Concretely, take a look at the forward_ and dkdx2_ methods below (note the underscores on the end), and the assertion:

class RBF(Kernel):
# ...

def forward_(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
assert thetas.shape == (1,)
length_scale = thetas[0]
dist_sq = np.vdot(x1, x1) + np.vdot(x2, x2) - 2 * np.vdot(x1, x2)
return np.exp(-0.5 * dist_sq / length_scale**2)

def __init__(self, length_scale=1.0):
self.dk2dx1dx2 = jit(jacfwd(jacrev(super(RBFGrad, self).forward, argnums=0), argnums=1))

self.dkdx2_ = jit(vmap(vmap(grad(super().forward_, argnums=1), (0, None, None)), (None, 0, None)))

def forward(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
K = super().forward(x1, x2, thetas)
dx2 = self.dkdx2(x1, x2, thetas).sum(-2)
dx2_ = self.dkdx2_(x1, x2, thetas)
assert np.allclose(dx2, dx2_)
# ...


I rewrote the forward_ method so that it more obviously applies to single vectors at a time, and also to use the polarization identity ||u - v||^2 = ||u||^2 + ||v||^2 - 2 u \cdot v which gives us more matrix multiplies and I think is often faster (though one should benchmark against computing ||u - v||^2 more directly as before). I also added the dkdx2_ method, which uses vmap to do all the batching, and compared against the old calculation (it may be faster to use jvp than grad but this was convenient). That second bit, with the vmaps, is the main thing I wanted to illustrate, as it was basically @shoyer's advice.

One other piece of advice would be to put jit on RBFGrad.forward, because the more code XLA can see the more it can optimize things.

WDYT?

tawe141

comment created time in 15 hours

Hi JAX team!

Currently, it seems that broadcasting of minval and maxval in jax.random.uniform is not supported. https://github.com/google/jax/blob/a169743f64e9ddf826cdad5c225e3e18a980a1db/jax/random.py#L343

In numpy, this is supported

np.random.uniform(low=jnp.zeros(2), high=jnp.ones(2), size=(10,2))


In JAX, this will result in a TypeError TypeError: max got arrays of different rank: (2,), (10, 2)..

jax.random.uniform(
key=jax.random.PRNGKey(42),
shape=(10,2),
minval=jnp.zeros(2),
maxval=jnp.ones(2))


A manual broadcasting will fix this issue

jax.random.uniform(
key=jax.random.PRNGKey(42),
shape=(1,2),


Is JAX's behaviour intended? If not, I am happy to send a PR fixing this.

closed time in 16 hours

ethanluoyc

commit sha 85eea219e3e5041f3f9d5e7884db62a27492d58c

omnistaging on by default

push time in 16 hours

+767 -755

0 comment

14 changed files

pr created time in 16 hours

created branch time in 16 hours

Let me know if I can help! I looked at some of this bookkeeping recently.

shoyer

comment created time in 17 hours

By the way, XLA:CPU (and hence JAX on CPU) has some known performance issues with float64, e.g. I believe the 64bit GEMM kernel being called is a slow Eigen one, while the 32bit one is from MKL-DNN. I mention it just because I don't want it to be a confounder, though I haven't looked at the details here enough to know if it could be relevant.

I'd love if an outcome of this investigation is that we should add some alternative VJPs for odeint, because that sounds really fun!

ibulu

comment created time in 17 hours

Hey @gsp-27 , thanks for the contribution!

It turns out that trim_zeros isn't a good candidate for JIT compilation with XLA, because the output shape depends on the values of the input. Because XLA programs have static shapes in their type system, that means we'd need to create and compile a separate program for most input values. That's why you're seeing issues with _CompileAndCheck.

We could still include an implementation of trim_zeros in jax.numpy and just allow it to error when people try to include it in jit functions, but that means the upside here is limited so it might also make sense just to leave it out. If we did want to include it, we'd likely need another implementation technique rather than the Python loop implementation here, perhaps based on lax.while_loop (though I'm not sure about that).

Overall I'd recommend not adding trim_zeros for the above reason.

What do you think?

gsp-27

comment created time in a day

I tried to run the example code in README to use jax on TPU from a GCE VM, but I got RuntimeError.

2020-08-11 08:26:26.301652: E external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc:1068] Failed to open the gRPC driver: 12: :
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/random.py", line 85, in PRNGKey
k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF))
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/random.py", line 81, in <lambda>
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lax/lax.py", line 397, in convert_element_type
operand, new_dtype=new_dtype, old_dtype=old_dtype)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/core.py", line 276, in bind
return self.impl(*args, **kwargs)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 236, in xla_primitive_callable
backend = xb.get_device_backend(device)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 176, in get_device_backend
return get_backend(platform)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 170, in get_backend
return backend(platform)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 148, in _get_tpu_driver_backend
_tpu_backend = tpu_client.TpuBackend.create(worker=backend_target)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jaxlib/tpu_client.py", line 59, in create
return _tpu_client.TpuClient.Get(worker)
RuntimeError: Unimplemented: Failed to connect to remote server at address: grpc://10.240.1.2:8470. Error from gRPC: . Details:


Any idea why this happens?

closed time in a day

JaySunnn

@jekbradbury @skye you are the greatest!

JaySunnn

comment created time in a day

delete branch : fewer-ode-gpu-tests

delete time in a day

commit sha c564aca77710df0599715d4231b7d5b7dd46984a

skip more ode tests on gpu, b/c slow to compile (#4028)

push time in a day

+3 -3

0 comment

1 changed file

mattjj

pr closed time in a day

+3 -3

0 comment

1 changed file

pr created time in a day

created branch time in a day

@skye can you take a look at this one?

JaySunnn

comment created time in 2 days

startedlocuslab/monotone_op_net

started time in 2 days

commit sha fe9f264b55f8b99c57f803db9eb7a2c8df897e9b

cumulative jet rules (#4000)

push time in 2 days

cumulative jet rules cla: yes

Below is my attempt at adding jet rules for cumulative operations, following jvp rules defined in lax.py

+26 -1

2 changed files

thisiscam

pr closed time in 2 days

commit sha 1e07712955939d6f8f461fc259b12a20808782b3

Fix typos in api.py docstrings (#4021)

push time in 2 days

+20 -20

1 changed file

j-towns

pr closed time in 2 days

Thanks for catching these! Possibly related, my 'r' key is broken and often types 0 or 2 copies of the letter! :P

j-towns

comment created time in 2 days

 def process_primitive(self, primitive, tracers, params):     vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)     if all(bdim is not_mapped for bdim in dims_in):       return primitive.bind(*vals_in, **params)+    elif config.omnistaging_enabled and primitive in collective_rules:+      axes_names = params['axis_name']+      if not isinstance(axes_names, (tuple, list)):+        axes_names = (axes_names,)+      for i, axis_name in enumerate(axes_names):+        frame = core.axis_frame(axis_name)+        if frame.tag is self.master:

Yeah, I agree.

apaszke

comment created time in 2 days

Thank you, David!

majnemer

comment created time in 2 days

delete branch : make-jaxpr-works-on-tracers

delete time in 2 days

commit sha d100327bf33515cedca41eedfc30d6fb49de7ef2

Remove type restrictions (#4017) We support half16 on TPU

push time in 2 days

Remove type restrictions cla: yes

We support half16 on TPU

+0 -1

0 comment

1 changed file

majnemer

pr closed time in 2 days

delete branch : revise-custom-interpreters-notebook

delete time in 2 days

commit sha 09d8ac14de8c9edc2c9a14becb4963bfdebda605

use fewer internal APIs in custom interp notebook (#4016)

push time in 2 days

fixes #4001

+594 -766

0 comment

1 changed file

mattjj

pr closed time in 2 days

closed time in 2 days

mattjj

commit sha 6a3b920507dcae7a4e4dfa513155222fa0c6feb1

make make_jaxpr work on tracer example args (#4014) (don't use xla.abstractify)

Remove type restrictions (#4011) We support s8, u8, s16, u16, half16 on TPU

commit sha f50216316f692d7c261aeedfc74a24e1424f9a42

use fewer internal APIs in custom interp notebook

push time in 2 days

Remove type restrictions (#4011) We support s8, u8, s16, u16, half16 on TPU

push time in 2 days

Remove type restrictions cla: yes

We support s8, u8, s16, u16, half16 on TPU

+6 -29

1 comment

1 changed file

majnemer

pr closed time in 2 days

Internal tests pass!

majnemer

comment created time in 2 days

+594 -766

0 comment

1 changed file

pr created time in 2 days

@gnecula want to take this one?

shoyer

comment created time in 2 days

created branch time in 2 days

commit sha 6a3b920507dcae7a4e4dfa513155222fa0c6feb1

make make_jaxpr work on tracer example args (#4014) (don't use xla.abstractify)

push time in 2 days

(don't use xla.abstractify)

+1 -1

0 comment

1 changed file

mattjj

pr closed time in 2 days

(don't use xla.abstractify)

+1 -1

0 comment

1 changed file

pr created time in 2 days

created branch time in 2 days

 def __iter__(self):     if self.ndim == 0:       raise TypeError("iteration over a 0-d array")  # same as numpy error     else:-      return self._value.__iter__()+      device = self.device_buffer.device()+      if device is None or device.platform == 'cpu':+        return iter(self._value)

Yeah, good point. We should revise this case so that it returns CPU DeviceArrays, maybe.

mattjj

comment created time in 3 days

+64 -35

0 comment

2 changed files

pr created time in 3 days

created branch time in 3 days

push eventapaszke/jax

deflake

push time in 3 days

 def process_primitive(self, primitive, tracers, params):     vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)     if all(bdim is not_mapped for bdim in dims_in):       return primitive.bind(*vals_in, **params)+    elif config.omnistaging_enabled and primitive in collective_rules:+      axes_names = params['axis_name']+      if not isinstance(axes_names, (tuple, list)):+        axes_names = (axes_names,)+      for i, axis_name in enumerate(axes_names):+        frame = core.axis_frame(axis_name)+        if frame.tag is self.master:

I know you explained this to me before, but I forget: is the tag just to handle possible shadowing? If so, after this came up in a convo with James and Roy last week, I've started to think we should make shadowing illegal, which might also help us slightly simplify this logic.

For now, just flagging this for more discussion!

apaszke

comment created time in 3 days

 def _psum_transpose_rule(cts, axis_name, axis_index_groups): pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p)+batching.collective_rules[psum_p] = \+    lambda vals, dims, axis_size, **_: [v.sum(d) if d is not batching.not_mapped else

Currently we might need to lower this to a reduce-sum followed by a broadcast, as per this comment and this one. But also we might want to revise that, so that psum just means reduce-sum... I'm not sure what the best thing to do is, just wanted to flag this as a discussion topic!

apaszke

comment created time in 3 days

Wow, awesome! Thanks for pushing on this.

(By the way, our GitHub CI tests have been really flakey for the last ~week, failing on dependency installation. So if you see dependency installation failures, assume they're not your fault!)

thisiscam

comment created time in 3 days

delete branch : flax-omnistaging-bug

delete time in 3 days

commit sha d46ea969533dbfe4451517ac9a5e51dbdeee6d5d

add helper for flax to be omnistaging-compatible (#4004)

push time in 3 days

+9 -0

0 comment

1 changed file

mattjj

pr closed time in 3 days

+9 -0

0 comment

1 changed file

pr created time in 3 days

created branch time in 3 days

why is this still not officially supported?

Well, we have a small team, and so even though there's a lot of work worth doing, we have to choose what to prioritize.

I believe JAX works great with WSL, so if that works on your Windows setup you might want to give it a try.

@ericmjl I love your attitude! We've benefitted a huge amount from open-source contributions, and I hope we can get a lot more over time! We're all on the same team here, doing our best to push things forward.

ericmjl

comment created time in 4 days

repro:

from jax import numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (3,), dtype=jnp.bfloat16)


observation:  - crashes on TPU backends (both internal and cloud TPU as far as I can tell)  - CPU and GPU backends don't seem to crash

expect: even if bfloat16 isn't supported by random.* on TPU it would be better to error-out rather than crashing

closed time in 4 days

levskaya

Thanks, @majnemer! I just double-checked on an internal TPU colab and indeed it works. Thanks to your team for the fix.

levskaya

comment created time in 4 days

code:



def gumbel_sample(log_probs, temperature=1.0):
"""Gumbel sampling from a categorical distribution."""
u = numpy.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
g = -np.log(-np.log(u))
return np.argmax(log_probs + g * temperature, axis=-1)

def predict(num_chars, prefix):
inp = [ord(c) for c in prefix]
result = [c for c in prefix]
max_len = len(prefix) + num_chars
for _ in range(num_chars):
cur_inp = np.array(inp + [0] * (max_len - len(inp)))
outp = model(cur_inp[None, :])  # Add batch dim.
next_char = gumbel_sample(outp[0, len(inp)])
inp += [int(next_char)]

if inp[-1] == 1:
break  # EOS
result.append(chr(int(next_char)))

return "".join(result)

print(predict(32, ""))


Error:

LayerError                                Traceback (most recent call last)
<ipython-input-27-6f9f9afc30e6> in <module>
22     return "".join(result)
23
---> 24 print(predict(32, ""))

<ipython-input-27-6f9f9afc30e6> in predict(num_chars, prefix)
12     for _ in range(num_chars):
13         cur_inp = np.array(inp + [0] * (max_len - len(inp)))
---> 14         outp = model(cur_inp[None, :])  # Add batch dim.
15         next_char = gumbel_sample(outp[0, len(inp)])
16         inp += [int(next_char)]

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
165       self.state = state  # Needed if the model wasn't fully initialized.
166     state = self.state
--> 167     outputs, new_state = self.pure_fn(x, weights, state, rng)
168     self.state = new_state
169     self.weights = weights

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
448       name, trace = self._name, _short_traceback(skip=3)
449       raise LayerError(name, 'pure_fn',
--> 450                        self._caller, signature(x), trace) from None
451
452   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-13-6f7ffe0c061e>, line 21
layer input shapes: ShapeDtype{shape:(1, 32), dtype:int32}

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Embedding_256_512 (in pure_fn):
layer created in file [...]/<ipython-input-13-6f7ffe0c061e>, line 18
layer input shapes: ShapeDtype{shape:(1, 32), dtype:int32}

File [...]/trax/layers/core.py, line 150, in forward
return jnp.take(self.weights, x, axis=0)

File [...]/jax/numpy/lax_numpy.py, line 3298, in take
slice_sizes=tuple(slice_sizes))

File [...]/jax/lax/lax.py, line 835, in gather
slice_sizes=canonicalize_shape(slice_sizes))

File [...]/site-packages/jax/core.py, line 273, in bind
return self.impl(*args, **kwargs)

File [...]/jax/interpreters/xla.py, line 228, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)

File [...]/jax/interpreters/xla.py, line 262, in xla_primitive_callable
*avals, **params)

File [...]/jax/interpreters/xla.py, line 320, in primitive_computation
raise RuntimeError(msg) from e

RuntimeError: Invalid argument: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.:
This is a bug in JAX's shape-checking rules; please report it!


closed time in 4 days

Keshav15

No worries, we always welcome issues, and also welcome them fixing themselves :)

Keshav15

comment created time in 4 days

The tensordot on this line does not perform the intended broadcast matrix-vector multiply. I think you need this (or equivalent),

return mean + np.einsum('...ij,...j->...i', chol_factor, normal_samples)


Here's an example:

import jax.numpy as np
import jax.random

mean = np.zeros((10, 4))
cov = np.eye(4)[None,...].repeat(10, axis=0)
rng = jax.random.PRNGKey(0)
sample = jax.random.multivariate_normal(rng, mean, cov)
print(sample.shape) # is (10, 10, 4); should be (10, 4).


closed time in 4 days

slinderman

Merged the fix from Scott in #4002!

slinderman

comment created time in 4 days

commit sha ea88c55f555780cdccc27c13d19fd393e695e120

Fixes and tests for jax.random.multivariate_normal (#4002) * Fix bug #3997, change jax.random.multivariate_normal to handle batches of covariance matrices. It works as long as mean and covariance are broadcast-compatible, as specified in the docstring. * Fix bug in multivariate_normal shape checking Minor bug: should be checking for compatibility of shape, mean, and the the last two dimensions of the _covariance_ matrix. * Add test for multivariate_normal shapes This test checks that jax.random.multivariate_normal produces the expected output shape for various combinations of event dimension and mean, covariance, and shape shapes. * Fix linter issues in tests/random_test.py Trimming trialing whitespace and 80 char limit. * Really trimming whitespace in tests/random_test.py Arg. Have to fix my editor to do this automatically.

push time in 4 days

This PR addresses bug https://github.com/google/jax/issues/3997 and another shape issue exposed during testing.

+27 -2

0 comment

2 changed files

slinderman

pr closed time in 4 days

created time in 4 days

I left this bug in there just for you to find, Scott.

slinderman

comment created time in 5 days

Yeah! Let's get @slinderman's first of many JAX PRs here!

slinderman

comment created time in 5 days

As a wild guess, not packing values together with jnp.stack (as in the return value of SEIRD_mobility_coupled) can help both compilation time and performance.

ibulu

comment created time in 5 days

Is it slow to execute (after compiling), or just to compile? That is, if you try evaluating the gradient a second time, is it faster on the second evaluation?

ibulu

comment created time in 5 days

Thanks for the question!

No, that's not currently possible. It's documented in the custom derivatives tutorial (see the subsection "Forward-mode autodiff cannot be used on the jax.custom_vjp function and will raise an error" for more).

Can you say more about why you might need to define both?

thisiscam

comment created time in 5 days

Okay, back to it!

Notice the Jacobian of the agg_loss function written with the broadcast is [[0, 1], [0, 1]]:

import jax.numpy as jnp
from jax import lax
from jax import jacfwd

def f(w, data):
def agg_loss(w):
return jacfwd(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [[0. 1.]
#  [0. 1.]]


So, while keeping the current (at HEAD) bulk array definition of psum as a reduce-sum followed by a broadcast, the SPMD AD semantics is consistent so long as we take grad to mean "compute the VJP against a ones vector broadcast along all named axes":

import jax.numpy as jnp
from jax import vjp, jvp, pmap, grad
from jax import lax

### reverse-mode

# At HEAD, we define this SPMD program:
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [0. 2.]

# To mean the same as this bulk array vjp-with-ones program (grad is always
# defined as vjp-with-ones plus an error check for scalar outputs that we don't
# include in the definition of SPMD semantics):
ans, f_vjp = vjp(f, x)
x_bar, = f_vjp(jnp.ones_like(ans))
return x_bar

def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
print(f(jnp.ones(2), jnp.arange(2)))
# [0. 2.]

# ### forward-mode

# At HEAD, we define this SPMD program:
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [1. 1.]

# To mean the same as this bulk array jvp-with-ones program:
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
return jvp(agg_loss, (w,), (jnp.ones_like(w),))[1]
print(f(jnp.ones(2), jnp.arange(2)))
# [1. 1.]


(In Autograd, like TF today, we used to define grad as vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn't make that check work with SPMD functions, in the sense that grad will in effect happily allow broadcasting along named mapped axes!)

If the semantics at HEAD are self-consistent, except the error semantics for grad, do we need to change anything, except perhaps to avoid this potential confusion by making grad error semantics consistent in the positional and SPMD worlds?

Maybe yes. One problem with the current semantics is that if we make @sharadmv's use of grad here an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they'd have to write vjp-with-ones themselves, e.g. by defining grad2 as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we're calling grad/grad2 on a scalar-input scalar-output function (but for the closed-over value of data which is different on each device) with the same primal input value on every device, yet getting different grad2 results on different devices (perhaps not noticing that if we looked at the primal output value we'd also have a different value on each device, which might make getting different gradients less surprising).

In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for grad), this example has clearly shown that they can be confusing, especially when closing over mapped data.

comment created time in 6 days

I changed the issue title because I think the current (i.e. at HEAD) semantics make sense, but can make grad(f)(x) and jvp(f, (x,), (1.,)) differ for scalar-input scalar-output f when inside a pmap, which seems worth revising. The current semantics just take a different correspondence to bulk array programs than the one in @jekbradbury's previous comment.

The original intention with psum was to correspond to a bulk array operation that included broadcasting, not just a reduce-sum:

pmap(lambda x: psum(x, 'i'), axis_name='i')
==


So given this pmap code:

def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))


we should expect it to behave like this bulk array version:

import jax.numpy as jnp
from jax import lax

def f(w, data):
def agg_loss(w):
print(f(jnp.ones(2), jnp.arange(2)))


But that code triggers an error, because in api.py we check to ensure grad is only being applied to scalar-output functions. Just for defining semantics, let's sidestep that error check:

import jax.numpy as jnp
from jax import vjp
from jax import lax

ans, f_vjp = vjp(f, x)
x_bar, = f_vjp(jnp.ones_like(ans))
return x_bar

def f(w, data):
def agg_loss(w):
print(f(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]


In effect, we had to write our own grad that was willing to produce a broadcasted ones cotangent vector.

These SPMD semantics explain the reverse-mode answer in my first comment above.

The trouble with this definition, though, is that it can make grad disagree with jvp and numerical differences on what appears locally to be a scalar-input scalar-output function when inside a pmap, as in the examples in my first comment. I want to write out an explanation for what's going on, but I've got to step away for a moment and wanted to send this comment first. To be continued!

In any case, I think @jekbradbury 's proposal for revising the semantics is likely to be better. I just want to pin down both the old and new semantics as best we can.

comment created time in 6 days

push time in 6 days

created branch time in 7 days

@jekbradbury that's brilliant! Let's do it.

comment created time in 7 days

This was a bad merge on my part! I believe this passes internal tests.

j-towns

comment created time in 7 days

Yeah that's my thinking too!

dionhaefner

comment created time in 7 days

The omnistaging change ensures that we trace the Python to a jaxpr in a way that preserves operation ordering, but I suspect we still need to thread tokens in the XLA HLO lowering. That is, JAX now has the tools to avoid reordering without user-level tokens, but XLA may still reorder things.

@hawkinsp does that sound right to you?

dionhaefner

comment created time in 7 days

Thanks for raising this; that's too slow! Which backend is this on (CPU/GPU/TPU)? It might be faster or slower to compile on different backends (often slowest on CPU, fastest on TPU).

comment created time in 7 days

I wonder if this is an interaction between the laziness of np.ones and omnistaging...

dionhaefner

comment created time in 7 days

GE and LT are "greater than" and "less than", respectively. That error is saying that those operations aren't implemented for complex numbers. (NumPy implements them with a kind of weird convention, since complex numbers don't have a total ordering.)

mganahl

comment created time in 7 days

I'm starting to agree with James that we might need to track spmd-symmetry in jaxprs, and reverse mode should do something interesting with that information. It's tricky: if we have a jaxpr to transpose with both symmetric and non-symmetric outputs, we may need to run two backward passes, one for the symmetric outputs (where we symmetrize the resulting input cotangnents) and one for the non-symmetric outputs (where we don't).

comment created time in 7 days

Thanks! One possible desideratum is not to do something different on guaranteed-symmetric values versus just-happens-to-be-symmetric values. Notice that in this example if data just happens to be symmetric, we get the right answer. Also, if we noticed (say in backward_pass) that the function to be transposed is symmetric (in the sense that when it's given a symmetric input it produces a symmetric output), and in that case always symmetrized the cotangents of the jaxpr inputs, then we'd compute the right answer in both guaranteed-symmetric (new!) and just-happens-to-be-symmetric (already got that one) cases.

comment created time in 7 days

Can you define 'replicated thing' and 'non-replicated thing'? Do you mean a function (agg_loss in this case), and if so what's it mean to be replicated (maybe: has any collectives in it) ?

comment created time in 7 days

This is confusing! It seems to be an issue about how we define transposes of SPMD functions (i.e. it's not just a dumb software bug). No conclusions yet, but wanted to dump some thoughts here.

Here are some test programs:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax

def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]

def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]

def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]


The jaxpr being transposed in the last one is

{ lambda c ; b.
let d = mul b c
e = psum[ axis_index_groups=None
axis_name=batch ] d
in (e,) }


Operationally, the transpose evaluation proceeds like this: grad feeds in a cotangent for e of 1.0. Then we compute a cotangent for d of 2.0, reflecting the fact that if we perturb the value of d by epsilon (on each replica) then the value of e changes by 2 epsilon. Then we compute a cotangent for b by multiplying by 0. on the first replica and 1. on the second replica, leading to the final result [0. 2.].

Beyond the numbers, we have a symmetry issue. The first two versions produce a symmetric-across-replicas result, which makes sense because agg_loss's last operation is a psum and so its output must be symmetric across replicas. But the reverse-mode version can't produce a symmetric result because it multiplies by the mapped value data at the end. (The symmetric result makes the most sense because agg_loss always produces a symmetric result given a symmetric input.)

Hmmmm...

comment created time in 7 days

created time in 7 days

 def sampleCosineWeightedHemisphere (k:Key) (normal: Vec3) : Vec3 =   rr = (rx .* uu) + (ry .* vv) + (rz .* normal)   normalize rr +' ### Raytracer++Distance = Real+def Image (n:Int) :Type = Fin n => Fin n => Color -- TODO: hide the size++Position  = Vec 3+Direction = Vec 3  -- Should be normalized. TODO: use a newtype wrapper --- Misc. rendering params.-MAX_ITERS = 50-HORIZON = 20.0-MAX_DEPTH = 3--def raymarch (ro:Vec3) (rd: Vec3) : (Object & Distance) =-  -- Move along ray until we hit an obect.-  def step (i:(Fin MAX_ITERS)) (pair:(Object & Distance)) : (Object & Distance) =-    (_, t) = pair-    (obj, distance) = sdScene (ro + (t .* rd))-    (obj, t + distance)--  (obj_id, t) = fold (Obj_None, 0.0) step-  obj_id = select (t > HORIZON) Obj_None obj_id -  (obj_id, t)---def calcNormal (p:Vec3) : Vec3 =-  dist = \p. snd $sdScene p- -- normalize(grad(dist)(p))-- -- derivative approximation via midpoint rule.- -- Todo: Switch to autodiff when it works.- eps = 0.001- dx = [eps, 0.0, 0.0]- dy = [0.0, eps, 0.0]- dz = [0.0, 0.0, eps]- -- extract just the distance component- nor = [(dist(p+dx)) - (dist(p-dx)),- (dist(p+dy)) - (dist(p-dy)),- (dist(p+dz)) - (dist(p-dz))]- normalize nor---def sample_light_point (key:Key) : Vec3 =- -- Samples a point uniformly from the surface of the light.- -- Todo: remove hardcoded values- -- Todo: Allow multiple lights- (key1, key2) = splitKey key- p_light_x = randuniform key1 (-0.25) (0.25)- p_light_z = randuniform key2 (2.0 - 0.25) (2.0 + 0.25)- [p_light_x, 3.9, p_light_z]---def direct_light (key:Key) (p:Vec3) (nor:Vec3) (brdf:Color) : Vec3 =-- -- Check for line-of-sight to a random point on the light.- p_light = sample_light_point key- wi_light = normalize (p_light - p)- (obj2, t2) = raymarch (p + 0.001 .* nor) wi_light-- -- Compute radiance of light from this direction.- vis = islight obj2- cos1 = relu (vdot nor wi_light)- cos2 = relu (vdot nor_light (negvec wi_light))- pdf_A = 1.0 / LIGHT_AREA- square_distance = sum$ for i. sq (p_light.i - p.i)-  scale = (cos1 * cos2) / (pdf_A * square_distance)-  li = for i. scale * emittedRadiance.i * brdf.i--  case vis of-    False -> zero-    True -> li---def scatter_eye_rays-    (depth:(Fin MAX_DEPTH)) (carry:(Key & Vec3 & Vec3)) :-    ((Key & Vec3 & Vec3) & (Bool & Vec3 & Color)) =-  (rng, ro, rd) = carry-  (obj, t) = raymarch ro rd-  brdf = brdf_map obj-  -  is_light = islight obj-  did_intersect = not $isnone obj-- li_e = (relu ( vdot (negvec rd) nor_light)) .* emittedRadiance- radiance = select is_light li_e zero- - p = ro + t .* rd- nor = calcNormal p- - -- Contribution directly from light.- (rng, subkey) = splitKey rng - li_direct = direct_light subkey p nor brdf- radiance = radiance + select did_intersect li_direct zero-- -- Sample bounced ray for future indirect contributions.- (rng, subkey) = splitKey rng- rd2 = sampleCosineWeightedHemisphere subkey nor-- carry = (rng, ro, rd)- outputs = (did_intersect, radiance, brdf)- (carry, outputs)----- Add light from each step if there was an intersection.-def accumulate_outgoing (li_indirect:Vec3) (x:(Bool & Vec3 & Color)) : Vec3 =- (did_intersect, radiance, brdf) = x- radiance + select did_intersect (for d. brdf.d * li_indirect.d) zero ---- Main loop.-def trace (rng:Key) (ro:Vec3) (rd:Vec3) (depth:(Fin MAX_DEPTH)) : Vec3 =- init = (rng, ro, rd)- (carry, outputs) = scan init scatter_eye_rays -- Forward pass.- fold zero \i c. accumulate_outgoing c (reverse outputs).i -- Backward pass.---' Setup and draw image--num_samples = 3-N = 500 -- pixel width and height of image.--xs = linspace (Fin N) 1.0 0.0 -- Reverse order because of pinhole camera-rng = newKey 0+BlockHalfWidths = Vec 3+Radius = Real+Radiance = Color++data ObjectGeom =+ Wall Direction Distance+ Block Position BlockHalfWidths Angle+ Sphere Position Radius++data Surface =+ Matte Color+ Mirror++OrientedSurface = (Direction & Surface)++data Object =+ PassiveObject ObjectGeom Surface+ -- position, half-width, intensity (assumed to point down)+ Light Position Real Radiance++Ray = (Position & Direction)+Filter = Color++-- TODO: use a record+-- num samples, num bounces+Params = (Int & Int)++-- TODO: use a list instead, once they work+data Scene n:Type = MkScene (n=>Object)++def sampleReflection ((nor, surf):OrientedSurface) ((pos, dir):Ray) (k:Key) : Ray =+ newDir = case surf of+ Matte _ -> sampleCosineWeightedHemisphere nor k+ -- TODO: surely there's some change-of-solid-angle correction we need to+ -- consider when reflecting off a curved surface.+ Mirror -> dir - (2.0 * dot dir nor) .* nor+ (pos, newDir)++def probReflection ((nor, surf):OrientedSurface) (_:Ray) ((_, outRayDir):Ray) : Real =+ case surf of+ Matte _ -> relu$ dot nor outRayDir+    Mirror  -> 0.0  -- TODO: this should be a delta function of some sort++def applyFilter (filter:Filter) (radiance:Radiance) : Radiance =+  for i. filter.i * radiance.i++def surfaceFilter (filter:Filter) (surf:Surface) : Filter =+  case surf of+    Matte color -> for i. filter.i * color.i+    Mirror      -> filter++def sdObject (pos:Position) (obj:Object) : Distance =+  case obj of+    PassiveObject geom _ -> case geom of+      Wall nor d -> d + dot nor pos+      Block blockPos halfWidths angle ->+        pos' = rotateY (pos - blockPos) angle+        length $for i. max ((abs pos'.i) - halfWidths.i) 0.0+ Sphere spherePos r ->+ pos' = pos - spherePos+ max (length pos' - r) 0.0+ Light squarePos hw _ ->+ pos' = pos - squarePos+ halfWidths = [hw, 0.01, hw]+ length$ for i. max ((abs pos'.i) - halfWidths.i) 0.0++def sdScene (scene:Scene n) (pos:Position) : (Object & Distance) =+  (MkScene objs) = scene+  minimumBy snd for i.+    obj = objs.i+    (obj, sdObject pos obj)++-- TODO: use AD!+def calcNormal (scene:Scene n) (pos:Position) : Direction =+  dist = \p. snd $sdScene scene p+ normalize (gradNumerical dist pos)++data RayMarchResult =+ -- incident ray, surface normal, surface properties+ HitObj Ray OrientedSurface+ HitLight Radiance+ -- Could refine with failure reason (beyond horizon, failed to converge etc)+ HitNothing++def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult =+ max_iters = 100+ tol = 0.01+ startLength = 10.0 * tol -- trying to escape the current surface+ (rayOrigin, rayDir) = ray+ iter (10.0 * tol) \i rayLength.+ case i >= max_iters of+ True -> Done HitNothing+ False ->+ rayPos = rayOrigin + rayLength .* rayDir+ (obj, d) = sdScene scene$ rayPos+        -- 0.9 ensures we come close to the surface but don't touch it+        dNew = rayLength + 0.9 * d+        case d < tol of+          False -> Continue $dNew+ True ->+ surfNorm = calcNormal scene rayPos+ case positiveProjection rayDir surfNorm of+ True ->+ -- Oops, we didn't escape the surface we're leaving..+ -- (Is there a more standard way to do this?)+ Continue dNew+ False ->+ -- We made it!+ Done$ case obj of+                  PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf)+                  Light _ _ radiance   -> HitLight radiance++def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance =+  case raymarch scene ray of+    HitLight intensity -> intensity+    HitNothing -> zero+    HitObj _ _ -> zero++def sampleSquare (hw:Real) (k:Key) : Position =+ (kx, kz) = splitKey k+ x = randuniform (- hw) hw kx+ z = randuniform (- hw) hw kz+ [x, 0.0, z]++def sampleLightRadiance+      (scene:Scene n) (osurf:OrientedSurface) (inRay:Ray) (k:Key) : Radiance =+  (surfNor, surf) = osurf+  (rayPos, _) = inRay+  (MkScene objs) = scene+  snd $withAccum \radiance.+ for i. case objs.i of+ PassiveObject _ _ -> ()+ Light lightPos hw _ ->+ (dirToLight, distToLight) = directionAndLength$+                                      lightPos + sampleSquare hw k - rayPos+        case positiveProjection dirToLight surfNor of+          False -> () -- light on the far side of current surface+          True ->+            fracSolidAngle = (relu $dot dirToLight yHat) * sq hw / (pi * sq distToLight)+ outRay = (rayPos, dirToLight)+ coeff = fracSolidAngle * probReflection osurf inRay outRay+ radiance += coeff .* rayDirectRadiance scene outRay++def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color =+ (_, max_bounces) = params+ -- TODO: we ought to be able to use an accumulator here, but there's a bug+ noFilter = [1.0, 1.0, 1.0]+ iter (noFilter, zero, init_ray)$+    \i (filter, radiance, ray).+      case i >= max_bounces of+        True -> Done radiance+        False -> case raymarch scene ray of+          HitNothing -> Done radiance+          HitLight intensity -> case i == 0 of+            True  -> Done intensity  -- TODO: scale etc+            False -> Done radiance+          HitObj incidentRay osurf ->+            (k1, k2) = splitKey $hash k i+ lightRadiance = sampleLightRadiance scene osurf incidentRay k1+ outRayHemisphere = sampleReflection osurf incidentRay k2+ newFilter = surfaceFilter filter (snd osurf)+ newRadiance = radiance + applyFilter newFilter lightRadiance+ Continue (newFilter, newRadiance, outRayHemisphere)++def avgRayColor (params:Params) (scene:Scene m) (ray:Ray) (k:Key) : Color =+ (num_samples, _) = params+ sampleAveraged (trace params scene ray) num_samples k++-- TODO: add number of pixels once we can hide sizes+-- sensor half-width, pinhole-sensor distance, pinhole position+-- (Assumes we're looking towards -z.)+Camera = (Position & Real & Real)++def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => Ray =+ -- images indexed from top-left+ (pos, halfWidth, sensorDist) = camera+ ys = reverse$ linspace (Fin n) (neg halfWidth) halfWidth+  xs =           linspace (Fin n) (neg halfWidth) halfWidth+  for i j. (pos, normalize [xs.j, ys.i, neg sensorDist])++def takePicture+       (params:Params) (scene:Scene m) (n:Int) (camera:Camera)+      : ColorImage n n =+  rays = cameraRays n camera+  rootKey = newKey 0+  for i j.+    k = ixkey (ixkey rootKey i) j

This should be called "the David trick" because he always suggests it and it always works.

dougalm

comment created time in 7 days

Copying from that comment on 3370 for convenience, I think we just need to memoize this line to avoid retracing.

tfrerix

comment created time in 8 days

Thanks for flagging this, and the great repros. I think this is a duplicate of a known bug, which I should look up and link here. The basic issue is that nondiff_argnums works basically like lexical closure over tracers, but that check is disallowing lexical closure over tracers!

@NeilGirdhar has a nice solution, and I believe a PR that I still haven't gotten to.

@brianwa84 how blocking is this for your work? I'd like to get to it this week, but I'm just trying to prioritize stuff. Let me know :)

brianwa84

comment created time in 8 days

delete branch : omnistaging3

delete time in 8 days

Thanks for letting us know about this! (By the way, unfortunately 64bit matmul on CPU can be slow to execute too: #3832.)

@hawkinsp any thoughts? Is this one for XLA:CPU folks?

ahoenselaar

comment created time in 8 days

I think copyrights are meant to be from when the file was written, and so they shouldn't be uniformly bumped to the current year.

jakevdp

comment created time in 9 days

Avoid re-flattening in jit() when no donate_argnums are present. (#3945) Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.

push time in 10 days

Thank you!

kosklain

comment created time in 10 days

Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.

+8 -2

0 comment

1 changed file

kosklain

pr closed time in 10 days

Seems like we should add device_put to this PR and merge it, and also follow up on #3238 (which has bigger picture considerations, but also might let us simplify this example back down again if we want to).