google/jax 10446
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Research language for array processing in the Haskell/ML family
A pedagogical implementation of Autograd
google/xls 450
XLS: Accelerated HW Synthesis
duvenaud/relax 137
Optimizing control variates for blackbox gradient estimation
jacobjinkelly/easyneuralode 100
Code for the paper "Learning Differential Equations that are Easy to Solve"
Recognizing and exploiting conjugacy without a domainspecific language
Prototypes of differentiable differential equation solvers in JAX.
my fish configuration
my .vim
push eventgoogle/jax
commit sha 3ea15b45410e36e40277e78a92e4864b69a2e0a1
Cleanup: remove unnecessary rng_factory boilerplace in lax_test.py
commit sha f0b13c19abaacae089c4d5ef912bb46996b1c41a
Merge pull request #5093 from jakevdp:cleanupboilerplate PiperOriginRevId: 345590163
push time in 2 hours
pull request commentgoogle/jax
added prepend and append to diff
A sync & rebase ~an hour ago worked for #5060. Would you mind trying again?
comment created time in 2 hours
issue commentgoogle/jax
Switch default PRNG to ThreeFry4x32
https://github.com/google/jax/issues/2294 might be related? It would allow you to make this change more transparently since no code would assume a key size; it would be hidden behind the opaque object.
comment created time in 2 hours
PR merged google/jax
Implemented maximum, mean, median, and minimum padding modes.
Related #5010
pr closed time in 2 hours
push eventgoogle/jax
commit sha 327ac8dd149e85c7a512eff1ba9f6da0b11bfc0f
Implement statistic pad mode in jax.numpy.pad
commit sha c7d19fd861bd8ce931ed38d68e2b959e95684ce7
Merge pull request #5060 from minoring:supportpaddingmodes PiperOriginRevId: 345588715
push time in 2 hours
issue commentgoogle/jax
Improve __str__/__repr__ of PyTreeDefs
This would be really useful. I actually had to improve this for my code too since my PyTrees are really complicated. I frequently have to diff PyTreeDefs returned when the custom VJP complains that they don't match. It might also be nice to consider displaying PyTreeDef on multiple lines, with subelements indented.
comment created time in 2 hours
pull request commentgoogle/jax
Implement statistic pad mode in jax.numpy.pad
Syncing to master worked! Thanks.
comment created time in 3 hours
issue commentgoogle/jax
@hawkinsp @inailuig
Thank you for trying out JAX on AMD GPUs. I am on the TF framework team in AMD, and would like to get a better understanding of the TF changes that are required to get JAX working. We would be more than happy to help out.
I also had a question for you. Does JAX have unittests that run on GPUs, and if so can you point me to the directions to run them. I would like to get them running on internally on our platform,
thanks again
deven
comment created time in 4 hours
PR merged google/jax
We don't know why some builds produce this and others do not, but we can at least test for it to prevent bad releases.
Issue #3867
pr closed time in 5 hours
push eventgoogle/jax
commit sha 4a774978a2afae07eed50dd320aeb089f660a656
Add test for chkstk_darwin symbol to jaxlib Mac builds. We don't know why some builds produce this and others do not, but we can at least test for it to prevent bad releases.
commit sha 6f3a0c7bbe795ff0418fdcc0bba06165015c95e0
Merge pull request #5092 from hawkinsp:jaxlib PiperOriginRevId: 345560236
push time in 5 hours
issue commentgoogle/jax
Tensorboard profiling fails with jax
Hi @tomhennigan, will you be able to list the steps you've taken so far?
For example:
 Update/uninstall CUDA (in Colab):
!aptget purge remove cuda nvidia* libnvidia*
!dpkg l  grep cuda  awk '{print $2}'  xargs n1 dpkg purge
!aptget remove cuda*
!apt autoremove
!aptget update
 Install CUDA 11.1:
# Check that your version of Ubuntu on Colab is 18.04
!lsb_release a
# Download CUDA 11.1 for Ubuntu 18.04
# (source: https://developer.nvidia.com/cudadownloads?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=deblocal
!wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cudaubuntu1804.pin
!sudo mv cudaubuntu1804.pin /etc/apt/preferences.d/cudarepositorypin600
!wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cudarepoubuntu1804111local_11.1.1455.32.001_amd64.deb
!sudo dpkg i cudarepoubuntu1804111local_11.1.1455.32.001_amd64.deb
!sudo aptkey add /var/cudarepoubuntu1804111local/7fa2af80.pub
!sudo aptget update
!sudo aptget y install cuda
 Check the CUDA version:
!nvcc version

Install cuDNN 8.0.6 by following https://docs.nvidia.com/deeplearning/cudnn/installguide/index.html#installlinux (requires to be logged in to download).

Install TensorFlow nightly and TensorBoard profiler nightly (Source: https://github.com/google/jax/issues/4427)
!pip install upgrade tfnightly tbpnightly
 Import JAX and JAX Profiler:
import jax
import jax.profiler
etc
comment created time in 6 hours
pull request commentgoogle/jax
added prepend and append to diff
Yeah, I'm not sure what's going on. The same errors are blocking #5060 as well.
comment created time in 8 hours
issue closedgoogle/jax
jax.scipy.special.multigammaln fails under jit
Brief example below.
The need for d
to be static is an implementation detail; the shape of the function result does not depend on the value of d
.
Easiest fix would be to have better errors for a traced d
.
Harder fix (and maybe not necessary) would be to implement the function in a way such that d
need not be static.
from jax.scipy.special import multigammaln
from jax import jit
jit(multigammaln)(1, 2)
Traceback (most recent call last):
File "tmp.py", line 4, in <module>
jit(multigammaln)(1, 2)
File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
lax.div(jnp.arange(d), _constant_like(a, 2))),
File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
start = require(start, msg("stop"))
jax._src.traceback_util.FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
It arose in jax.numpy.arange argument `stop`.
While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstracttracervalueencounteredwhereconcretevalueisexpectederror for more information.
Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
The stack trace above excludes JAXinternal frames.
The following is the original exception that occurred, unmodified.

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "tmp.py", line 4, in <module>
jit(multigammaln)(1, 2)
File "/Users/vanderplas/github/google/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/vanderplas/github/google/jax/jax/api.py", line 218, in f_jitted
out = xla.xla_call(
File "/Users/vanderplas/github/google/jax/jax/core.py", line 1226, in bind
return call_bind(self, fun, *args, **params)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 1217, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 1229, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 595, in process_call
return primitive.impl(f, *tracers, **params)
File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 569, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 251, in memoized_fun
ans = call(fun, *args)
File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 645, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
lax.div(jnp.arange(d), _constant_like(a, 2))),
File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
start = require(start, msg("stop"))
File "/Users/vanderplas/github/google/jax/jax/core.py", line 919, in concrete_or_error
raise_concretization_error(val, context)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 896, in raise_concretization_error
raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
It arose in jax.numpy.arange argument `stop`.
While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstracttracervalueencounteredwhereconcretevalueisexpectederror for more information.
Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
closed time in 9 hours
jakevdpissue commentgoogle/jax
jax.scipy.special.multigammaln fails under jit
I'm going to close this: I think requiring statid d
is fine, and #5074 made improved the error message. In the scipy version, only an integer value is accepted.
comment created time in 9 hours
PR merged google/jax
pr closed time in 9 hours
push eventgoogle/jax
commit sha e4c5be94a3c1f07b4d562ac2ba1d0c1866505a63
[jax2tf] Add note that the TF.js conversion is experimental.
commit sha 7f6f5f2f4e1989e38247a0e37144bea5b2d72d16
Merge pull request #4997 from bchetioui:tfjs_exp PiperOriginRevId: 345516398
push time in 9 hours
pull request commentgoogle/jax
added prepend and append to diff
Yeah, I gave it a shot. Let's see what happens.
comment created time in 9 hours
Pull request review commentgoogle/jax
make convert_element_type_p not require old_dtype
def testBroadcastInDim(self): self._Check(make_const, expected) def testConvertElementTypeMismatchedDTypeOldDType(self):
I think this one is pretty clearly obsolete at this point! I’d be fine removing it.
comment created time in 9 hours
fork yifeif/ray
A system for parallel and distributed Python that unifies the ML ecosystem.
https://ray.readthedocs.io/en/latest/
fork in 9 hours
PR merged google/jax
This was previously the accepted style for specifying tests; now we generally specify rngs directly.
pr closed time in 9 hours
push eventgoogle/jax
commit sha 01fbf780b1110f00314b5ce0cf78a4c3cd284320
Cleanup: remove unnecessary rng_factory boilerplate in lax_numpy_test
commit sha 7147d272bd24ceba58d50fee5805503a309c67e4
Merge pull request #5086 from jakevdp:testrngcleanup PiperOriginRevId: 345508883
push time in 9 hours
PR merged google/jax
Thanks to @dougalm for thinking this through with me!
The basic problem in #5054 is that we had for sinc
an expression like
return where(x == 0, 1., sin(pi * x) / (pi * x))
One way to read that is at x == 0
, we're replacing the function with its zerothorder Taylor expansion. But that meant at zero we aren't computing its correct derivatives! (We were getting lucky with the first derivative at zero: it happens to be zero because that's the slope of the sinc function at zero, but we were computing zero because we were differentiating a constant function.)
One way to think about the fix here is that we're replacing the constant function at zero with a truncated Taylor series at zero. For any fixed Taylor series truncation level, we'd only be able to support some finite order of autodiff before incorrectly truncating the derivatives to zero. So we play a trick: we use a custom_jvp
rule to in effect generate all the Taylor series coefficients we need, lazily as we differentiate.
fixes #5054
pr closed time in 10 hours
push eventgoogle/jax
commit sha 8cd855507179e21bfdc0e518025bf8957fc8e342
improve sinc jvp at zero, fixes #5054
commit sha e6c01aeacd70064bd3e6a160892ee51a713c1cda
fix broadcasting bug in sinc jvp rule
commit sha fa573ef07b28c9e44ad008fc3cbea60a57d2c07a
Merge pull request #5077 from google:sincjvp PiperOriginRevId: 345505504
push time in 10 hours
issue closedgoogle/jax
Add custom JVP rule for jnp.sinc
The current implementation has incorrect evenordered derivatives at x=0.0
:
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import grad, vmap
x = jnp.linspace(5, 5, 101)
y = vmap(grad(grad(jnp.sinc)))(x)
plt.plot(x, y)
Related to #5039
closed time in 10 hours
jakevdpPull request review commentgoogle/jax
make convert_element_type_p not require old_dtype
def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): ad.defjvp_zero(lt_p) def _convert_element_type_shape_rule(operand, *, new_dtype, old_dtype):+def _convert_element_type_shape_rule(operand, *, new_dtype):
What would you think about using dtype
rather than new_dtype
as the argument name?
comment created time in 10 hours
Pull request review commentgoogle/jax
make convert_element_type_p not require old_dtype
def testBroadcastInDim(self): self._Check(make_const, expected) def testConvertElementTypeMismatchedDTypeOldDType(self):
delete?
comment created time in 10 hours
issue commentgoogle/jax
numerical issues with 4th deriv of sinc near zero
Interestingly, the numerical issues appear unrelated to the Maclaurin trick added in #5077
import jax.numpy as jnp
from jax import vmap, grad
import matplotlib.pyplot as plt
def sinc(x):
x = jnp.pi * jnp.asarray(x)
safe_x = jnp.where(x == 0, 1, x)
return jnp.where(x == 0, 0, jnp.sin(safe_x) / safe_x)
xs = jnp.linspace(5, 5, 1001)
plt.plot(xs, vmap(grad(sinc))(xs))
plt.plot(xs, vmap(grad(grad(sinc)))(xs))
plt.plot(xs, vmap(grad(grad(grad(sinc))))(xs))
plt.plot(xs, vmap(grad(grad(grad(grad(sinc)))))(xs))
comment created time in 10 hours
pull request commentgoogle/jax
Implement statistic pad mode in jax.numpy.pad
We have some internal test failures that I believe would be remedied by syncing to master and forcepushing to this branch. Are you comfortable doing that?
comment created time in 10 hours