profile
viewpoint
Matthew Johnson mattjj Google San Francisco people.csail.mit.edu/~mattjj research scientist @ Google Brain

google/jax 10446

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

google-research/dex-lang 701

Research language for array processing in the Haskell/ML family

mattjj/autodidact 541

A pedagogical implementation of Autograd

google/xls 450

XLS: Accelerated HW Synthesis

duvenaud/relax 137

Optimizing control variates for black-box gradient estimation

jacobjinkelly/easy-neural-ode 100

Code for the paper "Learning Differential Equations that are Easy to Solve"

google-research/autoconj 35

Recognizing and exploiting conjugacy without a domain-specific language

duvenaud/jaxde 27

Prototypes of differentiable differential equation solvers in JAX.

mattjj/config-fish 4

my fish configuration

mattjj/config-vim 4

my .vim

PR merged google/jax

Cleanup: remove unnecessary rng_factory boilerplace in lax_test.py cla: yes pull ready

Similar to #5086

+203 -289

0 comment

1 changed file

jakevdp

pr closed time in 2 hours

push eventgoogle/jax

Jake VanderPlas

commit sha 3ea15b45410e36e40277e78a92e4864b69a2e0a1

Cleanup: remove unnecessary rng_factory boilerplace in lax_test.py

view details

jax authors

commit sha f0b13c19abaacae089c4d5ef912bb46996b1c41a

Merge pull request #5093 from jakevdp:cleanup-boilerplate PiperOrigin-RevId: 345590163

view details

push time in 2 hours

pull request commentgoogle/jax

added prepend and append to diff

A sync & rebase ~an hour ago worked for #5060. Would you mind trying again?

saklani

comment created time in 2 hours

issue commentgoogle/jax

Switch default PRNG to ThreeFry4x32

https://github.com/google/jax/issues/2294 might be related? It would allow you to make this change more transparently since no code would assume a key size; it would be hidden behind the opaque object.

hawkinsp

comment created time in 2 hours

PR merged google/jax

Implement statistic pad mode in jax.numpy.pad cla: yes pull ready

Implemented maximum, mean, median, and minimum padding modes.

Related #5010

+115 -21

9 comments

2 changed files

minoring

pr closed time in 2 hours

push eventgoogle/jax

minoring

commit sha 327ac8dd149e85c7a512eff1ba9f6da0b11bfc0f

Implement statistic pad mode in jax.numpy.pad

view details

jax authors

commit sha c7d19fd861bd8ce931ed38d68e2b959e95684ce7

Merge pull request #5060 from minoring:support-padding-modes PiperOrigin-RevId: 345588715

view details

push time in 2 hours

issue commentgoogle/jax

Improve __str__/__repr__ of PyTreeDefs

This would be really useful. I actually had to improve this for my code too since my PyTrees are really complicated. I frequently have to diff PyTreeDefs returned when the custom VJP complains that they don't match. It might also be nice to consider displaying PyTreeDef on multiple lines, with sub-elements indented.

j-towns

comment created time in 2 hours

pull request commentgoogle/jax

Implement statistic pad mode in jax.numpy.pad

Syncing to master worked! Thanks.

minoring

comment created time in 3 hours

issue commentgoogle/jax

Using non-nvidia GPU

@hawkinsp @inailuig

Thank you for trying out JAX on AMD GPUs. I am on the TF framework team in AMD, and would like to get a better understanding of the TF changes that are required to get JAX working. We would be more than happy to help out.

I also had a question for you. Does JAX have unit-tests that run on GPUs, and if so can you point me to the directions to run them. I would like to get them running on internally on our platform,

thanks again

deven

ricardobarroslourenco

comment created time in 4 hours

created repositorytil-unc/mhcseqs

MHC sequences

created time in 4 hours

PR merged google/jax

Add test for chkstk_darwin symbol to jaxlib Mac builds. cla: yes pull ready

We don't know why some builds produce this and others do not, but we can at least test for it to prevent bad releases.

Issue #3867

+23 -0

0 comment

1 changed file

hawkinsp

pr closed time in 5 hours

push eventgoogle/jax

Peter Hawkins

commit sha 4a774978a2afae07eed50dd320aeb089f660a656

Add test for chkstk_darwin symbol to jaxlib Mac builds. We don't know why some builds produce this and others do not, but we can at least test for it to prevent bad releases.

view details

jax authors

commit sha 6f3a0c7bbe795ff0418fdcc0bba06165015c95e0

Merge pull request #5092 from hawkinsp:jaxlib PiperOrigin-RevId: 345560236

view details

push time in 5 hours

issue commentgoogle/jax

Tensorboard profiling fails with jax

Hi @tomhennigan, will you be able to list the steps you've taken so far?

For example:

  1. Update/uninstall CUDA (in Colab):
!apt-get --purge remove cuda nvidia* libnvidia-*
!dpkg -l | grep cuda- | awk '{print $2}' | xargs -n1 dpkg --purge
!apt-get remove cuda-*
!apt autoremove
!apt-get update
  1. Install CUDA 11.1:
# Check that your version of Ubuntu on Colab is 18.04
!lsb_release -a
# Download CUDA 11.1 for Ubuntu 18.04
# (source: https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=deblocal
!wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
!sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
!wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda-repo-ubuntu1804-11-1-local_11.1.1-455.32.00-1_amd64.deb
!sudo dpkg -i cuda-repo-ubuntu1804-11-1-local_11.1.1-455.32.00-1_amd64.deb
!sudo apt-key add /var/cuda-repo-ubuntu1804-11-1-local/7fa2af80.pub
!sudo apt-get update
!sudo apt-get -y install cuda
  1. Check the CUDA version:
!nvcc --version
  1. Install cuDNN 8.0.6 by following https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux (requires to be logged in to download).

  2. Install TensorFlow nightly and TensorBoard profiler nightly (Source: https://github.com/google/jax/issues/4427)

!pip install --upgrade tf-nightly tbp-nightly
  1. Import JAX and JAX Profiler:
import jax
import jax.profiler

etc

tomweingarten

comment created time in 6 hours

pull request commentgoogle/jax

added prepend and append to diff

Yeah, I'm not sure what's going on. The same errors are blocking #5060 as well.

saklani

comment created time in 8 hours

issue closedgoogle/jax

jax.scipy.special.multigammaln fails under jit

Brief example below.

The need for d to be static is an implementation detail; the shape of the function result does not depend on the value of d.

Easiest fix would be to have better errors for a traced d.

Harder fix (and maybe not necessary) would be to implement the function in a way such that d need not be static.

from jax.scipy.special import multigammaln
from jax import jit
jit(multigammaln)(1, 2)
Traceback (most recent call last):
  File "tmp.py", line 4, in <module>
    jit(multigammaln)(1, 2)
  File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
    lax.div(jnp.arange(d), _constant_like(a, 2))),
  File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
    start = require(start, msg("stop"))
jax._src.traceback_util.FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "tmp.py", line 4, in <module>
    jit(multigammaln)(1, 2)
  File "/Users/vanderplas/github/google/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/vanderplas/github/google/jax/jax/api.py", line 218, in f_jitted
    out = xla.xla_call(
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 1226, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 1217, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 1229, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 595, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 569, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 251, in memoized_fun
    ans = call(fun, *args)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 645, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
    lax.div(jnp.arange(d), _constant_like(a, 2))),
  File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
    start = require(start, msg("stop"))
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 919, in concrete_or_error
    raise_concretization_error(val, context)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 896, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>

closed time in 9 hours

jakevdp

issue commentgoogle/jax

jax.scipy.special.multigammaln fails under jit

I'm going to close this: I think requiring statid d is fine, and #5074 made improved the error message. In the scipy version, only an integer value is accepted.

jakevdp

comment created time in 9 hours

PR merged google/jax

Reviewers
[jax2tf] Add note that the TF.js conversion is experimental. cla: yes pull ready
+4 -2

0 comment

2 changed files

bchetioui

pr closed time in 9 hours

push eventgoogle/jax

Benjamin Chetioui

commit sha e4c5be94a3c1f07b4d562ac2ba1d0c1866505a63

[jax2tf] Add note that the TF.js conversion is experimental.

view details

jax authors

commit sha 7f6f5f2f4e1989e38247a0e37144bea5b2d72d16

Merge pull request #4997 from bchetioui:tfjs_exp PiperOrigin-RevId: 345516398

view details

push time in 9 hours

pull request commentgoogle/jax

added prepend and append to diff

Yeah, I gave it a shot. Let's see what happens.

saklani

comment created time in 9 hours

Pull request review commentgoogle/jax

make convert_element_type_p not require old_dtype

 def testBroadcastInDim(self):     self._Check(make_const, expected)    def testConvertElementTypeMismatchedDTypeOldDType(self):

I think this one is pretty clearly obsolete at this point! I’d be fine removing it.

mattjj

comment created time in 9 hours

fork yifeif/ray

A system for parallel and distributed Python that unifies the ML ecosystem.

https://ray.readthedocs.io/en/latest/

fork in 9 hours

PR merged google/jax

Cleanup: remove unnecessary rng_factory boilerplate in lax_numpy_test cla: yes pull ready

This was previously the accepted style for specifying tests; now we generally specify rngs directly.

+185 -233

0 comment

1 changed file

jakevdp

pr closed time in 9 hours

push eventgoogle/jax

Jake VanderPlas

commit sha 01fbf780b1110f00314b5ce0cf78a4c3cd284320

Cleanup: remove unnecessary rng_factory boilerplate in lax_numpy_test

view details

jax authors

commit sha 7147d272bd24ceba58d50fee5805503a309c67e4

Merge pull request #5086 from jakevdp:test-rng-cleanup PiperOrigin-RevId: 345508883

view details

push time in 9 hours

PR merged google/jax

improve sinc jvp at zero cla: yes pull ready

Thanks to @dougalm for thinking this through with me!

The basic problem in #5054 is that we had for sinc an expression like

  return where(x == 0, 1., sin(pi * x) / (pi * x))

One way to read that is at x == 0, we're replacing the function with its zeroth-order Taylor expansion. But that meant at zero we aren't computing its correct derivatives! (We were getting lucky with the first derivative at zero: it happens to be zero because that's the slope of the sinc function at zero, but we were computing zero because we were differentiating a constant function.)

One way to think about the fix here is that we're replacing the constant function at zero with a truncated Taylor series at zero. For any fixed Taylor series truncation level, we'd only be able to support some finite order of autodiff before incorrectly truncating the derivatives to zero. So we play a trick: we use a custom_jvp rule to in effect generate all the Taylor series coefficients we need, lazily as we differentiate.

fixes #5054

+49 -4

3 comments

2 changed files

mattjj

pr closed time in 10 hours

push eventgoogle/jax

Matthew Johnson

commit sha 8cd855507179e21bfdc0e518025bf8957fc8e342

improve sinc jvp at zero, fixes #5054

view details

Matthew Johnson

commit sha e6c01aeacd70064bd3e6a160892ee51a713c1cda

fix broadcasting bug in sinc jvp rule

view details

jax authors

commit sha fa573ef07b28c9e44ad008fc3cbea60a57d2c07a

Merge pull request #5077 from google:sinc-jvp PiperOrigin-RevId: 345505504

view details

push time in 10 hours

issue closedgoogle/jax

Add custom JVP rule for jnp.sinc

The current implementation has incorrect even-ordered derivatives at x=0.0:

import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import grad, vmap

x = jnp.linspace(-5, 5, 101)
y = vmap(grad(grad(jnp.sinc)))(x)

plt.plot(x, y)

download

Related to #5039

closed time in 10 hours

jakevdp

Pull request review commentgoogle/jax

make convert_element_type_p not require old_dtype

 def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): ad.defjvp_zero(lt_p)  -def _convert_element_type_shape_rule(operand, *, new_dtype, old_dtype):+def _convert_element_type_shape_rule(operand, *, new_dtype):

What would you think about using dtype rather than new_dtype as the argument name?

mattjj

comment created time in 10 hours

Pull request review commentgoogle/jax

make convert_element_type_p not require old_dtype

 def testBroadcastInDim(self):     self._Check(make_const, expected)    def testConvertElementTypeMismatchedDTypeOldDType(self):

delete?

mattjj

comment created time in 10 hours

issue commentgoogle/jax

numerical issues with 4th deriv of sinc near zero

Interestingly, the numerical issues appear unrelated to the Maclaurin trick added in #5077

import jax.numpy as jnp
from jax import vmap, grad
import matplotlib.pyplot as plt

def sinc(x):
  x = jnp.pi * jnp.asarray(x)
  safe_x = jnp.where(x == 0, 1, x)
  return jnp.where(x == 0, 0, jnp.sin(safe_x) / safe_x)

xs = jnp.linspace(-5, 5, 1001)
plt.plot(xs, vmap(grad(sinc))(xs))
plt.plot(xs, vmap(grad(grad(sinc)))(xs))
plt.plot(xs, vmap(grad(grad(grad(sinc))))(xs))
plt.plot(xs, vmap(grad(grad(grad(grad(sinc)))))(xs))

download-1

mattjj

comment created time in 10 hours

pull request commentgoogle/jax

Implement statistic pad mode in jax.numpy.pad

We have some internal test failures that I believe would be remedied by syncing to master and force-pushing to this branch. Are you comfortable doing that?

minoring

comment created time in 10 hours

more