profile
viewpoint

google/jax 9779

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

percyliang/sempre 758

Semantic Parser with Execution

google-research/dex-lang 661

Research language for array processing in the Haskell/ML family

percyliang/refdb 23

Stores paper references, outputs to bib/html, does basic sanity checking on bib entries

froystig/rust 0

a safe, concurrent, practical language

froystig/vowpal_wabbit 0

Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

PR opened google/jax

mark lax traceables as entry points for filtered stack traces
+134 -5

0 comment

3 changed files

pr created time in 6 days

create barnchgoogle/jax

branch : lax-api-boundary

created branch time in 6 days

PullRequestReviewEvent

issue commentgoogle/jax

Differentiation rule for 'create_token'

As of #4330, you should no longer need to make that stop_gradient call yourself.

PhilipVinc

comment created time in 9 days

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

delete branch google/jax

delete branch : create-token-stop-grad

delete time in 12 days

push eventgoogle/jax

Roy Frostig

commit sha 16e60360940f2baa4f3bc4f94b58680af8295198

trivial change to test source sync PiperOrigin-RevId: 332544315

view details

push time in 12 days

PR opened google/jax

Reviewers
insert a stop_gradient in lax.create_token

see #4292

The data dependency forced on the operand is false. There's no need to follow it for derivatives, and doing so can lead to unexpected errors.

+1 -1

0 comment

1 changed file

pr created time in 13 days

create barnchgoogle/jax

branch : create-token-stop-grad

created branch time in 13 days

PullRequestReviewEvent

issue commentgoogle/jax

psum inconsistent across vmap/pmap context

related: #3970

froystig

comment created time in 13 days

issue openedgoogle/jax

psum inconsistent across vmap/pmap context

In particular, the psum primitive should behave as a broadcasting sum ("allreduce").

A fix would likely involve changing this rule, plus tests for (i) intended behavior and (ii) consistent behavior across pmap and vmap.

cc @apaszke @mattjj

created time in 13 days

issue commentgoogle/jax

Differentiation rule for 'create_token'

It's helpful to see the code example now, and in particular the token's dependence on x. If you replace the statement token = jax.lax.create_token(x) with:

token = jax.lax.create_token(jax.lax.stop_gradient(x))

I'd expect things to work as you intend. Do they?

PhilipVinc

comment created time in 13 days

pull request commentgoogle/jax

Implement jnp.ravel_multi_index()

I had core.concrete_or_error in mind, which is used elsewhere throughout lax_numpy. It allows tracers as long as they're concrete.

jakevdp

comment created time in 13 days

PullRequestReviewEvent
PullRequestReviewEvent

delete branch google/jax

delete branch : rtdfix

delete time in 15 days

PR closed google/jax

fix docs build by adding requirement cla: yes

This is an attempt to resolve #4202

+1 -0

2 comments

1 changed file

froystig

pr closed time in 15 days

pull request commentgoogle/jax

fix docs build by adding requirement

It is no longer needed!

froystig

comment created time in 15 days

issue closedgoogle/jax

Broken docs check

The docs/readthedocs.org:jax check fails across PRs lately, i. e. here. The only error in the log is

ERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.

myst-parser 0.12.8 requires docutils>=0.15, but you'll have docutils 0.14 which is incompatible.

closed time in 15 days

JuliusKunze

issue commentgoogle/jax

Broken docs check

I'm going to close this because it indeed seems to have been fixed.

JuliusKunze

comment created time in 15 days

issue commentgoogle/jax

Differentiation rule for 'create_token'

As you say, there isn't a derivative here, so we shouldn't define a false one. JAX allows differentiation with respect to only some function inputs. At the user level, this corresponds to passing argnums to functions such as grad. Does this resolve things within your use case?

PhilipVinc

comment created time in 15 days

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

issue commentgoogle/jax

Broken docs check

This looks like it may have been fixed (by RTD?) in the past 16+ hours, as the doc builds have started passing again. @shoyer, any thoughts on whether to add the requirement anyway? PR is ready, but I can discard it too.

JuliusKunze

comment created time in a month

PR opened google/jax

fix docs build by adding requirement

This is an attempt to resolve #4202

+1 -0

0 comment

1 changed file

pr created time in a month

create barnchgoogle/jax

branch : rtdfix

created branch time in a month

issue commentgoogle/jax

`jax.scipy.linalg.expm` causes an infinite loop inside two nested `fori_loop`/`scan`s.

Good point. We could return early in a jit-friendly way using lax.cond, but I'm wary of affecting performance and introducing new global flags. This might also obscure an underlying solve-related bug that will only reappear later.

From what I can tell, the "hanging loop" isn't one directly in our codebase. At some point we pass the inf/nan values over to a lower-level linalg routine. It may be the getrf in our lax.lu_p translation, which comes from lapack via jaxlib, or it may be an XLA solve op after that. It seems better to check and return early at whatever that more upstream location is.

C-J-Cundy

comment created time in a month

issue commentgoogle/jax

`jax.scipy.linalg.expm` causes an infinite loop inside two nested `fori_loop`/`scan`s.

I think this is expm hanging on invalid numeric input. The bug doesn't seem to require loop/scan. I've distilled the original example down to:

import numpy as onp
import scipy as osp
from jax import scipy as jsp

sp = jsp

def potential(x):
  W = onp.array([[0.0, x[0]], [x[1], 0.0]])
  return onp.trace(sp.linalg.expm(W * W))

def leapfrog(q):
  def scan_fun(q):
    V = potential(q)
    q -= V
    return q, (q, V)

  q, (q, V) = scan_fun(q)
  q, (q, V) = scan_fun(q)
  q, (q, V) = scan_fun(q)       # error happens here
  return q

leapfrog(onp.random.normal(size=2) * 3)

which hangs similarly. If I substitute sp = osp on line 5, I encounter:

Traceback (most recent call last):
  ...
  File ".../scipy/sparse/linalg/matfuncs.py", line 671, in _expm
    s = max(int(np.ceil(np.log2(eta_5 / theta_13))), 0)
ValueError: cannot convert float NaN to integer

The third invocation of potential passes it [-inf -inf]. Those inf values are handed to expm.

We could consider erring on input like this. A question is where to do so. Because expm reduces to matrix solve, I suspect that this attempts a solve on a terribly conditioned matrix. Maybe we should err in our solve routines, more generally.

Standard scipy's error isn't directly clear either, but at least the program halts.

Paging @shoyer who knows expm better. Thoughts?

C-J-Cundy

comment created time in a month

PullRequestReviewEvent
PullRequestReviewEvent

pull request commentgoogle/jax

Add NumPy backend

This is a large PR! It looks really promising to me still. It'll take us some time to review (whether it's @mattjj or me or both of us), but we'll try to give at least some partial feedback soon.

JuliusKunze

comment created time in a month

PullRequestReviewEvent

Pull request review commentgoogle/jax

single-operand cond

 def f_aug(*args):    return _make_typed_jaxpr(f_aug, jaxpr.in_avals) +def _join_cond_pe_staged_jaxpr_inputs(jaxprs, res_avals_per_jaxpr):+  newvar = core.gensym('~')     # TODO(frostig): safer gensym

Resolved by #3211

froystig

comment created time in a month

PullRequestReviewEvent

push eventgoogle/jax

Jean-Baptiste Lespiau

commit sha 6bed4ee3b2c9b7f90883118dffc183ca0ed39774

Temporarily disable jax_jit tests. (#4118)

view details

push time in a month

PR merged google/jax

Temporarilly disable C++ jax_jit tests. cla: yes

I need to update Tensorflow with a breaking change, and this breaks at Google3 head.

+3 -1

0 comment

1 changed file

jblespiau

pr closed time in a month

PullRequestReviewEvent

pull request commentgoogle/jax

Doc: change suggested way of starting the profiler

Maybe we should keep the resulting object alive within the profiler module? Then this requirement of caller behavior (and documentation) won't be needed, if I understand correctly.

wrzadkow

comment created time in a month

issue closedgoogle/jax

Suggestion: expose `_jit_is_disabled` utility

Hi there,

This is just a small comment / suggestion. I just found the _jit_is_disabled() utility and I'm finding it very useful to do some asserts that aren't possible in jit context.

The general structure:

@jax.jit
def func(arr):
    if jax.api._jit_is_disabled():
        assert arr.min() > 0, "data-dependent check failed"
    # implement the rest of the function
    ...

Are you planning on making _jit_is_disabled a first-class citizen in the jax api, e.g. jax.jit_is_disabled()?

Cheers!

closed time in a month

KristianHolsheimer

issue commentgoogle/jax

Suggestion: expose `_jit_is_disabled` utility

The predicate _jit_is_disabled() does not imply that arrays are concrete. In your example, this means that the expression under assert might not work as intended. For instance, suppose we vmap the function you've defined, as in:

@jax.vmap
@jax.jit
def func(arr): ...

and then run:

with jax.disable_jit():
  func(jax.numpy.ones((3, 4)))

This will result in an error:

jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
[...]

Even though jit is effectively off, vmap will still lead to abstract interpretation of func, and arr will be an abstract tracer value at the point of the data-dependent check.

We don't plan to make this particular function part of the public API. More broadly, JAX's interface intentionally offers no means of asking whether a function is being evaluated under a transformation (or, in particular, whether jit is not). In part, doing so tends against JAX's central functional programming requirement. Indeed func, as written, behaves differently based on state outside the function.

KristianHolsheimer

comment created time in a month

Pull request review commentgoogle/jax

Add NumPy backend

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+from contextlib import contextmanager+from typing import Any, Dict, Callable, Union, Sequence, Optional, Tuple+import builtins+import itertools+import numpy as np+import opt_einsum+import scipy.special++from .. import config, core, dtypes, curry, random, util+from ..lax import lax+from ..lax.lax import GatherDimensionNumbers, Shape+from ..interpreters import xla++_slice = builtins.slice+_min = builtins.min+_map = builtins.map++map = util.safe_map+zip = util.safe_zip++Array = Union[np.ndarray, xla.DeviceArray]++np_impl: Dict[core.Primitive, Callable[..., Any]] = {}++class NumpyEvalTrace(core.Trace):+  def pure(self, x): return x+  lift = sublift = pure++  def process_primitive(self, primitive, tracers, params):+    impl = np_impl.get(primitive)+    if impl is None:+      raise NotImplementedError("NumPy backend does not yet support the "+                                f"'{primitive}' primitive.")+    return impl(*tracers, **params)++  def process_call(self, primitive, f, tracers, params):+    return f.call_wrapped(*tracers)+  process_map = process_call++@contextmanager+def numpy_eval():+  """+  Makes JAX code evaluate with NumPy instead of XLA.++  >>> @numpy_eval()+  ... def some_fun(...):+  ...  # jax code++  Inline:++  >>> from jax.interpreters.numpy_eval import numpy_eval+  ... with numpy_eval():+  ...   # jax code++  ``jit(numpy_eval()(some_fun))`` will collapse constant calculations using NumPy,+  so that no operations are performed on the XLA device during compilation.++  Using `numpy_eval` is thread-safe.+  """+  assert config.omnistaging_enabled, \+    "The NumPy backend requires omnistaging. Set flag JAX_OMNISTAGING=True or " \+    "call jax.config.enable_omnistaging() before using jax."+  with core.new_base_master(NumpyEvalTrace):+    yield++neg = np.negative+np_impl[lax.neg_p] = neg++sign = np.sign+np_impl[lax.sign_p] = sign++floor = np.floor+np_impl[lax.floor_p] = floor++ceil = np.ceil+np_impl[lax.ceil_p] = ceil++def round(x):+  return np.trunc(+    x + np.copysign(np.nextafter(np.array(.5, dtype=x.dtype),+                                 np.array(0., dtype=x.dtype),+                                 dtype=x.dtype), x)).astype(x.dtype)+np_impl[lax.round_p] = round++nextafter = np.nextafter+np_impl[lax.nextafter_p] = nextafter++is_finite = np.isfinite+np_impl[lax.is_finite_p] = is_finite++exp = np.exp+np_impl[lax.exp_p] = exp++expm1 = np.expm1+np_impl[lax.expm1_p] = expm1++log = np.log+np_impl[lax.log_p] = log++log1p = np.log1p+np_impl[lax.log1p_p] = log1p++tanh = np.tanh+np_impl[lax.tanh_p] = tanh++sin = np.sin+np_impl[lax.sin_p] = sin++cos = np.cos+np_impl[lax.cos_p] = cos++def atan2(x, x2): return np.arctan2(x, x2).astype(x.dtype)+np_impl[lax.atan2_p] = atan2++sqrt = np.sqrt+np_impl[lax.sqrt_p] = sqrt++def rsqrt(x): return np.ones_like(x) / np.sqrt(x)+np_impl[lax.rsqrt_p] = rsqrt++square = np.square+reciprocal = np.reciprocal+tan = np.tan+asin = np.arcsin+acos = np.arccos+atan = np.arctan++sinh = np.sinh+np_impl[lax.sinh_p] = sinh++cosh = np.cosh+np_impl[lax.cosh_p] = cosh++asinh = np.arcsinh+np_impl[lax.asinh_p] = asinh++acosh = np.arccosh+np_impl[lax.acosh_p] = acosh++atanh = np.arctanh+np_impl[lax.atanh_p] = atanh++def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype)+np_impl[lax.regularized_incomplete_beta_p] = betainc++def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype)+np_impl[lax.lgamma_p] = lgamma++def digamma(x): return scipy.special.digamma(x).astype(x.dtype)+np_impl[lax.digamma_p] = digamma++igamma = scipy.special.gammainc+np_impl[lax.igamma_p] = igamma++igammac = scipy.special.gammaincc+np_impl[lax.igammac_p] = igammac++# TODO lax.igamma_grad_a_p++def erf(x): return scipy.special.erf(x).astype(x.dtype)+np_impl[lax.erf_p] = erf++def erfc(x): return scipy.special.erfc(x).astype(x.dtype)+np_impl[lax.erfc_p] = erfc++def erf_inv(x): return scipy.special.erfinv(x).astype(x.dtype)+np_impl[lax.erf_inv_p] = erf_inv++def bessel_i0e(x): return scipy.special.i0e(x).astype(x.dtype)+np_impl[lax.bessel_i0e_p] = bessel_i0e++def bessel_i1e(x): return scipy.special.i1e(x).astype(x.dtype)+np_impl[lax.bessel_i1e_p] = bessel_i1e++real = np.real+np_impl[lax.real_p] = real++imag = np.imag+np_impl[lax.imag_p] = imag++def conj(x, input_dtype=None): return np.conj(x) + np.complex64(0)+np_impl[lax.conj_p] = conj++def complex(x, y): return x + np.complex64(1j) * y+np_impl[lax.complex_p] = complex++abs = np.absolute+np_impl[lax.abs_p] = abs++pow = np.power+np_impl[lax.pow_p] = pow+def integer_pow(x, y):+  return np.array(pow(x, y)).astype(x.dtype)+np_impl[lax.integer_pow_p] = integer_pow++bitwise_not = np.bitwise_not+np_impl[lax.not_p] = np.bitwise_not++bitwise_and = np.bitwise_and+np_impl[lax.and_p] = bitwise_and++bitwise_or = np.bitwise_or+np_impl[lax.or_p] = bitwise_or++bitwise_xor = np.bitwise_xor+np_impl[lax.xor_p] = bitwise_xor++add = np.add+np_impl[lax.add_p] = add++sub = np.subtract+np_impl[lax.sub_p] = sub++mul = np.multiply+np_impl[lax.mul_p] = mul++def div(lhs, rhs):+  if dtypes.issubdtype(dtypes.result_type(lhs), np.integer):+    quotient = np.floor_divide(lhs, rhs)+    select = np.logical_and(np.sign(lhs) != np.sign(rhs),+                            np.remainder(lhs, rhs) != 0)+    return np.where(select, quotient + 1, quotient)+  else:+    return np.divide(lhs, rhs)+np_impl[lax.div_p] = div++def rem(lhs, rhs):+  return np.sign(lhs) * np.remainder(np.abs(lhs), np.abs(rhs))+np_impl[lax.rem_p] = rem++max = np.maximum+np_impl[lax.max_p] = max++min = np.minimum+np_impl[lax.min_p] = min+++shift_left = np.left_shift+np_impl[lax.shift_left_p] = shift_left++@curry+def _shift_right(shift_types: Dict, x1, x2):+  shift_type = shift_types[x1.dtype]+  shifted = np.right_shift(x1.view(shift_type), x2.astype(shift_type))+  return shifted.astype(shift_type).view(x1.dtype)++_arithmetic_shift_types = {+  np.dtype('int8'): np.int8,+  np.dtype('int16'): np.int16,+  np.dtype('int32'): np.int32,+  np.dtype('int64'): np.int64,+  # lax does arithmetic (signed) shift irrespective of the type:+  np.dtype('uint8'): np.int8,+  np.dtype('uint16'): np.int16,+  np.dtype('uint32'): np.int32,+  np.dtype('uint64'): np.int64,+}++shift_right_arithmetic = _shift_right(_arithmetic_shift_types)+np_impl[lax.shift_right_arithmetic_p] = shift_right_arithmetic++_logical_shift_types = {+  np.dtype('int8'): np.uint8,+  np.dtype('int16'): np.uint16,+  np.dtype('int32'): np.uint32,+  np.dtype('int64'): np.uint64,+  np.dtype('uint8'): np.uint8,+  np.dtype('uint16'): np.uint16,+  np.dtype('uint32'): np.uint32,+  np.dtype('uint64'): np.uint64,+}++shift_right_logical = _shift_right(_logical_shift_types)+np_impl[lax.shift_right_logical_p] = shift_right_logical++def population_count(x):+  assert np.issubdtype(x.dtype, np.integer)+  dtype = x.dtype+  iinfo = np.iinfo(x.dtype)+  if np.iinfo(x.dtype).bits < 32:+    assert iinfo.kind in ('i', 'u')+    x = x.astype(np.uint32 if iinfo.kind == 'u' else np.int32)+  if iinfo.kind == 'i':+    x = x.view(f"uint{np.iinfo(x.dtype).bits}")+  assert x.dtype in (np.uint32, np.uint64)+  m = [+    0x5555555555555555,  # binary: 0101...+    0x3333333333333333,  # binary: 00110011..+    0x0f0f0f0f0f0f0f0f,  # binary:  4 zeros,  4 ones ...+    0x00ff00ff00ff00ff,  # binary:  8 zeros,  8 ones ...+    0x0000ffff0000ffff,  # binary: 16 zeros, 16 ones ...+    0x00000000ffffffff,  # binary: 32 zeros, 32 ones+  ]++  if x.dtype == np.uint32:+    m = list(_map(np.uint32, m[:-1]))+  else:+    m = list(_map(np.uint64, m))++  x = (x & m[0]) + ((x >>  1) & m[0])  # put count of each  2 bits into those  2 bits+  x = (x & m[1]) + ((x >>  2) & m[1])  # put count of each  4 bits into those  4 bits+  x = (x & m[2]) + ((x >>  4) & m[2])  # put count of each  8 bits into those  8 bits+  x = (x & m[3]) + ((x >>  8) & m[3])  # put count of each 16 bits into those 16 bits+  x = (x & m[4]) + ((x >> 16) & m[4])  # put count of each 32 bits into those 32 bits+  if x.dtype == np.uint64:+    x = (x & m[5]) + ((x >> 32) & m[5])  # put count of each 64 bits into those 64 bits+  return x.astype(dtype)+np_impl[lax.population_count_p] = population_count++eq = np.equal+np_impl[lax.eq_p] = eq++ne = np.not_equal+np_impl[lax.ne_p] = ne++ge = np.greater_equal+np_impl[lax.ge_p] = ge++gt = np.greater+np_impl[lax.gt_p] = gt++le = np.less_equal+np_impl[lax.le_p] = le++lt = np.less+np_impl[lax.lt_p] = lt++def convert_element_type(operand, new_dtype, old_dtype=None):+  return np.asarray(operand, dtype=new_dtype)+np_impl[lax.convert_element_type_p] = convert_element_type++def bitcast_convert_type(operand, new_dtype):+  return np.asarray(operand).view(new_dtype)+np_impl[lax.bitcast_convert_type_p] = bitcast_convert_type++def clamp(min, operand, max):+  return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype)+np_impl[lax.clamp_p] = clamp++def concatenate(*operands, dimension):+  return np.concatenate(operands, dimension)+np_impl[lax.concatenate_p] = concatenate++def conv_with_general_padding(lhs, rhs, window_strides, padding, lhs_dilation,+                              rhs_dilation, precision=None):+  return _conv(_dilate(lhs, lhs_dilation), _dilate(rhs, rhs_dilation),+               window_strides, padding)++def conv_general_dilated(lhs, rhs, window_strides, padding,+                         lhs_dilation, rhs_dilation, dimension_numbers,+                         feature_group_count, batch_group_count,+                         **_):+  if feature_group_count != 1:+    raise NotImplementedError("Feature groups not yet supported on NumPy backend.")+  if batch_group_count != 1:+    raise NotImplementedError("Batch groups not yet supported on NumPy backend.")+  lhs_perm, rhs_perm, out_perm = dimension_numbers+  trans_lhs = np.transpose(lhs, lhs_perm)+  trans_rhs = np.transpose(rhs, rhs_perm)+  out = conv_with_general_padding(+    trans_lhs, trans_rhs, window_strides, padding, lhs_dilation, rhs_dilation)+  return np.transpose(out, np.argsort(out_perm))+np_impl[lax.conv_general_dilated_p] = conv_general_dilated++def dot_general(lhs, rhs, dimension_numbers, precision=None):+  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers+  new_id = itertools.count()+  lhs_axis_ids = [next(new_id) for _ in lhs.shape]+  rhs_axis_ids = [next(new_id) for _ in rhs.shape]+  lhs_out_axis_ids = lhs_axis_ids[:]+  rhs_out_axis_ids = rhs_axis_ids[:]++  for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):+    shared_id = next(new_id)+    lhs_axis_ids[lhs_axis] = shared_id+    rhs_axis_ids[rhs_axis] = shared_id+    lhs_out_axis_ids[lhs_axis] = None+    rhs_out_axis_ids[rhs_axis] = None++  batch_ids = []+  for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):+    shared_id = next(new_id)+    lhs_axis_ids[lhs_axis] = shared_id+    rhs_axis_ids[rhs_axis] = shared_id+    lhs_out_axis_ids[lhs_axis] = None+    rhs_out_axis_ids[rhs_axis] = None+    batch_ids.append(shared_id)++  not_none = lambda x: x is not None+  out_axis_ids = filter(not_none,+                        batch_ids + lhs_out_axis_ids + rhs_out_axis_ids)+  assert lhs.dtype == rhs.dtype+  dtype = np.float32 if lhs.dtype == dtypes.bfloat16 else None+  out = np.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids,+                  dtype=dtype)+  return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out+np_impl[lax.dot_general_p] = dot_general++def broadcast(operand, sizes):+  return np.broadcast_to(operand, sizes + np.shape(operand))++def broadcast_in_dim(operand, shape, broadcast_dimensions):+  in_reshape = np.ones(len(shape), dtype=np.int32)+  for i, bd in enumerate(broadcast_dimensions):+    in_reshape[bd] = operand.shape[i]+  return np.broadcast_to(np.reshape(operand, in_reshape), shape)+np_impl[lax.broadcast_in_dim_p] = broadcast_in_dim++def squeeze(array, dimensions): return np.squeeze(array, dimensions)+np_impl[lax.squeeze_p] = squeeze++def _ensure_numpy(operand: Array) -> np.ndarray:+  return operand.copy() if isinstance(operand, xla.DeviceArray) else operand++def reshape(operand, new_sizes, dimensions=None):+  if dimensions is not None:+    operand = transpose(operand, dimensions)+  return np.reshape(_ensure_numpy(operand), new_sizes)+np_impl[lax.reshape_p] = reshape++def pad(operand, padding_value, padding_config):+  # https://www.tensorflow.org/xla/operation_semantics#pad+  lo, hi, interior = zip(*padding_config)+  # Handle first the positive edge padding and interior+  lo_pos, hi_pos = np.clip(lo, 0, None), np.clip(hi, 0, None)+  outshape = np.add(+    np.add(np.add(lo_pos, hi_pos), operand.shape),+    np.maximum(0, np.multiply(interior, np.subtract(operand.shape, 1))))+  out = np.full(outshape, padding_value, operand.dtype)+  lhs_slices = tuple(_slice(l if l > 0 else 0, -h if h > 0 else None, step)+                     for l, h, step in zip(lo_pos, hi_pos, np.add(1, interior)))+  out[lhs_slices] = operand+  trim_slices = tuple(_slice(-l if l < 0 else 0, h if h < 0 else None)+                      for l, h in zip(lo, hi))+  return out[trim_slices]+np_impl[lax.pad_p] = pad++def rev(operand, dimensions):+  dimensions = frozenset(dimensions)+  indexer = (_slice(None, None, -1) if d in dimensions else _slice(None)+             for d in range(np.ndim(operand)))+  return operand[tuple(indexer)]+np_impl[lax.rev_p] = rev++select = np.where+np_impl[lax.select_p] = select++def slice(operand, start_indices, limit_indices, strides=None):  # pylint: disable=redefined-builtin+  if strides is None:+    strides = np.ones(len(start_indices)).astype(int)+  idx = tuple(map(_slice, start_indices, limit_indices, strides))+  return _ensure_numpy(operand)[idx]+np_impl[lax.slice_p] = slice++def dynamic_slice(operand, *start_indices, slice_sizes):+  out = np.zeros(slice_sizes, dtype=operand.dtype)+  idx = tuple(_slice(start, start+size)+              for start, size in zip(start_indices, slice_sizes))+  section = _ensure_numpy(operand)[idx]+  out[tuple(_slice(None, stop) for stop in section.shape)] = section+  return out+np_impl[lax.dynamic_slice_p] = dynamic_slice++def dynamic_update_slice(operand, update, *start_indices):+  slices = tuple(map(_slice, start_indices, np.add(start_indices, update.shape)))+  updated_operand = np.copy(operand)+  updated_operand[slices] = update+  return updated_operand+np_impl[lax.dynamic_update_slice_p] = dynamic_update_slice++def transpose(operand: Array, permutation: Sequence[int]) -> Array:+  return np.transpose(_ensure_numpy(operand), permutation)+np_impl[lax.transpose_p] = transpose++@curry+def _reduce(reducer: np.ufunc, operand: Array, axes: Sequence[int]) -> np.ndarray:+  operand = np.asarray(operand)+  dtype = operand.dtype+  if dtype == dtypes.bfloat16:+    operand = operand.astype(np.float64)+  elif dtype == np.uint64 and not config.FLAGS.jax_enable_x64:+    operand = operand.astype(np.uint32)+  out = reducer.reduce(operand, axis=tuple(axes))+  if dtype == dtypes.bfloat16:+    out = out.astype(np.float64)+  return out.astype(dtype)+np_impl[lax.reduce_sum_p] = _reduce(np.add)+np_impl[lax.reduce_prod_p] = _reduce(np.multiply)+np_impl[lax.reduce_max_p] = _reduce(np.maximum)+np_impl[lax.reduce_min_p] = _reduce(np.minimum)+np_impl[lax.reduce_or_p] = _reduce(np.logical_or)+np_impl[lax.reduce_and_p] = _reduce(np.logical_and)++def reduce(operand: Array, init_value: Array, computation: Callable,+           dimensions: Sequence[int], jaxpr=None, consts=None) -> np.ndarray:+  reducer = _reducer(computation, init_value)

Brought up in #4109: I'd recommend ignoring computation here, and using core.jaxpr_as_fun on jaxpr instead to reliably obtain a callable form of the reduction binop.

JuliusKunze

comment created time in a month

PullRequestReviewEvent
PullRequestReviewEvent

issue commentgoogle/jax

reduce_window primitive signature inconsistent with reduce

The computation parameter of reduce_p seems only to be read by primitive's batching rule at this point. As I understand it, that rule uses it as a workaround/shortcut in order to call lax.reduce rather than reduce_p.bind directly. This seems like something we ought to clean up.

Generally, the computation parameter isn't compatible with the Jaxpr IR. There is no notion of indirect invocation in the IR, by design. As you point out, jaxpr and consts encode what we need. Those are valid for Jaxpr. Note, for instance, that the translation rule (i.e. XLA backend) uses them and ignores computation.

For those reasons, I would propose that if we make any change, it would be to remove computation from reduce_p, rather than to add it elsewhere. Consider jaxpr and consts to be canonical.

Based on the changes in #3923: are you looking for a simple way to invoke a jaxpr? We typically use core.jaxpr_as_fun to turn a jaxpr into a callable function.

JuliusKunze

comment created time in a month

issue closedgoogle/jax

Can't differentiate scipy.linalg.expm for M Layer paper

Would it be possible to implement this interesting paper idea with jax / flax?

Intelligent Matrix Exponentiation paper: https://arxiv.org/pdf/2008.03936.pdf code: https://github.com/google-research/google-research/tree/master/m_layer wiki: https://en.wikipedia.org/wiki/Matrix_exponential

expm docs: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html#jax.scipy.linalg.expm

expm frechet looks like the right thing to make the gradient: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm_frechet.html#jax.scipy.linalg.expm_frechet

I tried a few options, here's a simple one:

import jax
from flax import nn
jnp = jax.numpy
vec_expm = jnp.vectorize(jax.scipy.linalg.expm, signature='(k)->(k)')

@nn.module
def MLayer(x, D=D_CODE):
    x = nn.Dense(x, D ** 2)
    x = x.reshape(x.shape[:-1] + (D, D))
    x = vec_expm(x
    x = x.reshape(x.shape[:-2] + (D ** 2,))
    x = nn.Dense(x, D)
    return x

however this crashes because of this:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

Would it work if we wire up expm_frechet primitive?

a simpler reproduction without other nn stuff:

import jax.numpy as jnp
import jax

rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, (2, 2))

def f(x):
    y = jax.scipy.linalg.expm(x)
    credit = jnp.sum(jnp.abs(y))
    return credit, y

f = jax.value_and_grad(f)

c, y = f(x)
print(c, y)

full trace:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-14-f2d123ee1505> in <module>
----> 1 f(x)

~/miniconda3/lib/python3.8/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    485     tree_map(partial(_check_input_dtype_grad, holomorphic), dyn_args)
    486     if not has_aux:
--> 487       ans, vjp_py = _vjp(f_partial, *dyn_args)
    488     else:
    489       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/miniconda3/lib/python3.8/site-packages/jax/api.py in _vjp(fun, *primals, **kwargs)
   1514   if not has_aux:
   1515     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1516     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1517     out_tree = out_tree()
   1518   else:

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    108 def vjp(traceable, primals, has_aux=False):
    109   if not has_aux:
--> 110     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    111   else:
    112     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     95   _, in_tree = tree_flatten(((primals, primals), {}))
     96   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 97   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     98   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
     99   assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    421   with core.new_master(trace_type, bottom=bottom) as master:
    422     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 423     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    424     assert not env
    425     del master

~/miniconda3/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

<ipython-input-10-995182967638> in wrapped(*args, **kwargs)
      1 def value_and_jacobian(fun):
      2     def wrapped(*args, **kwargs):
----> 3         return fun(*args, **kwargs), jax.jacfwd(fun)(*args, **kwargs)
      4     return wrapped

~/miniconda3/lib/python3.8/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    491     dtype = dtypes.result_type(ans)
    492     tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
--> 493     g = vjp_py(np.ones((), dtype=dtype))
    494     g = g[0] if isinstance(argnums, int) else g
    495     if not has_aux:

~/miniconda3/lib/python3.8/site-packages/jax/api.py in _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args)
   1458              "match type of corresponding primal output ({})")
   1459       raise TypeError(msg.format(_dtype(a), dtype))
-> 1460   ans = fun(*args)
   1461   return tree_unflatten(out_tree, ans)
   1462 

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py in unbound_vjp(pvals, jaxpr, consts, *cts)
    115     cts = tuple(map(ignore_consts, cts, pvals))
    116     dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
--> 117     arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
    118     return map(instantiate_zeros, arg_cts)
    119 

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
    200         cts_in_avals = [v.aval for v in eqn.outvars]
    201         call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
--> 202         cts_out = get_primitive_transpose(eqn.primitive)(
    203             params, call_jaxpr, invals, cts_in, cts_in_avals)
    204       else:

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py in call_transpose(primitive, params, call_jaxpr, args, ct, _)
    486     new_params = update_params(new_params, map(is_undefined_primal, args),
    487                                [type(x) is not Zero for x in ct])
--> 488   out_flat = primitive.bind(fun, *all_args, **new_params)
    489   return tree_unflatten(out_tree(), out_flat)
    490 primitive_transposes[core.call_p] = partial(call_transpose, call_p)

~/miniconda3/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1132 
   1133   def bind(self, fun, *args, **params):
-> 1134     return call_bind(self, fun, *args, **params)
   1135 
   1136   def process(self, trace, fun, tracers, params):

~/miniconda3/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1124   else:
   1125     tracers = map(top_trace.full_raise, args)
-> 1126     outs = primitive.process(top_trace, fun, tracers, params)
   1127   return apply_todos(env_trace_todo(), map(full_lower, outs))
   1128 

~/miniconda3/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1135 
   1136   def process(self, trace, fun, tracers, params):
-> 1137     return trace.process_call(self, fun, tracers, params)
   1138 
   1139   def post_process(self, trace, out_tracers, params):

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
    273     update_params = call_param_updaters.get(call_primitive)
    274     new_params = update_params(params, nz_tangents) if update_params else params
--> 275     result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
    276     primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
    277     return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

~/miniconda3/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1132 
   1133   def bind(self, fun, *args, **params):
-> 1134     return call_bind(self, fun, *args, **params)
   1135 
   1136   def process(self, trace, fun, tracers, params):

~/miniconda3/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1124   else:
   1125     tracers = map(top_trace.full_raise, args)
-> 1126     outs = primitive.process(top_trace, fun, tracers, params)
   1127   return apply_todos(env_trace_todo(), map(full_lower, outs))
   1128 

~/miniconda3/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1135 
   1136   def process(self, trace, fun, tracers, params):
-> 1137     return trace.process_call(self, fun, tracers, params)
   1138 
   1139   def post_process(self, trace, out_tracers, params):

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in process_call(self, primitive, f, tracers, params)
    179                   else PartialVal.unknown(mapped_aval(pval[0]))
    180                   for pval, is_mapped in zip(in_pvals, params['mapped_invars'])]
--> 181     jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
    182         f, in_pvals, partial(primitive.bind, **params))
    183     if primitive.map_primitive:

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in partial_eval(self, f, pvals, app)
    279     f = trace_to_subjaxpr(f, self.master, False)
    280     f, aux = partial_eval_wrapper(f, tuple(in_avals))
--> 281     out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
    282     out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
    283     out_pvs = map(PartialVal, zip(out_avals, out_consts))

~/miniconda3/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1132 
   1133   def bind(self, fun, *args, **params):
-> 1134     return call_bind(self, fun, *args, **params)
   1135 
   1136   def process(self, trace, fun, tracers, params):

~/miniconda3/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1121   if top_trace is None:
   1122     with new_sublevel():
-> 1123       outs = primitive.impl(fun, *args, **params)
   1124   else:
   1125     tracers = map(top_trace.full_raise, args)

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    524 
    525 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 526   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    527                                *unsafe_map(arg_spec, args))
    528   try:

~/miniconda3/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    222       fun.populate_stores(stores)
    223     else:
--> 224       ans = call(fun, *args)
    225       cache[key] = (ans, fun.stores)
    226     return ans

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    595   else:
    596     pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
--> 597     jaxpr, pvals, consts = pe.trace_to_jaxpr(
    598         fun, pvals, instantiate=False, stage_out=True, bottom=True)
    599   map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    421   with core.new_master(trace_type, bottom=bottom) as master:
    422     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 423     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    424     assert not env
    425     del master

~/miniconda3/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
    203             params, call_jaxpr, invals, cts_in, cts_in_avals)
    204       else:
--> 205         cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
    206                                                          **eqn.params)
    207     cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out

~/miniconda3/lib/python3.8/site-packages/jax/lax/lax_control_flow.py in _while_transpose_error(*_, **kwargs)
    536 
    537 def _while_transpose_error(*_, **kwargs):
--> 538   raise ValueError("Reverse-mode differentiation does not work for "
    539                    "lax.while_loop or lax.fori_loop. "
    540                    "Try using lax.scan instead.")

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

closed time in a month

bionicles

issue commentgoogle/jax

Can't differentiate scipy.linalg.expm for M Layer paper

The Fréchet derivative is implemented for the forward mode but JAX doesn't currently support custom JVP and VJP .

Indeed. Implementing the VJP of expm is open as #3447.

Closing this issue since it is essentially a duplicate in terms of the work required.

bionicles

comment created time in a month

delete branch google/jax

delete branch : jaxpr-stats

delete time in a month

delete branch google/jax

delete branch : frames-path

delete time in a month

delete branch google/jax

delete branch : primitive-shape-test

delete time in a month

push eventgoogle/jax

Roy Frostig

commit sha 5135fd176db23adc6cdb78b48155a46d7acdbde9

fix jaxpr util test under enable_x64

view details

push time in a month

PR merged google/jax

fix jaxpr util test in enable_x64 mode cla: yes
+7 -6

0 comment

1 changed file

froystig

pr closed time in a month

PR opened google/jax

fix jaxpr util test in enable_x64 mode
+7 -6

0 comment

1 changed file

pr created time in a month

create barnchgoogle/jax

branch : primitive-shape-test

created branch time in a month

push eventgoogle/jax

Roy Frostig

commit sha 8cc9579ca0c9a4bc72d1d68ccc8e0dba68c7b281

check path prefixes using os.path instead of string comparisons

view details

push time in a month

PR merged google/jax

check path prefixes using os.path instead of string comparisons cla: yes
+10 -2

0 comment

1 changed file

froystig

pr closed time in a month

push eventgoogle/jax

Roy Frostig

commit sha 908d54ab4651dd6757b3ea0f602eccf913a0199a

utilities to collect summary statistics of jaxprs

view details

Roy Frostig

commit sha d778a6d0741439eef038c29269534d25fdf168e2

move experimental.jaxpr_stats to jaxpr_util

view details

push time in a month

PR merged google/jax

Reviewers
summary statistics of jaxpr equations cla: yes

Example:

$ cat test.py
from jax import jit, make_jaxpr, numpy as jnp
from jax.experimental import jaxpr_stats as js

def f(x, y):
  s = jit(jnp.sin)(x)
  return jnp.sin(s) + jnp.cos(y)

j = make_jaxpr(f)(1., 1.).jaxpr
print(j)

print()
print('summary of primitives:')
h = js.primitives(j)
js.print_histogram(h)

print()
print('summary of primitives by source point:')
h = js.primitives_by_source(j)
js.print_histogram(h)

print()
print('summary of primitives by output shapes:')
h = js.primitives_by_shape(j)
js.print_histogram(h)
$ python test.py
{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = sin a
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    name=sin ] a
      d = sin c
      e = cos b
      f = add d e
  in (f,) }

summary of primitives:
2 sin
1 xla_call
1 cos
1 add

summary of primitives by source point:
1 xla_call @ test.py:5 (f)
1 sin @ test.py:6 (f)
1 sin @ test.py:5 (f)
1 cos @ test.py:6 (f)
1 add @ test.py:6 (f)

summary of primitives by output shapes:
2 sin :: float32[]
1 xla_call :: float32[]
1 cos :: float32[]
1 add :: float32[]
+209 -0

0 comment

2 changed files

froystig

pr closed time in a month

PR opened google/jax

check path prefixes using os.path instead of string comparisons
+10 -2

0 comment

1 changed file

pr created time in a month

create barnchgoogle/jax

branch : frames-path

created branch time in a month

push eventgoogle/jax

Alex Alemi

commit sha afeefa6f1fb8c5fdd296e5bedf4d4e1c1df70b45

Add typing and namedtuple to `optimizers.py`, improve documentation. (#3570)

view details

push time in a month

PR merged google/jax

Add typing and namedtuple to `optimizers.py`, improve documentation. cla: yes

Add QHM and QHAdam optimizers to optimizers.py and create an Optimizer namedtuple to improve useability. Improve documentation of both optimizers.py and optix.py to include example usages.

+70 -17

2 comments

2 changed files

alexalemi

pr closed time in a month

PullRequestReviewEvent

push eventgoogle/jax

Roy Frostig

commit sha dbca9e682c61f77813341794dc85699085512652

unrevert #3674 (revert #3791)

view details

Roy Frostig

commit sha fe69d3c6f0875c243ac3a01cae66c50c3cad3dd0

always deref all locals that indirectly reach stack frames in the exception-reraise handler

view details

push time in a month

PR merged google/jax

stack traces without jax-internal frames, take 2 cla: yes

This PR brings back the changes from #3674 (reverted in #3791), and fixes the remaining issue among those that @mattjj and I believe were introduced in the original PR. Namely, the function that re-raises internally-encountered exceptions with an appended __cause__ now deletes local variables explicitly, so as to avoid creating reference cycles that involve stack frames.

+312 -2

0 comment

4 changed files

froystig

pr closed time in a month

Pull request review commentgoogle/jax

stack traces without jax-internal frames

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++import os+import sys+import traceback+import types++from .api_util import wraps++_jax_path = os.path.dirname(__file__)+_include_paths = [+    os.path.join(_jax_path, path) for path in (+        'config.py', 'dlpack.py', 'experimental', 'lax', 'lax_linalg.py',

I agree that this can easily become stale, and better done automatically when possible, but I'm not sure where to begin prior to a source tree restructuring. To create this list, I subjectively inspected jax/ and picked out what seemed appropriate. I'd welcome ideas for how to avoid forming such a list, or offer that we revisit this when we restructure our source directories.

froystig

comment created time in a month

push eventgoogle/jax

James Bradbury

commit sha 16ab9cb8c7c4fe8a15765371650e7bf1fa5bcd7f

support multi-host pmap with omnistaging (#4075)

view details

Matthew Johnson

commit sha 8232f2deee5d1e0c78496699f35225b2f5d6226b

adapt _TempAxisName for unhashable objs (#4077) adapt _TempAxisName for unhashable objs

view details

Matthew Johnson

commit sha 91207011889000850cd5a90d081530e18e5badaf

allow xla_computation to psum a constant (#4078) * allow xla_computation to psum a constant * allow axis_env to be None

view details

George Necula

commit sha 22b92c5122ab5af6f5e4560f9be08f5649ae7653

Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080)

view details

George Necula

commit sha c7aff1da06072db8fb074f09a8215615d607adc2

Revert "Use pytree from xla_client. (#4063)" (#4081) This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f. Tryting to revert because it seems that this produces test failures in Google.

view details

Benjamin Chetioui

commit sha ec90c3587adb569491bf54c7b447844491925337

[jax2tf] Fix bfloat16 bug in select_and_gather_add conversion. (#4058) * [jax2tf] Fix bfloat16 bug in select_and_gather_add conversion. This fix makes it possible to run bfloat16 tests for the jax2tf conversion of select_and_gather_add.

view details

Benjamin Chetioui

commit sha 4c22e012d27ca865d429492f0e5956791907a981

[jax2tf] Explictly raise an error when attempting to convert _select_and_scatter_add_p. (#4084)

view details

Adam Paszke

commit sha 1ba4e06c9101f5fcc3fa587e76b1247b72a5fa1b

Initial version of gmap (#4006) Co-autored-by: Matthew Johnson <mattjj@google.com>

view details

Roy Frostig

commit sha 13b32300efa1a78e7d909d5ae2df01943c3ae183

unrevert #3674 (revert #3791)

view details

Roy Frostig

commit sha a5c4e6fa3f9f135ad0aac46a63925c2188bbc801

always deref all locals that indirectly reach stack frames in the exception-reraise handler

view details

push time in a month

push eventgoogle/jax

James Bradbury

commit sha 9ab07d8574766a1a5fb738b0c45c91891aa00a5d

support axis_index_groups in psum(const) (#4070) * support axis_index_groups in psum(const) * add test for psum(constant, axis_index_groups) * rm trailing whitespace * Update lax_parallel.py

view details

George Necula

commit sha 1dbdaac7fcb88e29abe8c3d060666c478566bc56

[jax2tf] avoid import errors when omnistaging is enabled (#4072) * [jax2tf] avoid import errors when omnistaging is enabled

view details

Philipp Thölke

commit sha 1316562b0b96eb4ce10b0661bb113e1be3b9406c

Canonicalize result dtype to fix double precision problem in ldexp (#4069)

view details

Roy Frostig

commit sha 5f284615ca066a58bc25038f5b6be951a73a05ae

unrevert #3674 (revert #3791)

view details

Roy Frostig

commit sha 9bbfdaa9781bf6225048ee3153c8e61875143fc9

always deref all locals that indirectly reach stack frames in the exception-reraise handler

view details

push time in a month

PR opened google/jax

stack traces without jax-internal frames, take 2

This PR brings back the changes from #3674 (reverted in #3791), and fixes the remaining issue among those that @mattjj and I believe were introduced in the original PR. Namely, the function that re-raises internally-encountered exceptions with an appended __cause__ now deletes local variables explicitly, so as to avoid creating reference cycles that involve stack frames.

+312 -2

0 comment

4 changed files

pr created time in 2 months

create barnchgoogle/jax

branch : stack-traces

created branch time in 2 months

delete branch google/jax

delete branch : tycheck-cond-scan-params

delete time in 2 months

delete branch google/jax

delete branch : rm-auto-parallel

delete time in 2 months

delete branch google/jax

delete branch : changelog-unroll

delete time in 2 months

push eventgoogle/jax

Roy Frostig

commit sha df6a3da44959daf333eb636fd5b2dc45f7a7cedc

add scan unrolling to a previous changelog entry

view details

push time in 2 months

PR merged google/jax

add scan unrolling to a previous changelog entry cla: yes
+3 -0

0 comment

1 changed file

froystig

pr closed time in 2 months

PR opened google/jax

add scan unrolling to a previous changelog entry
+3 -0

0 comment

1 changed file

pr created time in 2 months

create barnchgoogle/jax

branch : changelog-unroll

created branch time in 2 months

push eventgoogle/jax

Roy Frostig

commit sha 34f90f55e62356856300cf4538ce48461443707a

remove auto-parallelization transformation

view details

push time in 2 months

PR merged google/jax

remove auto-parallelization transformation cla: yes

... also known over time as papply and shard. This PR removes the interpreter, rules, tests, and the corresponding function in the api module.

+0 -928

0 comment

4 changed files

froystig

pr closed time in 2 months

PR opened google/jax

remove auto-parallelization transformation

... also known over time as papply and shard. This PR removes the interpreter, rules, tests, and the corresponding function in the api module.

+0 -928

0 comment

4 changed files

pr created time in 2 months

create barnchgoogle/jax

branch : rm-auto-parallel

created branch time in 2 months

push eventgoogle/jax

Roy Frostig

commit sha 76bf9a8d1b13c81a1d8beb84c79697eedba7ed0c

sketch out psum changes under new pmap semantics Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

push time in 2 months

pull request commentgoogle/jax

Avoid lexically capturing the train_images value in MNIST VAE example.

Aside: if we were to go the route of #3238, this wouldn't be necessary. Also, an automatic hoist could be done conditionally, say, based on JIT backend or based on the size of a constant.

hawkinsp

comment created time in 2 months

push eventgoogle/jax

Roy Frostig

commit sha cd64d2eed533fca91dff6c73379edb067cb5ac59

typecheck scan and cond params

view details

push time in 2 months

PR merged google/jax

typecheck scan and cond params cla: yes
+105 -13

0 comment

3 changed files

froystig

pr closed time in 2 months

push eventgoogle/jax

Jake Vanderplas

commit sha ca8dc20a4b04adc65c65710b3190e549b6eaa54a

make jnp.abs() work for unsigned inputs (#3914)

view details

Justin Lebar

commit sha e8c7d9e2812cc8ca1a1b8ac78860d0882122fa30

s/Three-fry/Threefry/ (#3918) Per http://www.thesalmons.org/john/random123/

view details

Matthew Johnson

commit sha abfd33570b865b6529404d360919e7a506beba4f

update version and changelog for pypi (#3921) * update version and changelog for pypi * fix typos

view details

Matthew Johnson

commit sha 146cf49fa04ce8e476da6981f52688b63bef7f4e

delay backend check for xla_computation (#3920)

view details

Jake Vanderplas

commit sha 0ec1e251e9994c6ba6f9fb5ca433f91d71a84518

Fix jnp.tile for cases with zero reps (fixes #3919) (#3922)

view details

Roy Frostig

commit sha 7b64ae32033207928f78eee678f074a632a0d839

typecheck scan and cond params

view details

push time in 2 months

PR opened google/jax

typecheck scan and cond params
+105 -13

0 comment

3 changed files

pr created time in 2 months

push eventgoogle/jax

Jake Vanderplas

commit sha ee7f0353497bd6193e2532dd834f8dcc7cef6bf5

jax.random: use correct x32/x64 default dtypes. (#3841) This is a no-op in the current package, but will make things cleaner during the x64 deprecation.

view details

Matthew Johnson

commit sha c9d8acd2e921efdbc01b2208fa9f790a9c312933

put core trace state in a threading.local class (#3869) this is a refinement of the fix in #3845, so that we no longer need TraceState.set_state (and so that #3370 is easier to adapt)

view details

Vaibhav Srivastav

commit sha 3aa37d3af4324f53c72270296ca905189ab6fd99

Replicating sort_complex functionality from np.sort_complex to jax.numpy (#3870)

view details

Matthew Johnson

commit sha 616d63b19ae08c6d58f1de8d7c41fe78ebded157

fix vmap error, fixes #3877 (#3879)

view details

Stephan Hoyer

commit sha dd7ab39e4d8e8b0bc13ebfa70890c6236fea1e8e

Fix formatting in the custom derivatives notebook (#3876) Sphinx is apparently quite picky about consistent use of headers: you can't skip a header level. We were getting warnings like "WARNING: Title level inconsistent" in the docs build, and sub-headers weren't showing up on this page after the first section.

view details

Jake Vanderplas

commit sha d7733c30d60972efaa461ad74ed9de852e402708

Cleanup: canonicalize several dtypes to prevent noisy warnings (#3874)

view details

Jamie Townsend

commit sha 7506a3e5f0137f48415def878bea71c5977dfbd0

Fix flaky generated_fun_test.py test (#3885)

view details

David Majnemer

commit sha 33faf6a46e9c221357f69ef1bf365fed0d56d438

TPUs support half precision arithmetic (#3878) * TPUs support half precision arithmetic * update jax2tf tests to handle fp16 Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

Jamie Townsend

commit sha e28db33b01d1c611cae05d7eced92cb71e62d3f0

Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883 (#3888) * Add test for issue 3883 * Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883

view details

Matthew Johnson

commit sha 30980742c57f3b507d74b5ae79e9a75b5f57bc08

refine population_count type check (#3887) * refine population_count type check fixes #3886 * allow signed/unsigned ints for population_count https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/shape_inference.cc;l=314?q=xla%20f:shape_inference.cc * make lax_reference.population_count handle signed

view details

Jake Vanderplas

commit sha c38bc36803221728a05e76d962afd64c3f188c42

jnp.linspace & friends: more carefully handle dtypes (#3859)

view details

Jake Vanderplas

commit sha 190b6af5724771d07fa27e594c33d8b57ad02483

Improve searchsorted implementation (#3873)

view details

Jake Vanderplas

commit sha e0a8d4486134a26adf4ddb68f792d61a260d796d

Add jnp.modf() & improve test coverage for related functions (#3894)

view details

Peter Hawkins

commit sha 659dd39d74e8aecf1af0292a9e0e6a5b7570a12f

Add MLPerf results link. (#3896)

view details

Joshua George Albert

commit sha 02009e0cf09b0b018363bd522e855ab54a1fefba

BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com>

view details

Stephan Hoyer

commit sha 242b3249c6012b116fa214433491b4f5449dd49f

Add missing license headers (#3899) Oops!

view details

Stephan Hoyer

commit sha b0ef4838d7cbfa546d7f985604848bbe512cf241

Fixes to test_scipy_optimize.py for Google internal tests (#3902)

view details

Jake Vanderplas

commit sha c28a7111476aeda81572b3a6b33b8ae5fd6bf993

Cleanup: pass function name rather than function object (#3897)

view details

Matthew Johnson

commit sha c8771e12e0712901ff8ab8f53ab773ebf756f99b

add omnistaging flag placeholder (#3904)

view details

Matthew Johnson

commit sha de645c5b8bba6c8d3e0c82d7f8f62cdde137bbcb

update version and changelog for pypi (#3906)

view details

push time in 2 months

create barnchgoogle/jax

branch : tycheck-cond-scan-params

created branch time in 2 months

push eventgoogle/jax

Peter Hawkins

commit sha a6e2d20b315ca63a34c03ea8be2dd34d6a0da2b0

Add support for base dilation and window dilation to reduce window op… (#3803)

view details

Jake Vanderplas

commit sha 71f80a50b1daf265aa9c5092d2ec70459f1e98d3

Fix type mismatch in jet rule for abs (#3807)

view details

Roy Frostig

commit sha 63b331548e5569cec35073195afa076f6599a9ce

utilities to collect summary statistics of jaxprs

view details

Roy Frostig

commit sha e3e2f3267797772f2c8120dd81d7a06b82c6d0c5

move experimental.jaxpr_stats to jaxpr_util

view details

push time in 2 months

delete branch google/jax

delete branch : unroll-scan

delete time in 2 months

push eventgoogle/jax

Roy Frostig

commit sha a7b4d3790427897f195afd78318ca01434957b41

count equations generated by a source location

view details

push time in 2 months

push eventgoogle/jax

Roy Frostig

commit sha 8e6f72db659a7262248fde654f320cbf5ff75b24

count equations generated by a source location

view details

push time in 2 months

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 8f4ba7e679b889ce8b75ef8fa07a1df947d89e52

Allow specifying both `devices` and `axis_size` to pmap. (#3475) This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform.

view details

Jake Vanderplas

commit sha 19f308b9edeff7c7837785309467ff5e1726feb6

implement jax.random.choice (#3463)

view details

Jake Vanderplas

commit sha 1ed60b3f96a67c8787643e83da9897a01ae6a819

support stacked doubling (#3490)

view details

8bitmp3

commit sha f0dff9c19b3e540c69b8cbb53c109219c52f9a50

Fix a link rendering to Autograd's reverse-mode Jacobian method (#3501)

view details

8bitmp3

commit sha 7f9fc27c487c4088acc7ceaef8e34adfc87666c5

Another small fix of the link rendering in the Autodiff Cookbook - vmap transformation (#3502)

view details

Matthew Johnson

commit sha 3fff83790742a1b0a468032ecff6f8809fde57d9

pin numpy version in setup.py to avoid warnings (#3509)

view details

8bitmp3

commit sha f8570ec0493134814554ba198c3ef7d2854e1ff4

Update JAX FAQ.rst, jax.device_put (#3496)

view details

Matthew Johnson

commit sha a81e732d3d5398ef5cd74eaf518b6cb2d58877db

fix jet typo: jnp not np (#3507)

view details

Skye Wanderman-Milne

commit sha 18798a307bc1a40e3dd2e457d09e6194efc29710

Revert "Remove documentation presubmit action now that the RTD presubmit is enabled. (#3480)" (#3497) This (partially) reverts commit 68dcbdd1189cd938bf6023e4e1efaf64c71629aa. @hawkinsp points out this was running doctest, which the RTD presubmit doesn't do AFAIK. We still don't build the docs here, as the RTD presubmit takes care of that. Co-authored-by: Stephan Hoyer <shoyer@google.com>

view details

igorwilbert

commit sha e5d4ca31a8b59fb93229d3629035007e5aa329cc

Fix typo understanding jaxprs page on readthedocs (#3513)

view details

Jake Vanderplas

commit sha 71023227c6da30e317fa6924c7b4e610dc8368cd

Add private class wrapper for double-double arithmetic (#3521)

view details

Skye Wanderman-Milne

commit sha 8fe6da08d4c634bf6126ae8e7caa30f557137582

Add instructions for how to use the TensorBoard profiler to the profiling docs. (#3481)

view details

Matthew Johnson

commit sha 2f7108f78ba4d7fd12a8ad6232685d0d10b28c01

remove the lower_fun default multiple_results=True (#3524)

view details

Neil

commit sha 046006e047543d5c24f75a930ab87ef56c247032

Fix typo: np.bool -> np.bool_ (#3525) Replaced np.bool (which is just bool) with np.bool_, which is numpy's Boolean type.

view details

Tom Hennigan

commit sha ca5b0b180e9401e8ef0bffb0ef65b461cfc12096

Cast int8 to bool for lax.not in jax2tf. (#3519)

view details

Jake Vanderplas

commit sha 33c455a1a8c0fce75ac9ec6d094353221a22b938

Add jax.scipy.signal.detrend (#3516)

view details

Matthew Johnson

commit sha 02494924f9554fb2b4317043319b3731a13fccdd

fix an issue with newer versions of pytype (#3526)

view details

clemisch

commit sha 5ee6bc00340e07d1b4fd705bddc2b496cd21f25a

Remove unnecessary static_argnum in np.gradient (#3512)

view details

Thomas Keck

commit sha 490c8533c889127200a9d4a7ed283cfbacc96586

Adds boolean support for bitwise not and unittests for boolean support on logical operations. (#3483) Co-authored-by: Thomas Keck <thomaskeck@google.com>

view details

Peter Hawkins

commit sha 86fcfbfa1af09a8776dc0a8c8de5b4d661103ea7

Fix memory leak when no axis is provided to pmap. (#3394) * Fix memory leak when no axis is provided to pmap. * Work around flake8 false positive. Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

push time in 2 months

more