profile
viewpoint
Matthew Johnson mattjj Google San Francisco people.csail.mit.edu/~mattjj research scientist @ Google Brain

google/jax 9359

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

google-research/dex-lang 638

Research language for array processing in the Haskell/ML family

mattjj/autodidact 493

A pedagogical implementation of Autograd

duvenaud/relax 133

Optimizing control variates for black-box gradient estimation

jacobjinkelly/easy-neural-ode 54

Code for the paper "Learning Differential Equations that are Easy to Solve"

google-research/autoconj 33

Recognizing and exploiting conjugacy without a domain-specific language

duvenaud/jaxde 26

Prototypes of differentiable differential equation solvers in JAX.

mattjj/config-fish 4

my fish configuration

mattjj/config-vim 4

my .vim

push eventgoogle/jax

Matthew Johnson

commit sha 241267c275b14305ad85f0688f679d660ed90182

omnistaging on by default

view details

push time in 14 hours

push eventgoogle/jax

Matthew Johnson

commit sha e7b74caa79b9592ab3b97d7c899f9c7e030ae56a

omnistaging on by default

view details

push time in 14 hours

issue commentgoogle/jax

Fastest implementation of an elementwise gradient?

One more thought: to get the second derivatives, you just need to write grad(grad(...)) inside the double-vmap. Getting rid of batch dimensions makes things so much easier!

tawe141

comment created time in 15 hours

push eventgoogle/jax

Matthew Johnson

commit sha 69ddb9f8ad17948384fd398e0b22b30a3b34031a

omnistaging on by default

view details

push time in 15 hours

issue commentgoogle/jax

Fastest implementation of an elementwise gradient?

There may be a nicer way to write the nested vmap using the vectorize wrapper.

tawe141

comment created time in 15 hours

issue commentgoogle/jax

Fastest implementation of an elementwise gradient?

Thanks for the question! I agree with what @shoyer said.

To say a bit more, if f: R^n -> R^m then we expect its Jacobian to be an m x n matrix. The function being differentiated here, namely RBF.forward with respect to its first argument, takes an input of shape (50, 1) and produces an output of shape (50, 50), which is why we expect to see a Jacobian of shape (50, 50, 50, 1). (If we want to think in flattened terms, we'd say that the input has dimension n=50 and the output has dimension m=2500.)

The trouble here is we just want to see the data axes of size 50 as batch axes. That is, we really want to think of the kernel as taking two vectors and outputting a scalar. We only carry along the batch axis in the implementation for vectorization efficiency.

But with vmap, we don't need to batch by hand, and as a result it can be easier to express the functions and derivatives we want. Concretely, take a look at the forward_ and dkdx2_ methods below (note the underscores on the end), and the assertion:

class RBF(Kernel):
    # ...

    def forward_(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
        assert thetas.shape == (1,)
        length_scale = thetas[0]
        dist_sq = np.vdot(x1, x1) + np.vdot(x2, x2) - 2 * np.vdot(x1, x2)
        return np.exp(-0.5 * dist_sq / length_scale**2)

class RBFGrad(RBF):
    def __init__(self, length_scale=1.0):
        super(RBFGrad, self).__init__(length_scale)
        self.dkdx1 = jit(jacfwd(super(RBFGrad, self).forward, argnums=0))
        self.dkdx2 = jit(jacfwd(super(RBFGrad, self).forward, argnums=1))
        self.dk2dx1dx2 = jit(jacfwd(jacrev(super(RBFGrad, self).forward, argnums=0), argnums=1))

        self.dkdx2_ = jit(vmap(vmap(grad(super().forward_, argnums=1), (0, None, None)), (None, 0, None)))

    def forward(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
        K = super().forward(x1, x2, thetas)
        dx2 = self.dkdx2(x1, x2, thetas).sum(-2)
        dx2_ = self.dkdx2_(x1, x2, thetas)
        assert np.allclose(dx2, dx2_)
        # ...

I rewrote the forward_ method so that it more obviously applies to single vectors at a time, and also to use the polarization identity ||u - v||^2 = ||u||^2 + ||v||^2 - 2 u \cdot v which gives us more matrix multiplies and I think is often faster (though one should benchmark against computing ||u - v||^2 more directly as before). I also added the dkdx2_ method, which uses vmap to do all the batching, and compared against the old calculation (it may be faster to use jvp than grad but this was convenient). That second bit, with the vmaps, is the main thing I wanted to illustrate, as it was basically @shoyer's advice.

One other piece of advice would be to put jit on RBFGrad.forward, because the more code XLA can see the more it can optimize things.

WDYT?

tawe141

comment created time in 15 hours

issue closedgoogle/jax

Broadcasting minval and maxval in jax.random.uniform

Hi JAX team!

Currently, it seems that broadcasting of minval and maxval in jax.random.uniform is not supported. https://github.com/google/jax/blob/a169743f64e9ddf826cdad5c225e3e18a980a1db/jax/random.py#L343

In numpy, this is supported

np.random.uniform(low=jnp.zeros(2), high=jnp.ones(2), size=(10,2))

In JAX, this will result in a TypeError TypeError: max got arrays of different rank: (2,), (10, 2)..

jax.random.uniform(
    key=jax.random.PRNGKey(42),
    shape=(10,2), 
    minval=jnp.zeros(2), 
    maxval=jnp.ones(2))

A manual broadcasting will fix this issue

jax.random.uniform(
    key=jax.random.PRNGKey(42),
    shape=(1,2), 
    minval=jnp.broadcast_to(jnp.zeros(2), (1,2)), 
    maxval=jnp.broadcast_to(jnp.ones(2), (1,2)))

Is JAX's behaviour intended? If not, I am happy to send a PR fixing this.

closed time in 16 hours

ethanluoyc

push eventgoogle/jax

Matthew Johnson

commit sha 85eea219e3e5041f3f9d5e7884db62a27492d58c

omnistaging on by default

view details

push time in 16 hours

PR opened google/jax

omnistaging on by default
+767 -755

0 comment

14 changed files

pr created time in 16 hours

create barnchgoogle/jax

branch : omnistaging-on-by-default

created branch time in 16 hours

issue commentgoogle/jax

host_callback doesn't work inside grad(odeint)

Let me know if I can help! I looked at some of this bookkeeping recently.

shoyer

comment created time in 17 hours

issue commentgoogle/jax

any suggestions on how to improve performance for gradient step with odeint?

By the way, XLA:CPU (and hence JAX on CPU) has some known performance issues with float64, e.g. I believe the 64bit GEMM kernel being called is a slow Eigen one, while the 32bit one is from MKL-DNN. I mention it just because I don't want it to be a confounder, though I haven't looked at the details here enough to know if it could be relevant.

I'd love if an outcome of this investigation is that we should add some alternative VJPs for odeint, because that sounds really fun!

ibulu

comment created time in 17 hours

pull request commentgoogle/jax

Added initial implementation of numpy equivalent for trim_zeros to jax

Hey @gsp-27 , thanks for the contribution!

It turns out that trim_zeros isn't a good candidate for JIT compilation with XLA, because the output shape depends on the values of the input. Because XLA programs have static shapes in their type system, that means we'd need to create and compile a separate program for most input values. That's why you're seeing issues with _CompileAndCheck.

We could still include an implementation of trim_zeros in jax.numpy and just allow it to error when people try to include it in jit functions, but that means the upside here is limited so it might also make sense just to leave it out. If we did want to include it, we'd likely need another implementation technique rather than the Python loop implementation here, perhaps based on lax.while_loop (though I'm not sure about that).

Overall I'd recommend not adding trim_zeros for the above reason.

What do you think?

gsp-27

comment created time in a day

issue closedgoogle/jax

RuntimeError on GCP TPU

I tried to run the example code in README to use jax on TPU from a GCE VM, but I got RuntimeError.

2020-08-11 08:26:26.301652: E external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc:1068] Failed to open the gRPC driver: 12: :
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/random.py", line 85, in PRNGKey
    k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF))
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/random.py", line 81, in <lambda>
    convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lax/lax.py", line 397, in convert_element_type
    operand, new_dtype=new_dtype, old_dtype=old_dtype)
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/core.py", line 276, in bind
    return self.impl(*args, **kwargs)
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 236, in xla_primitive_callable
    backend = xb.get_device_backend(device)
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 176, in get_device_backend
    return get_backend(platform)
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 170, in get_backend
    return backend(platform)
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 148, in _get_tpu_driver_backend
    _tpu_backend = tpu_client.TpuBackend.create(worker=backend_target)
  File "/home/sj2660089/.local/lib/python3.7/site-packages/jaxlib/tpu_client.py", line 59, in create
    return _tpu_client.TpuClient.Get(worker)
RuntimeError: Unimplemented: Failed to connect to remote server at address: grpc://10.240.1.2:8470. Error from gRPC: . Details:

Any idea why this happens?

closed time in a day

JaySunnn

issue commentgoogle/jax

RuntimeError on GCP TPU

@jekbradbury @skye you are the greatest!

JaySunnn

comment created time in a day

delete branch google/jax

delete branch : fewer-ode-gpu-tests

delete time in a day

push eventgoogle/jax

Matthew Johnson

commit sha c564aca77710df0599715d4231b7d5b7dd46984a

skip more ode tests on gpu, b/c slow to compile (#4028)

view details

push time in a day

PR merged google/jax

skip more ode tests on gpu, b/c slow to compile cla: yes
+3 -3

0 comment

1 changed file

mattjj

pr closed time in a day

PR opened google/jax

skip more ode tests on gpu, b/c slow to compile
+3 -3

0 comment

1 changed file

pr created time in a day

create barnchgoogle/jax

branch : fewer-ode-gpu-tests

created branch time in a day

issue commentgoogle/jax

RuntimeError on GCP TPU

@skye can you take a look at this one?

JaySunnn

comment created time in 2 days

startedlocuslab/monotone_op_net

started time in 2 days

push eventgoogle/jax

Cambridge Yang

commit sha fe9f264b55f8b99c57f803db9eb7a2c8df897e9b

cumulative jet rules (#4000)

view details

push time in 2 days

PR merged google/jax

cumulative jet rules cla: yes

Below is my attempt at adding jet rules for cumulative operations, following jvp rules defined in lax.py

+26 -1

3 comments

2 changed files

thisiscam

pr closed time in 2 days

push eventgoogle/jax

Jamie Townsend

commit sha 1e07712955939d6f8f461fc259b12a20808782b3

Fix typos in api.py docstrings (#4021)

view details

push time in 2 days

PR merged google/jax

Fix some typos in api.py docstrings cla: yes
+20 -20

2 comments

1 changed file

j-towns

pr closed time in 2 days

pull request commentgoogle/jax

Fix some typos in api.py docstrings

Thanks for catching these! Possibly related, my 'r' key is broken and often types 0 or 2 copies of the letter! :P

j-towns

comment created time in 2 days

Pull request review commentgoogle/jax

Initial version of vmap collectives

 def process_primitive(self, primitive, tracers, params):     vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)     if all(bdim is not_mapped for bdim in dims_in):       return primitive.bind(*vals_in, **params)+    elif config.omnistaging_enabled and primitive in collective_rules:+      axes_names = params['axis_name']+      if not isinstance(axes_names, (tuple, list)):+        axes_names = (axes_names,)+      for i, axis_name in enumerate(axes_names):+        frame = core.axis_frame(axis_name)+        if frame.tag is self.master:

Yeah, I agree.

apaszke

comment created time in 2 days

pull request commentgoogle/jax

Remove type restrictions

Thank you, David!

majnemer

comment created time in 2 days

delete branch google/jax

delete branch : make-jaxpr-works-on-tracers

delete time in 2 days

push eventgoogle/jax

David Majnemer

commit sha d100327bf33515cedca41eedfc30d6fb49de7ef2

Remove type restrictions (#4017) We support half16 on TPU

view details

push time in 2 days

PR merged google/jax

Remove type restrictions cla: yes

We support half16 on TPU

+0 -1

0 comment

1 changed file

majnemer

pr closed time in 2 days

delete branch google/jax

delete branch : revise-custom-interpreters-notebook

delete time in 2 days

push eventgoogle/jax

Matthew Johnson

commit sha 09d8ac14de8c9edc2c9a14becb4963bfdebda605

use fewer internal APIs in custom interp notebook (#4016)

view details

push time in 2 days

PR merged google/jax

use fewer internal APIs in custom interp notebook cla: yes

fixes #4001

+594 -766

0 comment

1 changed file

mattjj

pr closed time in 2 days

push eventgoogle/jax

Matthew Johnson

commit sha 6a3b920507dcae7a4e4dfa513155222fa0c6feb1

make make_jaxpr work on tracer example args (#4014) (don't use xla.abstractify)

view details

David Majnemer

commit sha 265c3faa405260fb563dd492dabad0752b42e842

Remove type restrictions (#4011) We support s8, u8, s16, u16, half16 on TPU

view details

Matthew Johnson

commit sha f50216316f692d7c261aeedfc74a24e1424f9a42

use fewer internal APIs in custom interp notebook

view details

push time in 2 days

push eventgoogle/jax

David Majnemer

commit sha 265c3faa405260fb563dd492dabad0752b42e842

Remove type restrictions (#4011) We support s8, u8, s16, u16, half16 on TPU

view details

push time in 2 days

PR merged google/jax

Remove type restrictions cla: yes

We support s8, u8, s16, u16, half16 on TPU

+6 -29

1 comment

1 changed file

majnemer

pr closed time in 2 days

pull request commentgoogle/jax

Remove type restrictions

Internal tests pass!

majnemer

comment created time in 2 days

PR opened google/jax

use fewer internal APIs in custom interp notebook
+594 -766

0 comment

1 changed file

pr created time in 2 days

issue commentgoogle/jax

host_callback doesn't work inside grad(odeint)

@gnecula want to take this one?

shoyer

comment created time in 2 days

create barnchgoogle/jax

branch : revise-custom-interpreters-notebook

created branch time in 2 days

push eventgoogle/jax

Matthew Johnson

commit sha 6a3b920507dcae7a4e4dfa513155222fa0c6feb1

make make_jaxpr work on tracer example args (#4014) (don't use xla.abstractify)

view details

push time in 2 days

PR merged google/jax

make make_jaxpr work on tracer example args cla: yes

(don't use xla.abstractify)

+1 -1

0 comment

1 changed file

mattjj

pr closed time in 2 days

PR opened google/jax

make make_jaxpr work on tracer example args

(don't use xla.abstractify)

+1 -1

0 comment

1 changed file

pr created time in 2 days

create barnchgoogle/jax

branch : make-jaxpr-works-on-tracers

created branch time in 2 days

Pull request review commentgoogle/jax

DeviceArray.__iter__ returns DeviceArrays, without host sync

 def __iter__(self):     if self.ndim == 0:       raise TypeError("iteration over a 0-d array")  # same as numpy error     else:-      return self._value.__iter__()+      device = self.device_buffer.device()+      if device is None or device.platform == 'cpu':+        return iter(self._value)

Yeah, good point. We should revise this case so that it returns CPU DeviceArrays, maybe.

mattjj

comment created time in 3 days

PR opened google/jax

custom_jvp/vjp closure issues, just experimenting for now
+64 -35

0 comment

2 changed files

pr created time in 3 days

create barnchgoogle/jax

branch : custom-jvp-closure-fixes

created branch time in 3 days

push eventapaszke/jax

Matthew Johnson

commit sha 2055d6ce421afd3dd8e278c5c26fdb6a5b3addeb

deflake

view details

push time in 3 days

Pull request review commentgoogle/jax

Initial version of vmap collectives

 def process_primitive(self, primitive, tracers, params):     vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)     if all(bdim is not_mapped for bdim in dims_in):       return primitive.bind(*vals_in, **params)+    elif config.omnistaging_enabled and primitive in collective_rules:+      axes_names = params['axis_name']+      if not isinstance(axes_names, (tuple, list)):+        axes_names = (axes_names,)+      for i, axis_name in enumerate(axes_names):+        frame = core.axis_frame(axis_name)+        if frame.tag is self.master:

I know you explained this to me before, but I forget: is the tag just to handle possible shadowing? If so, after this came up in a convo with James and Roy last week, I've started to think we should make shadowing illegal, which might also help us slightly simplify this logic.

For now, just flagging this for more discussion!

apaszke

comment created time in 3 days

Pull request review commentgoogle/jax

Initial version of vmap collectives

 def _psum_transpose_rule(cts, axis_name, axis_index_groups): pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p)+batching.collective_rules[psum_p] = \+    lambda vals, dims, axis_size, **_: [v.sum(d) if d is not batching.not_mapped else

Currently we might need to lower this to a reduce-sum followed by a broadcast, as per this comment and this one. But also we might want to revise that, so that psum just means reduce-sum... I'm not sure what the best thing to do is, just wanted to flag this as a discussion topic!

apaszke

comment created time in 3 days

pull request commentgoogle/jax

cumulative jet rules

Wow, awesome! Thanks for pushing on this.

(By the way, our GitHub CI tests have been really flakey for the last ~week, failing on dependency installation. So if you see dependency installation failures, assume they're not your fault!)

thisiscam

comment created time in 3 days

delete branch google/jax

delete branch : flax-omnistaging-bug

delete time in 3 days

push eventgoogle/jax

Matthew Johnson

commit sha d46ea969533dbfe4451517ac9a5e51dbdeee6d5d

add helper for flax to be omnistaging-compatible (#4004)

view details

push time in 3 days

PR merged google/jax

add helper for flax to be omnistaging-compatible cla: yes
+9 -0

0 comment

1 changed file

mattjj

pr closed time in 3 days

PR opened google/jax

add helper for flax to be omnistaging-compatible
+9 -0

0 comment

1 changed file

pr created time in 3 days

create barnchgoogle/jax

branch : flax-omnistaging-bug

created branch time in 3 days

issue commentgoogle/jax

conda-based installation

why is this still not officially supported?

Well, we have a small team, and so even though there's a lot of work worth doing, we have to choose what to prioritize.

I believe JAX works great with WSL, so if that works on your Windows setup you might want to give it a try.

@ericmjl I love your attitude! We've benefitted a huge amount from open-source contributions, and I hope we can get a lot more over time! We're all on the same team here, doing our best to push things forward.

ericmjl

comment created time in 4 days

issue closedgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

repro:

from jax import numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (3,), dtype=jnp.bfloat16)

observation:  - crashes on TPU backends (both internal and cloud TPU as far as I can tell)  - CPU and GPU backends don't seem to crash

expect: even if bfloat16 isn't supported by random.* on TPU it would be better to error-out rather than crashing

closed time in 4 days

levskaya

issue commentgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

Thanks, @majnemer! I just double-checked on an internal TPU colab and indeed it works. Thanks to your team for the fix.

levskaya

comment created time in 4 days

issue closedgoogle/jax

Error in matrix operation

code:



def gumbel_sample(log_probs, temperature=1.0):
    """Gumbel sampling from a categorical distribution."""
    u = numpy.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
    g = -np.log(-np.log(u))
    return np.argmax(log_probs + g * temperature, axis=-1)

def predict(num_chars, prefix):
    inp = [ord(c) for c in prefix]
    result = [c for c in prefix]
    max_len = len(prefix) + num_chars
    for _ in range(num_chars):
        cur_inp = np.array(inp + [0] * (max_len - len(inp)))
        outp = model(cur_inp[None, :])  # Add batch dim.
        next_char = gumbel_sample(outp[0, len(inp)])
        inp += [int(next_char)]
       
        if inp[-1] == 1:
            break  # EOS
        result.append(chr(int(next_char)))
    
    return "".join(result)

print(predict(32, ""))

Error:

LayerError                                Traceback (most recent call last)
<ipython-input-27-6f9f9afc30e6> in <module>
     22     return "".join(result)
     23 
---> 24 print(predict(32, ""))

<ipython-input-27-6f9f9afc30e6> in predict(num_chars, prefix)
     12     for _ in range(num_chars):
     13         cur_inp = np.array(inp + [0] * (max_len - len(inp)))
---> 14         outp = model(cur_inp[None, :])  # Add batch dim.
     15         next_char = gumbel_sample(outp[0, len(inp)])
     16         inp += [int(next_char)]

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
    165       self.state = state  # Needed if the model wasn't fully initialized.
    166     state = self.state
--> 167     outputs, new_state = self.pure_fn(x, weights, state, rng)
    168     self.state = new_state
    169     self.weights = weights

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    448       name, trace = self._name, _short_traceback(skip=3)
    449       raise LayerError(name, 'pure_fn',
--> 450                        self._caller, signature(x), trace) from None
    451 
    452   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-13-6f7ffe0c061e>, line 21
  layer input shapes: ShapeDtype{shape:(1, 32), dtype:int32}

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Embedding_256_512 (in pure_fn):
  layer created in file [...]/<ipython-input-13-6f7ffe0c061e>, line 18
  layer input shapes: ShapeDtype{shape:(1, 32), dtype:int32}

  File [...]/trax/layers/core.py, line 150, in forward
    return jnp.take(self.weights, x, axis=0)

  File [...]/jax/numpy/lax_numpy.py, line 3298, in take
    slice_sizes=tuple(slice_sizes))

  File [...]/jax/lax/lax.py, line 835, in gather
    slice_sizes=canonicalize_shape(slice_sizes))

  File [...]/site-packages/jax/core.py, line 273, in bind
    return self.impl(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 228, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)

  File [...]/jax/interpreters/xla.py, line 262, in xla_primitive_callable
    *avals, **params)

  File [...]/jax/interpreters/xla.py, line 320, in primitive_computation
    raise RuntimeError(msg) from e

RuntimeError: Invalid argument: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.: 
This is a bug in JAX's shape-checking rules; please report it!

closed time in 4 days

Keshav15

issue commentgoogle/jax

Error in matrix operation

No worries, we always welcome issues, and also welcome them fixing themselves :)

Keshav15

comment created time in 4 days

issue closedgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

The tensordot on this line does not perform the intended broadcast matrix-vector multiply. I think you need this (or equivalent),

return mean + np.einsum('...ij,...j->...i', chol_factor, normal_samples)

Here's an example:

import jax.numpy as np
import jax.random

mean = np.zeros((10, 4))
cov = np.eye(4)[None,...].repeat(10, axis=0)
rng = jax.random.PRNGKey(0)
sample = jax.random.multivariate_normal(rng, mean, cov)
print(sample.shape) # is (10, 10, 4); should be (10, 4).

closed time in 4 days

slinderman

issue commentgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

Merged the fix from Scott in #4002!

slinderman

comment created time in 4 days

push eventgoogle/jax

Scott Linderman

commit sha ea88c55f555780cdccc27c13d19fd393e695e120

Fixes and tests for jax.random.multivariate_normal (#4002) * Fix bug #3997, change `jax.random.multivariate_normal` to handle batches of covariance matrices. It works as long as mean and covariance are broadcast-compatible, as specified in the docstring. * Fix bug in multivariate_normal shape checking Minor bug: should be checking for compatibility of `shape`, `mean`, and the the last two dimensions of the _covariance_ matrix. * Add test for multivariate_normal shapes This test checks that `jax.random.multivariate_normal` produces the expected output shape for various combinations of event dimension and `mean`, `covariance`, and `shape` shapes. * Fix linter issues in tests/random_test.py Trimming trialing whitespace and 80 char limit. * Really trimming whitespace in tests/random_test.py Arg. Have to fix my editor to do this automatically.

view details

push time in 4 days

PR merged google/jax

Fixes and tests for jax.random.multivariate_normal cla: yes

This PR addresses bug https://github.com/google/jax/issues/3997 and another shape issue exposed during testing.

+27 -2

0 comment

2 changed files

slinderman

pr closed time in 4 days

issue commentgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

I left this bug in there just for you to find, Scott.

slinderman

comment created time in 5 days

issue commentgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

Yeah! Let's get @slinderman's first of many JAX PRs here!

slinderman

comment created time in 5 days

issue commentgoogle/jax

any suggestions on how to improve performance for gradient step with odeint?

As a wild guess, not packing values together with jnp.stack (as in the return value of SEIRD_mobility_coupled) can help both compilation time and performance.

ibulu

comment created time in 5 days

issue commentgoogle/jax

any suggestions on how to improve performance for gradient step with odeint?

Is it slow to execute (after compiling), or just to compile? That is, if you try evaluating the gradient a second time, is it faster on the second evaluation?

ibulu

comment created time in 5 days

issue commentgoogle/jax

Defining both custom_jvp and custom_jvp

Thanks for the question!

No, that's not currently possible. It's documented in the custom derivatives tutorial (see the subsection "Forward-mode autodiff cannot be used on the jax.custom_vjp function and will raise an error" for more).

Can you say more about why you might need to define both?

thisiscam

comment created time in 5 days

issue commentgoogle/jax

inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x))

Okay, back to it!

Notice the Jacobian of the agg_loss function written with the broadcast is [[0, 1], [0, 1]]:

import jax.numpy as jnp
from jax import lax
from jax import jacfwd

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(0), (2,))
  return jacfwd(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [[0. 1.]
#  [0. 1.]]

So, while keeping the current (at HEAD) bulk array definition of psum as a reduce-sum followed by a broadcast, the SPMD AD semantics is consistent so long as we take grad to mean "compute the VJP against a ones vector broadcast along all named axes":

import jax.numpy as jnp
from jax import vjp, jvp, pmap, grad
from jax import lax


### reverse-mode

# At HEAD, we define this SPMD program:
def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [0. 2.]

# To mean the same as this bulk array vjp-with-ones program (`grad` is always
# defined as vjp-with-ones plus an error check for scalar outputs that we don't
# include in the definition of SPMD semantics):
def grad2(f):
  def gradfun(x):
    ans, f_vjp = vjp(f, x)
    x_bar, = f_vjp(jnp.ones_like(ans))
    return x_bar
  return gradfun

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
  return grad2(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [0. 2.]


# ### forward-mode

# At HEAD, we define this SPMD program:
def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [1. 1.]

# To mean the same as this bulk array jvp-with-ones program:
def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
  return jvp(agg_loss, (w,), (jnp.ones_like(w),))[1]
print(f(jnp.ones(2), jnp.arange(2)))
# [1. 1.]

(In Autograd, like TF today, we used to define grad as vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn't make that check work with SPMD functions, in the sense that grad will in effect happily allow broadcasting along named mapped axes!)

If the semantics at HEAD are self-consistent, except the error semantics for grad, do we need to change anything, except perhaps to avoid this potential confusion by making grad error semantics consistent in the positional and SPMD worlds?

Maybe yes. One problem with the current semantics is that if we make @sharadmv's use of grad here an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they'd have to write vjp-with-ones themselves, e.g. by defining grad2 as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we're calling grad/grad2 on a scalar-input scalar-output function (but for the closed-over value of data which is different on each device) with the same primal input value on every device, yet getting different grad2 results on different devices (perhaps not noticing that if we looked at the primal output value we'd also have a different value on each device, which might make getting different gradients less surprising).

In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for grad), this example has clearly shown that they can be confusing, especially when closing over mapped data.

sharadmv

comment created time in 6 days

issue commentgoogle/jax

inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x))

I changed the issue title because I think the current (i.e. at HEAD) semantics make sense, but can make grad(f)(x) and jvp(f, (x,), (1.,)) differ for scalar-input scalar-output f when inside a pmap, which seems worth revising. The current semantics just take a different correspondence to bulk array programs than the one in @jekbradbury's previous comment.

The original intention with psum was to correspond to a bulk array operation that included broadcasting, not just a reduce-sum:

pmap(lambda x: psum(x, 'i'), axis_name='i')
==
lambda x: lax.broadcast(x.sum(0), (num_devices,))

So given this pmap code:

def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))

we should expect it to behave like this bulk array version:

import jax.numpy as jnp
from jax import grad
from jax import lax

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))
  return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))

But that code triggers an error, because in api.py we check to ensure grad is only being applied to scalar-output functions. Just for defining semantics, let's sidestep that error check:

import jax.numpy as jnp
from jax import vjp
from jax import lax

def grad(f):
  def gradfun(x):
    ans, f_vjp = vjp(f, x)
    x_bar, = f_vjp(jnp.ones_like(ans))
    return x_bar
  return gradfun

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))
  return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]

In effect, we had to write our own grad that was willing to produce a broadcasted ones cotangent vector.

These SPMD semantics explain the reverse-mode answer in my first comment above.

The trouble with this definition, though, is that it can make grad disagree with jvp and numerical differences on what appears locally to be a scalar-input scalar-output function when inside a pmap, as in the examples in my first comment. I want to write out an explanation for what's going on, but I've got to step away for a moment and wanted to send this comment first. To be continued!

In any case, I think @jekbradbury 's proposal for revising the semantics is likely to be better. I just want to pin down both the old and new semantics as best we can.

sharadmv

comment created time in 6 days

push eventgoogle/jax

Matthew Johnson

commit sha 161e81cebed0432c7d211ca33a837aadc2d7d0c0

support collectives with new pmap semantics Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: Roy Frostig <frostig@google.com>

view details

push time in 6 days

create barnchgoogle/jax

branch : avals-with-names

created branch time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

@jekbradbury that's brilliant! Let's do it.

sharadmv

comment created time in 7 days

pull request commentgoogle/jax

Rm two unused lines of code from lax_parallel.psum_bind

This was a bad merge on my part! I believe this passes internal tests.

j-towns

comment created time in 7 days

issue commentgoogle/jax

Prevent statements with side effects from being re-ordered

Yeah that's my thinking too!

dionhaefner

comment created time in 7 days

issue commentgoogle/jax

Prevent statements with side effects from being re-ordered

The omnistaging change ensures that we trace the Python to a jaxpr in a way that preserves operation ordering, but I suspect we still need to thread tokens in the XLA HLO lowering. That is, JAX now has the tools to avoid reordering without user-level tokens, but XLA may still reorder things.

@hawkinsp does that sound right to you?

dionhaefner

comment created time in 7 days

issue commentgoogle/jax

Very slow compile when doing stochastic variational inference on the parameters of an ODE integrator

Thanks for raising this; that's too slow! Which backend is this on (CPU/GPU/TPU)? It might be faster or slower to compile on different backends (often slowest on CPU, fastest on TPU).

ahmadsalim

comment created time in 7 days

issue commentgoogle/jax

Cannot use static scalar arrays in comparison when using omnistaging

I wonder if this is an interaction between the laziness of np.ones and omnistaging...

dionhaefner

comment created time in 7 days

issue commentgoogle/jax

Implement complex QR decomposition in HLO (TPU)

GE and LT are "greater than" and "less than", respectively. That error is saying that those operations aren't implemented for complex numbers. (NumPy implements them with a kind of weird convention, since complex numbers don't have a total ordering.)

mganahl

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

I'm starting to agree with James that we might need to track spmd-symmetry in jaxprs, and reverse mode should do something interesting with that information. It's tricky: if we have a jaxpr to transpose with both symmetric and non-symmetric outputs, we may need to run two backward passes, one for the symmetric outputs (where we symmetrize the resulting input cotangnents) and one for the non-symmetric outputs (where we don't).

sharadmv

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

Thanks! One possible desideratum is not to do something different on guaranteed-symmetric values versus just-happens-to-be-symmetric values. Notice that in this example if data just happens to be symmetric, we get the right answer. Also, if we noticed (say in backward_pass) that the function to be transposed is symmetric (in the sense that when it's given a symmetric input it produces a symmetric output), and in that case always symmetrized the cotangents of the jaxpr inputs, then we'd compute the right answer in both guaranteed-symmetric (new!) and just-happens-to-be-symmetric (already got that one) cases.

sharadmv

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

Can you define 'replicated thing' and 'non-replicated thing'? Do you mean a function (agg_loss in this case), and if so what's it mean to be replicated (maybe: has any collectives in it) ?

sharadmv

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

This is confusing! It seems to be an issue about how we define transposes of SPMD functions (i.e. it's not just a dumb software bug). No conclusions yet, but wanted to dump some thoughts here.

Here are some test programs:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax

def distributed(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]

def distributed(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]

def distributed(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]

The jaxpr being transposed in the last one is

{ lambda c ; b.
  let d = mul b c
      e = psum[ axis_index_groups=None
                axis_name=batch ] d
  in (e,) }

Operationally, the transpose evaluation proceeds like this: grad feeds in a cotangent for e of 1.0. Then we compute a cotangent for d of 2.0, reflecting the fact that if we perturb the value of d by epsilon (on each replica) then the value of e changes by 2 epsilon. Then we compute a cotangent for b by multiplying by 0. on the first replica and 1. on the second replica, leading to the final result [0. 2.].

Beyond the numbers, we have a symmetry issue. The first two versions produce a symmetric-across-replicas result, which makes sense because agg_loss's last operation is a psum and so its output must be symmetric across replicas. But the reverse-mode version can't produce a symmetric result because it multiplies by the mapped value data at the end. (The symmetric result makes the most sense because agg_loss always produces a symmetric result given a symmetric input.)

Hmmmm...

sharadmv

comment created time in 7 days

issue openedgoogle/jax

improve jnp.tile lowering to broadcast/reshape

created time in 7 days

Pull request review commentgoogle-research/dex-lang

Expand raytracer example

 def sampleCosineWeightedHemisphere (k:Key) (normal: Vec3) : Vec3 =   rr = (rx .* uu) + (ry .* vv) + (rz .* normal)   normalize rr +' ### Raytracer++Distance = Real+def Image (n:Int) :Type = Fin n => Fin n => Color -- TODO: hide the size++Position  = Vec 3+Direction = Vec 3  -- Should be normalized. TODO: use a newtype wrapper --- Misc. rendering params.-MAX_ITERS = 50-HORIZON = 20.0-MAX_DEPTH = 3--def raymarch (ro:Vec3) (rd: Vec3) : (Object & Distance) =-  -- Move along ray until we hit an obect.-  def step (i:(Fin MAX_ITERS)) (pair:(Object & Distance)) : (Object & Distance) =-    (_, t) = pair-    (obj, distance) = sdScene (ro + (t .* rd))-    (obj, t + distance)--  (obj_id, t) = fold (Obj_None, 0.0) step-  obj_id = select (t > HORIZON) Obj_None obj_id -  (obj_id, t)---def calcNormal (p:Vec3) : Vec3 =-  dist = \p. snd $ sdScene p-  -- normalize(grad(dist)(p))--  -- derivative approximation via midpoint rule.-  -- Todo: Switch to autodiff when it works.-  eps = 0.001-  dx = [eps, 0.0, 0.0]-  dy = [0.0, eps, 0.0]-  dz = [0.0, 0.0, eps]-  -- extract just the distance component-  nor = [(dist(p+dx)) - (dist(p-dx)),-         (dist(p+dy)) - (dist(p-dy)),-         (dist(p+dz)) - (dist(p-dz))]-  normalize nor---def sample_light_point (key:Key) : Vec3 =-  -- Samples a point uniformly from the surface of the light.-  -- Todo: remove hardcoded values-  -- Todo: Allow multiple lights-  (key1, key2) = splitKey key-  p_light_x = randuniform key1 (-0.25) (0.25)-  p_light_z = randuniform key2 (2.0 - 0.25) (2.0 + 0.25)-  [p_light_x, 3.9, p_light_z]---def direct_light (key:Key) (p:Vec3) (nor:Vec3) (brdf:Color) : Vec3 =--  -- Check for line-of-sight to a random point on the light.-  p_light = sample_light_point key-  wi_light = normalize (p_light - p)-  (obj2, t2) = raymarch (p + 0.001 .* nor) wi_light--  -- Compute radiance of light from this direction.-  vis = islight obj2-  cos1 = relu (vdot nor wi_light)-  cos2 = relu (vdot nor_light (negvec wi_light))-  pdf_A = 1.0 / LIGHT_AREA-  square_distance = sum $ for i. sq (p_light.i - p.i)-  scale = (cos1 * cos2) / (pdf_A * square_distance)-  li = for i. scale * emittedRadiance.i * brdf.i--  case vis of-    False -> zero-    True -> li---def scatter_eye_rays-    (depth:(Fin MAX_DEPTH)) (carry:(Key & Vec3 & Vec3)) :-    ((Key & Vec3 & Vec3) & (Bool & Vec3 & Color)) =-  (rng, ro, rd) = carry-  (obj, t) = raymarch ro rd-  brdf = brdf_map obj-  -  is_light = islight obj-  did_intersect = not $ isnone obj--  li_e = (relu ( vdot (negvec rd) nor_light)) .* emittedRadiance-  radiance = select is_light li_e zero-  -  p = ro + t .* rd-  nor = calcNormal p-  -  -- Contribution directly from light.-  (rng, subkey) = splitKey rng  -  li_direct = direct_light subkey p nor brdf-  radiance = radiance + select did_intersect li_direct zero--  -- Sample bounced ray for future indirect contributions.-  (rng, subkey) = splitKey rng-  rd2 = sampleCosineWeightedHemisphere subkey nor--  carry = (rng, ro, rd)-  outputs = (did_intersect, radiance, brdf)-  (carry, outputs)----- Add light from each step if there was an intersection.-def accumulate_outgoing (li_indirect:Vec3) (x:(Bool & Vec3 & Color)) : Vec3 =-  (did_intersect, radiance, brdf) = x-  radiance + select did_intersect (for d. brdf.d * li_indirect.d) zero ---- Main loop.-def trace (rng:Key) (ro:Vec3) (rd:Vec3) (depth:(Fin MAX_DEPTH)) : Vec3 =-  init = (rng, ro, rd)-  (carry, outputs) = scan init scatter_eye_rays              -- Forward pass.-  fold zero \i c. accumulate_outgoing c (reverse outputs).i  -- Backward pass.---' Setup and draw image--num_samples = 3-N = 500  -- pixel width and height of image.--xs = linspace (Fin N) 1.0 0.0  -- Reverse order because of pinhole camera-rng = newKey 0+BlockHalfWidths = Vec 3+Radius = Real+Radiance = Color++data ObjectGeom =+  Wall Direction Distance+  Block Position BlockHalfWidths Angle+  Sphere Position Radius++data Surface =+  Matte Color+  Mirror++OrientedSurface = (Direction & Surface)++data Object =+  PassiveObject ObjectGeom Surface+  -- position, half-width, intensity (assumed to point down)+  Light Position Real Radiance++Ray = (Position & Direction)+Filter   = Color++-- TODO: use a record+-- num samples, num bounces+Params = (Int & Int)++-- TODO: use a list instead, once they work+data Scene n:Type = MkScene (n=>Object)++def sampleReflection ((nor, surf):OrientedSurface) ((pos, dir):Ray) (k:Key) : Ray =+  newDir = case surf of+    Matte _ -> sampleCosineWeightedHemisphere nor k+    -- TODO: surely there's some change-of-solid-angle correction we need to+    -- consider when reflecting off a curved surface.+    Mirror  -> dir - (2.0 * dot dir nor) .* nor+  (pos, newDir)++def probReflection ((nor, surf):OrientedSurface) (_:Ray) ((_, outRayDir):Ray) : Real =+  case surf of+    Matte _ -> relu $ dot nor outRayDir+    Mirror  -> 0.0  -- TODO: this should be a delta function of some sort++def applyFilter (filter:Filter) (radiance:Radiance) : Radiance =+  for i. filter.i * radiance.i++def surfaceFilter (filter:Filter) (surf:Surface) : Filter =+  case surf of+    Matte color -> for i. filter.i * color.i+    Mirror      -> filter++def sdObject (pos:Position) (obj:Object) : Distance =+  case obj of+    PassiveObject geom _ -> case geom of+      Wall nor d -> d + dot nor pos+      Block blockPos halfWidths angle ->+        pos' = rotateY (pos - blockPos) angle+        length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0+      Sphere spherePos r ->+        pos' = pos - spherePos+        max (length pos' - r) 0.0+    Light squarePos hw _ ->+      pos' = pos - squarePos+      halfWidths = [hw, 0.01, hw]+      length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0++def sdScene (scene:Scene n) (pos:Position) : (Object & Distance) =+  (MkScene objs) = scene+  minimumBy snd for i.+    obj = objs.i+    (obj, sdObject pos obj)++-- TODO: use AD!+def calcNormal (scene:Scene n) (pos:Position) : Direction =+  dist = \p. snd $ sdScene scene p+  normalize (gradNumerical dist pos)++data RayMarchResult =+  -- incident ray, surface normal, surface properties+  HitObj Ray OrientedSurface+  HitLight Radiance+  -- Could refine with failure reason (beyond horizon, failed to converge etc)+  HitNothing++def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult =+  max_iters = 100+  tol = 0.01+  startLength = 10.0 * tol  -- trying to escape the current surface+  (rayOrigin, rayDir) = ray+  iter (10.0 * tol) \i rayLength.+    case i >= max_iters of+      True -> Done HitNothing+      False ->+        rayPos = rayOrigin + rayLength .* rayDir+        (obj, d) = sdScene scene $ rayPos+        -- 0.9 ensures we come close to the surface but don't touch it+        dNew = rayLength + 0.9 * d+        case d < tol of+          False -> Continue $ dNew+          True ->+            surfNorm = calcNormal scene rayPos+            case positiveProjection rayDir surfNorm of+              True ->+                -- Oops, we didn't escape the surface we're leaving..+                -- (Is there a more standard way to do this?)+                Continue dNew+              False ->+                -- We made it!+                Done $ case obj of+                  PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf)+                  Light _ _ radiance   -> HitLight radiance++def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance =+  case raymarch scene ray of+    HitLight intensity -> intensity+    HitNothing -> zero+    HitObj _ _ -> zero++def sampleSquare (hw:Real) (k:Key) : Position =+ (kx, kz) = splitKey k+ x = randuniform (- hw) hw kx+ z = randuniform (- hw) hw kz+ [x, 0.0, z]++def sampleLightRadiance+      (scene:Scene n) (osurf:OrientedSurface) (inRay:Ray) (k:Key) : Radiance =+  (surfNor, surf) = osurf+  (rayPos, _) = inRay+  (MkScene objs) = scene+  snd $ withAccum \radiance.+    for i. case objs.i of+      PassiveObject _ _ -> ()+      Light lightPos hw _ ->+        (dirToLight, distToLight) = directionAndLength $+                                      lightPos + sampleSquare hw k - rayPos+        case positiveProjection dirToLight surfNor of+          False -> () -- light on the far side of current surface+          True ->+            fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight)+            outRay = (rayPos, dirToLight)+            coeff = fracSolidAngle * probReflection osurf inRay outRay+            radiance += coeff .* rayDirectRadiance scene outRay++def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color =+  (_, max_bounces) = params+  -- TODO: we ought to be able to use an accumulator here, but there's a bug+  noFilter = [1.0, 1.0, 1.0]+  iter (noFilter, zero, init_ray) $+    \i (filter, radiance, ray).+      case i >= max_bounces of+        True -> Done radiance+        False -> case raymarch scene ray of+          HitNothing -> Done radiance+          HitLight intensity -> case i == 0 of+            True  -> Done intensity  -- TODO: scale etc+            False -> Done radiance+          HitObj incidentRay osurf ->+            (k1, k2) = splitKey $ hash k i+            lightRadiance    = sampleLightRadiance scene osurf incidentRay k1+            outRayHemisphere = sampleReflection          osurf incidentRay k2+            newFilter = surfaceFilter filter (snd osurf)+            newRadiance = radiance + applyFilter newFilter lightRadiance+            Continue (newFilter, newRadiance, outRayHemisphere)++def avgRayColor (params:Params) (scene:Scene m) (ray:Ray) (k:Key) : Color =+  (num_samples, _) = params+  sampleAveraged (trace params scene ray) num_samples k++-- TODO: add number of pixels once we can hide sizes+-- sensor half-width, pinhole-sensor distance, pinhole position+-- (Assumes we're looking towards -z.)+Camera = (Position & Real & Real)++def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => Ray =+  -- images indexed from top-left+  (pos, halfWidth, sensorDist) = camera+  ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth+  xs =           linspace (Fin n) (neg halfWidth) halfWidth+  for i j. (pos, normalize [xs.j, ys.i, neg sensorDist])++def takePicture+       (params:Params) (scene:Scene m) (n:Int) (camera:Camera)+      : ColorImage n n =+  rays = cameraRays n camera+  rootKey = newKey 0+  for i j.+    k = ixkey (ixkey rootKey i) j

This should be called "the David trick" because he always suggests it and it always works.

dougalm

comment created time in 7 days

issue commentgoogle/jax

Slow Python tracing time with odeint inside tfp.optimizer.lbfgs_minimize

Copying from that comment on 3370 for convenience, I think we just need to memoize this line to avoid retracing.

tfrerix

comment created time in 8 days

issue commentgoogle/jax

Bad interaction between batch trace and custom jvp trace

Thanks for flagging this, and the great repros. I think this is a duplicate of a known bug, which I should look up and link here. The basic issue is that nondiff_argnums works basically like lexical closure over tracers, but that check is disallowing lexical closure over tracers!

@NeilGirdhar has a nice solution, and I believe a PR that I still haven't gotten to.

@brianwa84 how blocking is this for your work? I'd like to get to it this week, but I'm just trying to prioritize stuff. Let me know :)

brianwa84

comment created time in 8 days

delete branch google/jax

delete branch : omnistaging3

delete time in 8 days

issue commentgoogle/jax

Performance regression in compilation of JITed Jacobian

Thanks for letting us know about this! (By the way, unfortunately 64bit matmul on CPU can be slow to execute too: #3832.)

@hawkinsp any thoughts? Is this one for XLA:CPU folks?

ahoenselaar

comment created time in 8 days

pull request commentgoogle/jax

Cleanup: update license copyrights to 2020

I think copyrights are meant to be from when the file was written, and so they shouldn't be uniformly bumped to the current year.

jakevdp

comment created time in 9 days

push eventgoogle/jax

Adrià Puigdomènech

commit sha 4e873f417ab4f3e68a102548167f2b0a005edfad

Avoid re-flattening in jit() when no donate_argnums are present. (#3945) Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.

view details

push time in 10 days

pull request commentgoogle/jax

Avoid re-flattening in jit() and pmap() when no donate_argnums are present.

Thank you!

kosklain

comment created time in 10 days

PR merged google/jax

Avoid re-flattening in jit() and pmap() when no donate_argnums are present. cla: yes

Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.

+8 -2

0 comment

1 changed file

kosklain

pr closed time in 10 days

pull request commentgoogle/jax

Avoid lexically capturing the train_images value in MNIST VAE example.

IIUC @froystig correctly, #3238 would have the nice effect of turning this closed-over constant into a buffer kept on device. So that's pretty cool.

Seems like we should add device_put to this PR and merge it, and also follow up on #3238 (which has bigger picture considerations, but also might let us simplify this example back down again if we want to).

hawkinsp

comment created time in 10 days

more