profile
viewpoint

google/jax 8411

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

skye/community 0

Stores documents used by the TensorFlow developer community

skye/Impala 0

Real-time Query for Hadoop

skye/impala-udf-samples 0

Sample UDF and UDAs for Impala.

skye/jax 0

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

skye/kite 0

Kite SDK

push eventskye/jax

Skye Wanderman-Milne

commit sha 888c9c77b3daee21b2fb2ce20657acc57da52b7a

Implement pmap of sharded_jit

view details

Peter Hawkins

commit sha f349979302548a9b1e93796875ce01b49f187eb8

Relax tolerance of LaxVmapTest.testDot for float64 inputs. (#3167)

view details

James Bradbury

commit sha 8b32c5ddfcf74af0b8428c3855f403450d01551e

Avoid running trivial jitted subcomputations in pe (#3169)

view details

Skye Wanderman-Milne

commit sha 12f26d3c8c6b3e020e17024d3cbd39b62fd631bb

Improve ``devices`` and related documentation (#3155)

view details

Matthew Johnson

commit sha a4094f72a4435e57e2c52e27085e55659797b54a

revise "Tracer with raw numpy" error message (#3160) * revise "Tracer with raw numpy" error message fixes #3133 * fix f-string typo * fix typo Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: James Bradbury <jekbradbury@google.com>

view details

Matthew Johnson

commit sha f9c978e9d608e373965512ebc498c0a1338af3ed

improve docstring of jax.numpy.broadcast_to (#3173) thanks @joaogui1 !

view details

Matthew Johnson

commit sha ae9d1753462d07d75e9e002b13538c385949d9af

fix while_loop cond function batching (#3174) * fix while_loop cond function batching fixes #3164 * add test for #3164

view details

Jake Vanderplas

commit sha 6e3c8b1d9beea2cd777e2dede22d1cc29eb608eb

Fix arr.view() on TPU & improve tests (#3141)

view details

Matthew Johnson

commit sha eb81d7e7ffcbfaff1e2b281ae6ee902d46653d6a

add dict in_axes example to vmap docstring (#3176) * add dict in_axes example to vmap docstring fixes #3161 * fix typo

view details

Matthew Johnson

commit sha 5c1de2836c11f5068f9df75fcce4fe887b458f1c

revise vmap in_axes/out_axes leaf type error msg (#3179) from #3161 discussion

view details

Dan Piponi

commit sha c459280e5678a96529dada852bf97b8cf01a4808

Added `associative_scan`. (#2170) * Added `associative_scan`. * Fixed problem where base case of associative scan could fail * remove jax.numpy dependence in associative_scan Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

Jake Vanderplas

commit sha bb2127cebd3ade161109ee4919a92aaff5c788c1

Future-proof view test against signaling NaNs (#3178)

view details

Skye Wanderman-Milne

commit sha 4cbd14ca89efca7377894c74607342b848bdb644

Update jax version to 0.1.68 (#3181)

view details

Roy Frostig

commit sha 82a9af519adb9ba0581d46656786040c3dedf53f

typecheck jaxpr equations

view details

Roy Frostig

commit sha 94b1f631ea772a7ad7e2a05586d013cd9349c4eb

raise TypeError for jaxpr typechecking errors

view details

Roy Frostig

commit sha 1205f7a00fbc67f9ae31cd2e46c7583372733efd

factor out jaxpr equation checks

view details

Roy Frostig

commit sha 8e70769cba6e1a95fb58fcd3137e878b32c3317d

factor out jaxpr-check context and variable environment

view details

Roy Frostig

commit sha 0c2c5584827feafea4487693cdc3fbcb9e109435

check that variables are typed equally throughout a jaxpr

view details

Roy Frostig

commit sha cc34ed26939adcc9adaee94e3ad6991b8a98f5c5

check aval compatibility, not strict equality, when typechecking jaxpr equations

view details

Roy Frostig

commit sha 42e7e20eabd3b415742700df5d75fcd6ab58fad3

update check_jaxpr doc

view details

push time in 6 hours

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 888c9c77b3daee21b2fb2ce20657acc57da52b7a

Implement pmap of sharded_jit

view details

Skye Wanderman-Milne

commit sha d8ede0106a6dcf8f71d97dc5e47636398dcb9124

Update jax/interpreters/pxla.py Co-authored-by: James Bradbury <jekbradbury@google.com>

view details

Skye Wanderman-Milne

commit sha ecd893626ff890906a3394551f55264b96d7713e

Address comments

view details

Skye Wanderman-Milne

commit sha 6ffde8061d20b3f3c2ce4196e7640dee4b2548dc

Implement pmap of sharded_jit (#3144) * Implement pmap of sharded_jit * Update jax/interpreters/pxla.py Co-authored-by: James Bradbury <jekbradbury@google.com> * Address comments Co-authored-by: James Bradbury <jekbradbury@google.com>

view details

push time in 8 hours

pull request commentgoogle/jax

Update installation directions in README to mention expected CUDA location.

@refraction-ray great points about symlinks. I didn't read the original thread closely enough and wasn't sure if the XLA_FLAGS option actually worked, but rereading it looks like it does :) I just added a blurb about that as well. @hawkinsp I didn't mention the unstable aspect, we can always update the README again and it probably won't change for a while at least (right?). But good to know.

skye

comment created time in 10 hours

push eventskye/jax

Skye Wanderman-Milne

commit sha ff22abaf0cb67c5a4e2bb6fd658eea2d1cc93996

Mention XLA_FLAGS option

view details

push time in 10 hours

pull request commentgoogle/jax

Enable XLA SPMD partitioning by default.

Yeah, I forgot I needed a new jaxlib. I'll mark this as a draft for now.

skye

comment created time in 10 hours

issue commentgoogle/jax

cannot find libdevice

Hi, sorry for the delay on this. I've created a PR with updated installation instructions: https://github.com/google/jax/pull/3190. Please comment if you have any suggestions. We can do even more to address this situation (@hawkinsp suggested bundling libdevice with jaxlib), but hopefully this will help for now.

murphyk

comment created time in 4 days

PR opened google/jax

Update installation directions in README to mention expected CUDA location.

See https://github.com/google/jax/issues/989

+11 -3

0 comment

1 changed file

pr created time in 4 days

create barnchskye/jax

branch : cuda_install_readme

created branch time in 4 days

push eventgoogle/jax

Jascha Sohl-Dickstein

commit sha 190f88dedebd90e87d73e6b19d591afaf5357b7f

Update Common_Gotchas_in_JAX.ipynb (#3189) typo fix

view details

push time in 4 days

PR merged google/jax

typo fix in Common_Gotchas_in_JAX.ipynb cla: yes
+2 -2

1 comment

1 changed file

Sohl-Dickstein

pr closed time in 4 days

pull request commentgoogle/jax

typo fix in Common_Gotchas_in_JAX.ipynb

Thanks Jascha!

Sohl-Dickstein

comment created time in 4 days

issue commentgoogle/jax

Unexpectedly high grad-of-scan memory usage

You're right that grad causes every (y + z) to be stored. Since the result of f is computed using x * (y + z), it needs to save the (y + z) values to compute the gradient. You can try using the new jax.remat, which causes values needed by the gradient computation to be recomputed instead of stored, thus saving memory. This probably makes sense for a scan like this, where you're creating a large amount of easy-to-compute values. See https://github.com/google/jax/pull/1749 for examples of using remat. I think doing scan(remat(scanned), ...) should work in this case.

cc @mattjj who created remat

jeffgortmaker

comment created time in 4 days

issue closedgoogle/jax

how to vectorize custom functions with vmap

Hello everyone, I am trying to vectorize a function which takes 2 parameters: (model_params, x). Both of these parameters are list of numpy arrays.

I would like to vectorize this function on the second parameter. My problem is that the shapes of the arrays in x parameter are not fixed. I tried to fill them with zeros to have the same shape and get rid of the extra values once I am in the function but it didn't work.

I would appreciate if you have any suggestions for me.

Thank you!

closed time in 5 days

cagrikymk

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def _xla_callable_device(nreps, backend, device, arg_devices):     else:       assert False  # Unreachable given the error check in _xla_callable -def _xla_callable_args(c, avals, tuple_args, replicated=None):+# Used within _xla_callable_args and _xla_param to distinguish between None (no+# sharding annotation set) and replicated.+_replicated_param = object()

Yeahh this kinda sucks. I added this because I need to distinguish between three options for what to do with a single parameter in the non-tuple case: partitioned (tuple of ints), replicated (this thing), no annotation (None). I don't need it elsewhere because everywhere else deals with sequences of values, so no annotation is None and everything else is a sequence. I think replication already has True, False, None as options so it's doesn't need this stupid extra option.

If I push the trinary logic into xla_bridge, this might be a little cleaner. But I still will probably need the extra replicated option, and then it'll be everywhere instead of in this one internal place... so I'm not sure. I'm gonna try it in a follow-up change along with the context manager change George suggested.

skye

comment created time in 5 days

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def _pval_to_result_handler(axis_size, nrep, pval, devices, backend):   else:     if pv is not core.abstract_unit:       unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype)-      sharding_spec = _pmap_sharding_spec(nrep, axis_size, pv, True)+      sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv,+                                          True)       indices = spec_to_indices(unsharded_aval.shape, sharding_spec)     else:       sharding_spec = indices = None       unsharded_aval = pv     return aval_to_result_handler(sharding_spec, indices, unsharded_aval) -def _pmap_sharding_spec(nrep, axis_size, sharded_aval, mapped):+def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):+  if not mapped and npart > 1:+    # TODO(skye): ShardingSpec assumes replication is treated as the innermost+    # axis, but in this case, it needs to be the outer axis.

Added your name to the TODO :) I think two replication factors would work for this case, but not the most general case of in_axes/out_axes=N for N > 1 on pmap. Introducing a new replication factor until we figure out a more general solution seems strictly better than the current situation though, and doesn't seem like too much work (I still have staging around ShardedDeviceArray.__init__ so we don't have to update all callers when we update how sharding works).

skye

comment created time in 5 days

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def _xla_callable_device(nreps, backend, device, arg_devices):     else:       assert False  # Unreachable given the error check in _xla_callable -def _xla_callable_args(c, avals, tuple_args, replicated=None):+# Used within _xla_callable_args and _xla_param to distinguish between None (no+# sharding annotation set) and replicated.+_replicated_param = object()++def _xla_callable_args(+    c, avals, tuple_args, replicated=None,+    partitions: Optional[Sequence[Optional[Tuple[int]]]] = None):

In practice it's always a Tuple, but Sequence works. I changed it.

skye

comment created time in 5 days

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def get_num_partitions(*partitions): class ResultToPopulate(object): pass result_to_populate = ResultToPopulate() -def _pvals_to_results_handler(size, nrep, out_pvals, devices, backend):+def _pvals_to_results_handler(+    size, nrep, npart,+    out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],+    out_pvals, devices, backend):   nouts = len(out_pvals)-  handlers = [_pval_to_result_handler(size, nrep, pval, devices, backend)-              for pval in out_pvals]+  if out_parts is None:+    out_parts = (None,) * len(out_pvals)+  handlers = [+      _pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)+      for pval, parts in safe_zip(out_pvals, out_parts)+  ]+   def handler(out_bufs):-    buffers = [[result_to_populate] * nrep for _ in range(nouts)]+    assert nrep * npart == len(out_bufs)+    buffers = [[result_to_populate] * nrep * npart for _ in range(nouts)]

Perhaps! I have a microbenchmark for this path in pmap_benchmark.py, we could try it out. I'm gonna leave the logic as-is for this change though.

skye

comment created time in 5 days

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def dynamic_fun(dummy, *args): multi_host_supported_collectives: Set[core.Primitive] = set()  +PartitionsOrReplicated = Optional[Tuple[int, ...]]++def _find_partitions(jaxpr) -> Tuple[+    Optional[Tuple[PartitionsOrReplicated, ...]],+    Optional[Tuple[PartitionsOrReplicated, ...]],+    int]:+  """Returns (in_partitions, out_partitions, num_partitions)."""+  for eqn in jaxpr.eqns:+    if eqn.primitive.name == "sharded_call":

It would be good to raise an error in this situation though. This is currently a sharp edge in general with sharded_jit, since it doesn't work under other transformations (e.g. jit), but it won't stop you from trying. I'd like to improve this situation, either by figuring out how composition should work or at least figuring out how to raise good errors, but not in this change :)

skye

comment created time in 5 days

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def dynamic_fun(dummy, *args):                  for d in xb.local_devices(host_id)]     else:       devices = xb.get_backend(backend).get_default_device_assignment(-          num_global_replicas)+          num_global_replicas, num_partitions)   else:-    if num_local_replicas != len(local_devices):+    if num_local_shards != len(local_devices):       local_devices_str = ", ".join(map(str, local_devices))       raise ValueError(           "Leading axis size of input to pmapped function must equal the "           "number of local devices passed to pmap. Got axis_size=%d, "           "num_local_devices=%d.\n(Local devices passed to pmap: %s)"           % (axis_size, len(local_devices), local_devices_str))-    if num_global_replicas != len(devices):-      raise ValueError("compiling computation that requires %s replicas, "-                       "but %s devices were specified"-                       % (num_global_replicas, len(devices)))+    if num_global_shards != len(devices):+      raise ValueError("compiling computation that creates %s shards, "+                       "but %s devices were specified" %+                       (num_global_shards, len(devices))) -  device_assignment = tuple(d.id for d in devices)+  device_assignment = tree_map(lambda d: d.id, devices)

It may be either 1D or 2D, depending on if we got a 1D devices list/tuple or a 2D assignment from get_default_device_assignment. I added a comment.

skye

comment created time in 5 days

push eventskye/jax

Skye Wanderman-Milne

commit sha ecd893626ff890906a3394551f55264b96d7713e

Address comments

view details

push time in 5 days

push eventskye/jax

Skye Wanderman-Milne

commit sha d8ede0106a6dcf8f71d97dc5e47636398dcb9124

Update jax/interpreters/pxla.py Co-authored-by: James Bradbury <jekbradbury@google.com>

view details

push time in 5 days

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def dynamic_fun(dummy, *args):     # XLA computation at all; we handle this as a special case so we can stage     # out multi-replica XLA computations regardless of the hardware available.     # The 'None' values here are just dummies we know will be ignored.-    handlers = [_pval_to_result_handler(axis_size, None, pval, local_devices,-                                        backend)-                for pval in out_pvals]+    handlers = [+        _pval_to_result_handler(axis_size, None, None, None, pval, local_devices,+                                backend) for pval in out_pvals+    ]     results = [handler(None) for handler in handlers]     return lambda *_: results    jaxpr_replicas = xla.jaxpr_replicas(jaxpr)   num_local_replicas = axis_size * jaxpr_replicas

No this part works! It should just be doing something equivalent to this diff hunk: https://github.com/google/jax/pull/3107/files#diff-6609803133099eb146acd9d3128dccc3R624-R632, but ideally in a way that still raises the original error when appropriate. I think we just need to move that logic to after we've traced the jaxpr.

skye

comment created time in 5 days

Pull request review commentgoogle/jax

Implement pmap of sharded_jit

 def dynamic_fun(dummy, *args):     # XLA computation at all; we handle this as a special case so we can stage     # out multi-replica XLA computations regardless of the hardware available.     # The 'None' values here are just dummies we know will be ignored.-    handlers = [_pval_to_result_handler(axis_size, None, pval, local_devices,-                                        backend)-                for pval in out_pvals]+    handlers = [+        _pval_to_result_handler(axis_size, None, None, None, pval, local_devices,+                                backend) for pval in out_pvals+    ]     results = [handler(None) for handler in handlers]     return lambda *_: results    jaxpr_replicas = xla.jaxpr_replicas(jaxpr)   num_local_replicas = axis_size * jaxpr_replicas   num_global_replicas = global_axis_size * jaxpr_replicas+  arg_parts, out_parts, num_partitions = _find_partitions(jaxpr)++  num_local_shards = num_local_replicas * num_partitions+  num_global_shards = num_global_replicas * num_partitions+   axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,), devices)    tuple_args = len(sharded_avals) > 100  # pass long arg lists as tuple for TPU    c = xb.make_computation_builder("pmap_{}".format(fun.__name__))   xla_consts = _map(partial(xb.constant, c), consts)   replicated = [not m for m in mapped_invars]-  xla_args = xla._xla_callable_args(c, sharded_avals, tuple_args, replicated)+  xla_args = xla._xla_callable_args(c, sharded_avals, tuple_args, replicated,+                                    arg_parts)   out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts,                                 extend_name_stack(wrap_name(name, 'pmap')), *xla_args)-  built = c.build(xops.Tuple(c, out_nodes))+  build_out_tuple = partial(xops.Tuple, c, out_nodes)+  if out_parts is not None:+    out_tuple = xb.with_sharding(c, out_parts, build_out_tuple)

Context manager seems like a good idea, and I might also push the to-shard-or-not-to-shard logic into it, since it's pretty awkward. I'm gonna do this as a follow-up PR, I filed https://github.com/google/jax/issues/3183 for now.

skye

comment created time in 5 days

issue openedgoogle/jax

Change xla_bridge.with_sharding to be a context manager

See https://github.com/google/jax/pull/3144#discussion_r427283109

created time in 5 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha a3e0cd1293236dca6843740ceebda0423c8998f1

Fix pxla.shard_args bug (#3170)

view details

push time in 5 days

PR merged google/jax

Fix pxla.shard_args bug cla: yes
+5 -1

0 comment

2 changed files

skye

pr closed time in 5 days

created taggoogle/jax

tagjax-v0.1.68

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

created time in 6 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 4cbd14ca89efca7377894c74607342b848bdb644

Update jax version to 0.1.68 (#3181)

view details

push time in 6 days

PR merged google/jax

Update jax version to 0.1.68 cla: yes
+8 -3

0 comment

2 changed files

skye

pr closed time in 6 days

PR opened google/jax

Update jax version to 0.1.68
+8 -3

0 comment

2 changed files

pr created time in 6 days

push eventskye/jax

Skye Wanderman-Milne

commit sha ef4debcaad5a5ac5182899e385f45ca64f5ce600

Update jax version to 0.1.67 (#3065)

view details

Jake Vanderplas

commit sha 96fbfeee7c53aefa9c19c4e71b4417f0676d7637

Add lax implementation of np.trapz (#3042)

view details

Peter Hawkins

commit sha 88d3422a5cedfbd74d9ce3039b79b8f0dcb02222

Add special case for integer scalars to jax.numpy.power. (#3066) * Add special case for integer scalars to jax.numpy.power.

view details

James Bradbury

commit sha 4a84b91304d78633d571ab706a031cd9b1a5de7a

lower in_axes=None to XLA replication annotation (#3025) * lower in_axes=None to XLA replication annotation * ignore replicated value for tokens

view details

Skye Wanderman-Milne

commit sha 11760ca9344f252a4cc61933ec1e7ff9cf9f1cde

Refactor aval_to_result_handler to take unsharded aval. (#3067) This is in preparation for calling it from the sharded_jit. Currently aval_to_result_handler is specific to pmap, but this change makes it work for any kind of sharding.

view details

Tom Hennigan

commit sha abdf504e9e52c303fe4b4d3938661085f7bcb323

Avoid recompilation of rolled loops in threefry2x32. (#3069)

view details

Peter Hawkins

commit sha cd966f28ed652f61d960a5b2118ddf5ee79d9534

Disable check_type for trapz test due to test failures. (#3071)

view details

joao guilherme

commit sha d2f84d635bedaa5a415339f9ffe7f19ba5d56855

Change instances of onp to np and np to jnp (#3044)

view details

Sharad Vikram

commit sha e9d33946d61be55ea0e774f9c3a019aee1c65a0c

Make jnp.array convert empty list to DeviceArray (#3049) * Make jnp.array convert empty list to DeviceArray * Add additional tests for empty classes with __array__ Co-authored-by: Peter Hawkins <phawkins@google.com>

view details

Peter Hawkins

commit sha 91d1e0ddbd06360a5295a0e8a68ba055c57009e9

Disable trapz test on TPU. (#3078)

view details

Peter Hawkins

commit sha 22d14fd7ddd317e828324de358b51925c74212ea

Remove workaround for Mac linear algebra bug that is fixed in the minimum jaxlib version. (#3080)

view details

Jake Vanderplas

commit sha 59aab01629379c1339acb7a985d1d64e5c001b08

Implement .view() method of jax.numpy arrays (#3073)

view details

Jake Vanderplas

commit sha 777636af4a74ff602de05abb7f49e8fcdecd5663

promote integer inputs to float in jnp.median() and jnp.quantile() (#3082)

view details

Jake Vanderplas

commit sha 6bd01602191fdd9f261ea46e169b70b636e857d2

disable arr.view() test on TPU (#3089) * disable arr.view() test on TPU * Update lax_numpy_test.py Use decorator. * Update lax_numpy_test.py Co-authored-by: Peter Hawkins <hawkinsp@cs.stanford.edu>

view details

Skye Wanderman-Milne

commit sha 16cf84514862b73a191a02eb551028eb2b16c700

Update jaxlib version to 0.1.47 in README

view details

Roy Frostig

commit sha 92176418700127d70433c2113588e04833116902

take a single operand in lax.cond and deprecate the old calling convention

view details

Roy Frostig

commit sha 28e698ed72c28b6a46997225cd0319112eac5ca9

update uses of cond in lax control flow tests

view details

Roy Frostig

commit sha 97738be44febb56735f81584bc90659334feeeb8

bind a single-operand cond primitive and update jaxpr typechecks

view details

Roy Frostig

commit sha 027a53900099bdf73e4d78b549c9f40e9ee9a7d4

translation rule for single-operand cond primitive

view details

Roy Frostig

commit sha fc4ab77bc61373298cb15f1aaa0e3396a40d244e

merge constvars when forming cond branch jaxprs

view details

push time in 6 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 12f26d3c8c6b3e020e17024d3cbd39b62fd631bb

Improve ``devices`` and related documentation (#3155)

view details

push time in 6 days

PR merged google/jax

Improve ``devices`` and related documentation cla: yes

Fixes #2898

+37 -13

0 comment

2 changed files

skye

pr closed time in 6 days

issue closedgoogle/jax

jax.devices returns only the devices from the default backend

The documentation of jax.devices says that it returns all devices. This is not true in multi-backend setups, it returns only the devices from the default backend. This was confusing to me initially, and I also encountered this confusion in issue #2785

The simplest change would be to the documentation: if no backend is specified then only the devices on the default backend are returned (along with an explanation of what is the default backend).

A better change may be though to say that it returns all devices, in the order of backend priority, such that devices()[0] is the same as now. This may break some code though.

closed time in 6 days

gnecula

PR opened google/jax

Fix pxla.shard_args bug
+5 -1

0 comment

2 changed files

pr created time in 6 days

push eventskye/jax

Skye Wanderman-Milne

commit sha e6e22280710f1d6884e2a007d1056c6a63e646b0

Fix pxla.shard_args bug

view details

push time in 6 days

push eventskye/jax

Daniel Johnson

commit sha 98d46d39aa9138fbd8143411b99dc42ef0a36ad3

Implement indexing helpers

view details

Daniel Johnson

commit sha 72593f0489315cefe685f36f6c89ae8b81dfc860

Modify syntax to `x.at[idx].set(y)` and similar.

view details

Daniel Johnson

commit sha ab0a903172fdff74287b047ae7e0e26b6fc1c0fc

Merge branch 'master' into index_sugar

view details

Daniel Johnson

commit sha 9e429907c1236e600473d2ff31008cb81e9d1395

Add support for `mul`

view details

Daniel Johnson

commit sha f4f67b065f5b4663474df8e91785524101387a36

Remove unused textwrap

view details

Skye Wanderman-Milne

commit sha 7a61aea530f49fa17f813c5bfb72c1b491269bfd

Add type hint to fix pytype error. (#2727) Without this, pytype (correctly) points out that AbstractValues do not have shape/type information.

view details

Jake Vanderplas

commit sha f9a3aed0b61971b4306f081fd39114d2ea578092

Implement numpy.linalg.multi_dot (#2726) * Implement numpy.linalg.multi_dot * Thread precision through multi_dot

view details

Peter Hawkins

commit sha 14618d1af49a53e6eaecf2207effb36646a21e4a

Update XLA. (#2733)

view details

Skye Wanderman-Milne

commit sha fe8d1d1b4a10d0485ef9ed47ff0c2eaf6fca3547

Temporarily make ShardedDeviceArray.__init__ optionally accept old si… (#2730) This allows us to incrementally update ShardedDeviceArray creators to the new constructor introduced in https://github.com/google/jax/commit/07571ae4dd3fceee580aa49c4490f99ce7f6b6de.

view details

Skye Wanderman-Milne

commit sha 0e29bd4ba3ee8996f0362393dad3baf39c803dc5

Fix some bugs in _reshape_sharded_device_array (#2732)

view details

Chris Jones

commit sha 40884dbd303879cd893b4c734d201c643f2a000b

Fix copy-paste error It looks as though `_device_put_scalar` should be used here. If not, `_device_put_scalar` should be removed, as it is otherwise unused.

view details

Matthew Johnson

commit sha 61c76d45beea467e4e438434ef12c1a9e5d86f66

Merge pull request #2738 from chr1sj0nes/patch-1 Actually use xla._device_put_scalar for device-putting scalars

view details

Peter Hawkins

commit sha d356c41f77401fb796a763f19beda65f86300f50

Release jaxlib 0.1.44. (#2740)

view details

Jamie Townsend

commit sha 708107ebe30826d428b56cebe553b0e9016a0c80

Add numpy.rint to lax numpy (#2724) * Add numpy.rint to lax numpy * Use round_to_nearest_even for numpy.rint * Add rint to jax.numpy docs * Fix np.rint float promotion

view details

Chris Jones

commit sha 903b50ebf102b3152d60a7cc0e465bdeeee2c8f5

Some cleanup and reformatting in `xla.py`. - Make creation of a few dictionaries more readable. - Use f-strings where possible. - Remove unused imports and function parameters. - Don't format string before passing to `log` function.

view details

Chris Jones

commit sha f7070e56d19b0cbdac7d6f2474bae6dde673d612

Merge branch 'master' into changelist/306845248

view details

Peter Hawkins

commit sha 42887172b04c8933d628d5101221d38b4815114f

Add a regression test that runs the same computation on all devices that are present. (#2741)

view details

Roy Frostig

commit sha d906c89e0bf39f6b41d360db27b1c70b9b29fdad

fix scipy_signal_test convolve failures

view details

Matthew Johnson

commit sha 6d889efd30d8c1b13eceb2493e0cb3b726fbec0a

Merge pull request #2743 from chr1sj0nes/changelist/306845248 Some cleanup and reformatting in `xla.py`.

view details

Daniel Johnson

commit sha a5efe842562af81cae4df7aca95e52981dea2802

Update names and documentation.

view details

push time in 6 days

push eventskye/jax

Peter Hawkins

commit sha 0557248fbd983ab5f86707b3d71fc62a05d028a4

Check for unsupported dtypes and issue a helpful error. (#2885)

view details

James Bradbury

commit sha 43efbe2db89ceed0bee36fd0d876962394201722

Reset parameter replication default (#2880) * Reset parameter replication default * add tests

view details

Peter Hawkins

commit sha 3e87e8f9608338b9a44f85a8164e02f46a801968

Add relu6, hard_swish, and hard_sigmoid to docs. (#2886)

view details

Peter Hawkins

commit sha b8cbc9583a01bfe69b03c751950e763a40f6416d

Fix lax_reference implementation of round() to match lax. (#2894) lax.round() is documented to round half away from zero, but np.round() rounds to nearest even.

view details

Peter Hawkins

commit sha c8d1700bd3f34feae7d7dcc4d19b186190d7fcb7

Make sure gather/scatter indices in lax gradient tests aren't out of bounds. (#2895) Out-of-bounds gathers are clamped to be in bounds, but out-of-bounds scatters are dropped entirely. This can cause gradient tests to fail because the two operations aren't duals of one another, as the gradient rules expect.

view details

Jacob Kelly

commit sha 1f7ebabfc8445cba9b26f46275bbac61b8bce151

add jets for sines fns (#2892) refactor remove duplicate

view details

George Necula

commit sha b39da1f842f1363b8b36052c0837407de0be9c2d

Fix jit with device placement (#2883) In setups with multiple backends, a jit happens on the default backend, unless we give a `backend` parameter. This is true even if the inputs are committed to a device on the non-default backend, or if we pass a `device` parameter to jit.

view details

George Necula

commit sha 8d4b6857adbd52c3093f034b35c23a6dc2e20492

Fix typo in tests; caught on GPU and TPU (#2902)

view details

Roy Frostig

commit sha 3216f5ca4647a00a5a53663846a56dc0faf1fc01

err on empty operand in numpy argmin and argmax fixes #2899

view details

Skye Wanderman-Milne

commit sha 815a92e411649cb74023087d38bafffed1841003

Remove assert from ShardedDeviceArray staging. (#2908) This would erroneously fail on Cloud TPU because the TPU client has its own buffer type.

view details

Skye Wanderman-Milne

commit sha 3aa953d8d40ec407036a49d0de2195ce609c8d6c

Update jax version to 0.1.65 (#2909)

view details

Matthew Johnson

commit sha bb608339b5323fffb257597ee86ba657accb1b96

update changelog

view details

Roy Frostig

commit sha a4deae392ef8e2a901016f3445e5714b97019fe2

err on empty operand dimension in numpy argmin and argmax see #2899

view details

Matthew Johnson

commit sha e06bde8cc0b1bef42f38f9b4a018ecb2daaab1cf

revise xla.device_put device logic (#2907) * revise xla.device_put device logic, fixes #2905 * remove test of behavior we don't want Previously, we were testing that for a DeviceArray x, writing jax.device_put(x) would evaluate to a DeviceArray *on the default device*. Instead, we should be happy with just returning the same DeviceArray without any movement.

view details

George Necula

commit sha 2e9047d3884451294874fd485209d3218ee968fb

Add flag to enable checking, and turn on checking in tests. (#2900) Fix an error in check_jaxpr.

view details

George Necula

commit sha ac023bf28fc368eebe70ab99ef691fc831422d75

Fixed a few places where device sticky-ness was lost. Added FAQ (#2882) * Fixed a few places where device sitckyness was lost. Added FAQ for device placement. I have also added a new test (multi_device_test.test_computation_follows_data), written more as part of the documentation. It is shorted than the old test_computation_follows_data (which is still there, renamed as test_computation_follows_data_old). I believe there is no extra coverage in test_computation_follows_data_old w.r.t. all the other tests we have. * Fix mypy annotations and updates based on comments * Undid some changes, will make another PR

view details

Peter Hawkins

commit sha 25e8280d8a56c48d7b6dc360aa7b4005ab1ee79c

Relax some test tolerances. (#2917)

view details

James Bradbury

commit sha 279a077c04106a7fd39cb55e24dc7c11bf5ffc15

Avoid tuple allreduce lowering of psum on TPUs (#2914) Tuple-shaped allreduces aren't supported in an XLA:TPU optimization pass (see internal bug), but since our use of them on GPU is due to compiler nondeterminism that isn't present on TPU, it should be fine to avoid this bug by disabling tuple psum on TPU.

view details

Tom Hennigan

commit sha 0736679c331fa2dd53abea4f8dd41ca7db3d0978

Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901) At head the following fails: ```python >>> import jax >>> import jax.numpy as jnp >>> jax.config.update('jax_numpy_rank_promotion', 'raise') >>> jax.nn.one_hot(jnp.ones([8]), 512) ... ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html. ```

view details

Peter Hawkins

commit sha 1b5642880ee4344ca88948c560a96bf525f9c28b

Fix test flakiness in autodiff tests for min/max type functions (#2918) * Fix test flakiness in autodiff tests for clamp, reduce, and reduce-window. We change the tests to avoid computing numerical gradients in the neighborhood of nondifferentiable points where, for example, the maximum element in a reduce-max changes. The autodiff approximation is only valid within an epsilon ball around a point, and close to an inflection point the approximation may not be valid. * Only test reduce-grad-mul for float types.

view details

push time in 6 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 1901691c8ae5900f81bcb32ecbe79aa8fe3ca5e1

Address comments

view details

push time in 6 days

Pull request review commentgoogle/jax

Improve ``devices`` and related documentation

 def devices(backend=None):   return get_backend(backend).devices()  -def local_devices(host_id=None, backend=None):-  """Returns a list of devices local to a given host (this host by default)."""+def local_devices(host_id: int = None, backend: str = None):+  """Like ``devices``, but only returns devices local to a given host.++  If ``host_id`` is ``None``, returns devices local to this host.

I just tried it. A bad backend raises an error, but a bad host_id currently returns an empty list. I made it raise an error and added a small test.

skye

comment created time in 7 days

Pull request review commentgoogle/jax

Improve ``devices`` and related documentation

 def device_count(backend=None):   return int(get_backend(backend).device_count())  -def local_device_count(backend=None):+def local_device_count(backend: str =None):   """Returns the number of devices on this host."""   return int(get_backend(backend).local_device_count())  -def devices(backend=None):-  """Returns a list of all devices.+def devices(backend: str = None):+  """Returns a list of all devices for a given backend. -  Each device is represented by a subclass of Device (e.g. CpuDevice,-  GpuDevice). The length of the returned list is equal to-  ``device_count()``. Local devices can be identified by comparing+  Each device is represented by a subclass of ``Device`` (e.g. ``CpuDevice``,+  ``GpuDevice``). The length of the returned list is equal to+  ``device_count(backend)``. Local devices can be identified by comparing   ``Device.host_id`` to ``host_id()``. +  If ``backend`` is ``None``, returns all the devices from the default backend.

Done.

skye

comment created time in 7 days

Pull request review commentgoogle/jax

Improve ``devices`` and related documentation

 def device_count(backend=None):   return int(get_backend(backend).device_count())  -def local_device_count(backend=None):+def local_device_count(backend: str =None):   """Returns the number of devices on this host."""   return int(get_backend(backend).local_device_count())  -def devices(backend=None):-  """Returns a list of all devices.+def devices(backend: str = None):+  """Returns a list of all devices for a given backend. -  Each device is represented by a subclass of Device (e.g. CpuDevice,-  GpuDevice). The length of the returned list is equal to-  ``device_count()``. Local devices can be identified by comparing+  Each device is represented by a subclass of ``Device`` (e.g. ``CpuDevice``,

It should be the first local device, but I'm not 100% sure that's true on TPU pods (we use the default XLA device assignment, which in theory can do whatever it wants). We should probably explicitly set a device in jax in the default single-device case, so I'm gonna leave this out for now.

skye

comment created time in 7 days

issue commentgoogle/jax

Jitting and memory leaks

I somehow clicked close and comment while I was typing :)

xwinxu

comment created time in 7 days

IssuesEvent

issue closedgoogle/jax

Jitting and memory leaks

I am using a computing cluster to run my jobs. I notice that Slurm keeps killing my jobs within seconds of submission due to out of memory errors, which I believe is because JAX is asking for more memory than allocated since GPU usage seems to increase, potentially due to a memory leak. I am running cuda10.1.

This is the stack trace I get:

RuntimeError: Internal: Unable to launch cuBLAS gemm on stream 0x562d67df53e0
2020-05-18 19:33:00.557492: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:940] could not synchronize on CUDA context: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered :: *** Begin stack trace ***
        _PyDict_DelItem_KnownHash
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_FastCallDict
        PyObject_CallFunctionObjArgs
        PyObject_ClearWeakRefs
        PyDict_SetItem
        PyDict_SetItemString
        PyImport_Cleanup
        Py_FinalizeEx
        _Py_UnixMain
        __libc_start_main
*** End stack trace ***
2020-05-18 19:33:00.557560: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_executable.cc:88] Check failed: pair.first->SynchronizeAllActivity() 
slurmstepd: error: *** JOB 505887 ON gpu117 CANCELLED AT 2020-05-18T19:33:01 ***

I also get this sometimes as in this issue:

 File "...py", line 126, in apply_fun
    outputs = sdeint(augmented_drift, augmented_diffusion, aug_init(inputs), np.linspace(0, 1, 100), rng)
  File "...py", line 49, in sdeint
    _, states = lax.scan(f_scan, dict(curr_state=s_init), xs)
  File "...anaconda/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 955, in scan
    linear=(False,) * (len(consts) + len(in_flat)))
  File "...anaconda/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 1337, in scan_bind
    num_carry=num_carry, linear=linear)
  File ".../anaconda/lib/python3.7/site-packages/jax/core.py", line 214, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "...anaconda/lib/python3.7/site-packages/jax/interpreters/batching.py", line 134, in process_primitive
    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
  File "...anaconda/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 1278, in _scan_batching_rule
    num_consts=num_consts, num_carry=num_carry, linear=linear)
  File "...anaconda/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 1337, in scan_bind
    num_carry=num_carry, linear=linear)
  File "...anaconda/lib/python3.7/site-packages/jax/core.py", line 211, in bind
    return self.impl(*args, **kwargs)
  File "...anaconda/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 981, in _scan_impl
    _, *outs = while_loop(cond_fun, body_fun, init_val)
  File "...anaconda/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 246, in while_loop
    body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
  File "...anaconda/lib/python3.7/site-packages/jax/core.py", line 211, in bind
    return self.impl(*args, **kwargs)
  File ".../anaconda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 218, in apply_primitive
    return compiled_fun(*args)
  File "...anaconda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 315, in _execute_compiled_primitive
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Resource exhausted: Out of memory while trying to allocate 208652400 bytes.

I'm wondering where the inefficiencies are in my code if this may be the issue. I have one jit in my update function which calls value_and_grad and updates the optimizer (Adam). And I use lax.scan to replace my for loops which are called when making a prediction. The update function is most top level so I have no jitted anything else. I've also tried calling xla._xla_callable.cache_clear() (from this issue) to no avail. This error happens quite inconsistently across my jobs and I'd appreciate any guidance in investigating the issue.

Thanks for your time!

closed time in 7 days

xwinxu

issue commentgoogle/jax

Jitting and memory leaks

That first stack is quite troubling! I'm guessing this can happen in low-memory conditions, but it's not a very friendly error message... if we resolve the OOM issue, this error will probably go away too.

Did you switch to scan from a regular Python loop? That may have resulted in more memory being used, since it's presumably working with larger arrays. From your stack trace, it looks like the scan isn't running under a jit. If you can, it might help to jit

xwinxu

comment created time in 7 days

PR opened google/jax

Reviewers
Improve ``devices`` and related documentation

Fixes #2898

+26 -12

0 comment

1 changed file

pr created time in 7 days

create barnchskye/jax

branch : devices_doc

created branch time in 7 days

PR opened google/jax

Enable XLA SPMD partitioning by default.

This option is needed for sharded_jit. Future APIs may use MPMD partitioning instead.

+6 -1

0 comment

1 changed file

pr created time in 7 days

create barnchskye/jax

branch : spmd

created branch time in 7 days

issue commentgoogle/jax

This is a bug in JAX's shape-checking rules; please report it!

@lucasb-eyer I'm surprised you're getting exactly the same error with an all-gather, since the error specifically talks about add: Expected element type in shape to be arithmetic type for operation add; got PRED. I haven't had a chance to dig into this myself with your code yet though.

Either way, I was able to repro this error with a small psum(bool) unit test, so there's definitely an issue there. I was gonna start by fixing the unit test, then rerunning your code with the fix and see what happens. It's also possible the unit test will expose the more general issue if there is one.

lucasb-eyer

comment created time in 7 days

pull request commentgoogle/jax

initial draft of hard coded docstrings

I think what I wrote in my above comment is still true. Re: Jake's point about numpy versions, I would guess having an accurate version of older numpy docs would usually be better than the updated numpy version (especially since I'm guessing most updates to numpy documentation are to introduce new features that jax doesn't support).

It sounds like everyone else prefers a note at the top though, so let's do that, and we can revisit if people are still confused. It looks like _wraps already supports this (link), so it's just a matter of adding the note.

joaogui1

comment created time in 7 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 888c9c77b3daee21b2fb2ce20657acc57da52b7a

Implement pmap of sharded_jit

view details

push time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 083cdd384141d5d67cb1a39f1039fbc9a296976d

Fix in pxla._inner_partitions (#3146) In cb77f2a22de49e85da93f43b7dc448aa238d5207, I switched to looking for sharding_constraint_p's name since sharding_constraint_p itself is defined in sharded_jit.py, but didn't quite get the update right.

view details

Skye Wanderman-Milne

commit sha d53bab930651dab87c19fbcdb7ce021c4336d2d0

Improve sharded_jit error message and fix test (#3145)

view details

Skye Wanderman-Milne

commit sha aeda1bddd3c99fd09bb65dc13fc5da2c36392a5e

Implement pmap of sharded_jit

view details

push time in 8 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha d53bab930651dab87c19fbcdb7ce021c4336d2d0

Improve sharded_jit error message and fix test (#3145)

view details

push time in 8 days

PR merged google/jax

Improve sharded_jit error message and fix test cla: yes
+30 -2

0 comment

2 changed files

skye

pr closed time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha e87d0e7d25ab1d6289eb9dd4d06b02f5b076e6d1

Make mypy happy

view details

push time in 8 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 083cdd384141d5d67cb1a39f1039fbc9a296976d

Fix in pxla._inner_partitions (#3146) In cb77f2a22de49e85da93f43b7dc448aa238d5207, I switched to looking for sharding_constraint_p's name since sharding_constraint_p itself is defined in sharded_jit.py, but didn't quite get the update right.

view details

push time in 8 days

PR merged google/jax

Fix in pxla._inner_partitions cla: yes

In cb77f2a22de49e85da93f43b7dc448aa238d5207, I switched to looking for sharding_constraint_p's name since sharding_constraint_p itself is defined in sharded_jit.py, but didn't quite get the update right.

+1 -1

0 comment

1 changed file

skye

pr closed time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 095bb56b5c96568fd9bba950c82e32c033d49d36

Make mypy happy

view details

push time in 8 days

PR opened google/jax

Fix in pxla._inner_partitions

In cb77f2a22de49e85da93f43b7dc448aa238d5207, I switched to looking for sharding_constraint_p's name since sharding_constraint_p itself is defined in sharded_jit.py, but didn't quite get the update right.

+1 -1

0 comment

1 changed file

pr created time in 8 days

create barnchskye/jax

branch : sharded_jit_fix

created branch time in 8 days

PR opened google/jax

Improve sharded_jit error message and fix test
+29 -2

0 comment

2 changed files

pr created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 82f283ace807a4379b2b609f0630e6042abcb65b

Improve sharded_jit error message and fix test

view details

push time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 8b6260b20356adacbab4e6fd060c09a0388a738f

Improve sharded_jit error message and fix test

view details

push time in 8 days

create barnchskye/jax

branch : sharded_jit_error

created branch time in 8 days

PR opened google/jax

Implement pmap of sharded_jit
+309 -46

0 comment

5 changed files

pr created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha cb77f2a22de49e85da93f43b7dc448aa238d5207

Move some sharded_jit functionality into pxla.py (#3142) Specifically: * Move `_inner_partitions` * Move `get_num_partitions` * Move and slightly modify the logic for finding and validating inner partitions to a new function, `reconcile_num_partitions` * Move `_partitioned_sharding_spec` and rename to `partitioned_sharding_spec` This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well.

view details

Skye Wanderman-Milne

commit sha bc2f0c1508baf468adcbbffe9d7a4986656edcd3

Implement pmap of sharded_jit

view details

push time in 8 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha cb77f2a22de49e85da93f43b7dc448aa238d5207

Move some sharded_jit functionality into pxla.py (#3142) Specifically: * Move `_inner_partitions` * Move `get_num_partitions` * Move and slightly modify the logic for finding and validating inner partitions to a new function, `reconcile_num_partitions` * Move `_partitioned_sharding_spec` and rename to `partitioned_sharding_spec` This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well.

view details

push time in 8 days

PR merged google/jax

Move some sharded_jit functionality into pxla.py cla: yes

Specifically:

  • Move _inner_partitions
  • Move get_num_partitions
  • Move and slightly modify the logic for finding and validating inner partitions to a new function, reconcile_num_partitions
  • Move _partitioned_sharding_spec and rename to partitioned_sharding_spec

This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well.

+84 -72

1 comment

2 changed files

skye

pr closed time in 8 days

create barnchskye/jax

branch : sharded_jit_pmap

created branch time in 8 days

pull request commentgoogle/jax

Move some sharded_jit functionality into pxla.py

I forgot one, I'm gonna assume you're still cool with this change.

skye

comment created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 7a97f2bd826f5bb125d54ad42406225a25eea4a8

Move _partitioned_sharding_spec too

view details

push time in 8 days

PR opened google/jax

Reviewers
Move some sharded_jit functionality into pxla.py

Specifically:

  • Move _inner_partitions
  • Move get_num_partitions
  • Move and slightly modify the logic for finding and validating inner partitions to a new function, reconcile_num_partitions

This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well.

+63 -51

0 comment

2 changed files

pr created time in 8 days

push eventskye/jax

Sandu Ursu

commit sha a9c1b3865969639a3917c361588d0ab7409b3012

Added link to README (#3139)

view details

Peter Hawkins

commit sha 36e7fad1e2342e8291147d61b8de4f5ab6489315

Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140) * Add a primitive integer_pow() for values raised to fixed integer scalar. Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal(). Fixes #3136 ``` In [1]: from jax import grad, make_jaxpr In [2]: def inv(x): return 1/x In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.)) 0.043945312 In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.) Out[4]: { lambda ; a. let b = integer_pow[ y=-7 ] a c = mul -6.0 b d = mul -120.0 c in (d,) } In [5]: ``` * Use x ** 3 in gelu definition.

view details

Stephan Hoyer

commit sha 460343293fd9b5c7e8953834a019aa9a27c1e102

Fix gradient for jax.scipy.ndimage.map_coordinates (#3110) * Fix gradient for jax.scipy.ndimage.map_coordinates Fixes GH3024 * minor refactor for clarification

view details

Skye Wanderman-Milne

commit sha 9152b760e9a2b9f64f7eff7c58d9ef9a04ae7e53

Add with_sharding_constraint method to be used within sharded_jit. (#3100) See the with_sharding_constraint docstring for a description of what this method does. Depending on how we decide nested sharded_jits should work, an alternative implementation for with_sharding_constraint could be: ```python def with_sharding_constraint(x, partitions): return sharded_jit(lambda x: x, in_parts=partitions, out_parts=partitions) ``` In this case, we could get rid of the with_sharding_constraint primitive, and possibly even the API. This implementation gets the job done for now without committing to a nested sharded_jit behavior, and is also much easier to take the gradient of than sharded_jit.

view details

Skye Wanderman-Milne

commit sha 001f8cdaf8d304b14ad992bf811a41c4be050755

Move some sharded_jit functionality into pxla.py Specifically: * Move `_inner_partitions` * Move `get_num_partitions` * Move and slightly modify the logic for finding and validating inner partitions to a new function, `reconcile_num_partitions` This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well.

view details

push time in 8 days

create barnchskye/jax

branch : sharded_jit_move_to_pxla2

created branch time in 8 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 9152b760e9a2b9f64f7eff7c58d9ef9a04ae7e53

Add with_sharding_constraint method to be used within sharded_jit. (#3100) See the with_sharding_constraint docstring for a description of what this method does. Depending on how we decide nested sharded_jits should work, an alternative implementation for with_sharding_constraint could be: ```python def with_sharding_constraint(x, partitions): return sharded_jit(lambda x: x, in_parts=partitions, out_parts=partitions) ``` In this case, we could get rid of the with_sharding_constraint primitive, and possibly even the API. This implementation gets the job done for now without committing to a nested sharded_jit behavior, and is also much easier to take the gradient of than sharded_jit.

view details

push time in 8 days

PR merged google/jax

Reviewers
Add with_sharding_constraint method to be used within sharded_jit. cla: yes

See the with_sharding_constraint docstring for a description of what this method does.

Depending on how we decide nested sharded_jits should work, an alternative implementation for with_sharding_constraint could be:

def with_sharding_constraint(x, partitions):
    return sharded_jit(lambda x: x, in_parts=partitions, out_parts=partitions)

In this case, we could get rid of the with_sharding_constraint primitive, and possibly even the API. This implementation gets the job done for now without committing to a nested sharded_jit behavior, and is also much easier to take the gradient of than sharded_jit.

+172 -5

1 comment

2 changed files

skye

pr closed time in 8 days

issue commentgoogle/jax

This is a bug in JAX's shape-checking rules; please report it!

Ok! I actually like that better, it just felt weird to have some lax primitives do casting and others not. But this way is more useful!

lucasb-eyer

comment created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha a611384e819809346f98d2230e68fc374d5fc0c6

fix test

view details

push time in 8 days

push eventskye/skye.github.io

Skye Wanderman-Milne

commit sha 84fcabf12e41a97a9c3d319193c2280115a32dd7

mute replay videos

view details

push time in 8 days

issue commentgoogle/jax

This is a bug in JAX's shape-checking rules; please report it!

@jekbradbury is right, in this case a boolean array is being passed to pmean. I think we should improve the error message, but still make it an error and require an explicit cast in this situation. However, there is an argument to be made for providing some kind of implicit casting functionality (either in psum or even a separate API endpoint?), since that's what numpy does, and even though psum is in lax there isn't a numpy equivalent.

Any thoughts? Otherwise I'll just improve the error message for now.

lucasb-eyer

comment created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha b480f03a3898a9179af47626570ede11ff96831c

fix regex string

view details

push time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 817500209ce4467ee8c7235f43badc6e894e61c1

fix bad rename

view details

push time in 8 days

push eventskye/jax

Roy Frostig

commit sha 92176418700127d70433c2113588e04833116902

take a single operand in lax.cond and deprecate the old calling convention

view details

Roy Frostig

commit sha 28e698ed72c28b6a46997225cd0319112eac5ca9

update uses of cond in lax control flow tests

view details

Roy Frostig

commit sha 97738be44febb56735f81584bc90659334feeeb8

bind a single-operand cond primitive and update jaxpr typechecks

view details

Roy Frostig

commit sha 027a53900099bdf73e4d78b549c9f40e9ee9a7d4

translation rule for single-operand cond primitive

view details

Roy Frostig

commit sha fc4ab77bc61373298cb15f1aaa0e3396a40d244e

merge constvars when forming cond branch jaxprs

view details

Roy Frostig

commit sha 226b5821d7f5787cc06fadde51c1099cdd3c0230

batching rule for single-operand cond primitive

view details

Roy Frostig

commit sha 34d75c2eacfcda7465acb31de317ecb19edf3d82

JVP rule for single-operand cond primitive

view details

Roy Frostig

commit sha 139c2a968705f1821341eb479d65fa181df9dba8

fix cond-with-constants tests to use jax.numpy where needed

view details

Roy Frostig

commit sha 37bb6f0913e05618b590d71295d4395d7fdbdb45

partial evaluation rule for single-operand cond primitive

view details

Roy Frostig

commit sha 30bb5fdd206716a4fae60461101a9746af4092bf

transpose rule for single-operand cond primitive

view details

Roy Frostig

commit sha efc1104cde39508d3b80c3169d1fab25fafede02

have loops module generate same-argument jaxprs for single-operand cond

view details

Roy Frostig

commit sha df622798abf442b1cc28881d6e3abc287ff4a7fe

cache jaxprs formed by cond branch staging

view details

Roy Frostig

commit sha cc6cea200708c4f3848508710e33121ce2367809

update reference jaxpr in cond-related jaxpr test

view details

Roy Frostig

commit sha f90bd4f5104a754ee11286a6436f24e599fb773e

move cond operand to final argument position

view details

Roy Frostig

commit sha de03c99b52723e0e86062258e7144f8a49494c6f

update jaxpr doc and tests with single-operand cond

view details

Roy Frostig

commit sha 76612c62e569438634b01bc2d89f27f899689783

test and fix cond when branch-staging is off

view details

Roy Frostig

commit sha e8f12b6bed513cf2166871e0d656e4751894fd78

remove deprecation warning on five-argument cond

view details

Roy Frostig

commit sha 34efdd98417e7d1597f18191e04c8752ad5a68b5

style changes in lax.cond

view details

Roy Frostig

commit sha 2950eb11c2716a27c4eb034444607a8434f84533

comment on joining staged jaxprs in partial evaluation of conditionals

view details

Roy Frostig

commit sha 6e3bfc339e79bf103a3a8c9f71ed71983aa1e573

comment on joining constants for conditional branch jaxprs

view details

push time in 8 days

Pull request review commentgoogle/jax

Add set_sharding method to be used within sharded_jit.

 def wrapped(*args, **kwargs):     return tree_unflatten(out_tree(), out)    return wrapped+++def _set_sharding_impl(x, partitions):+  # TODO(skye): can we also prevent this from being called in other+  # non-sharded_jit contexts? (e.g. pmap, control flow)+  raise NotImplementedError(+      "set_sharding() should only be called inside sharded_jit()")++def _set_sharding_translation_rule(c, x_node, partitions):+  return xb.set_sharding(c, x_node, partitions)++set_sharding_p = core.Primitive("set_sharding")+set_sharding_p.def_impl(_set_sharding_impl)+set_sharding_p.def_abstract_eval(lambda x, partitions: x)+ad.deflinear(set_sharding_p, lambda ct, partitions: (set_sharding(ct, partitions),))+xla.translations[set_sharding_p] = _set_sharding_translation_rule++def set_sharding(x, partitions: Optional[PartitionSpec]):

Good idea, done!

skye

comment created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 22b038ff3073b0c66a3dac1f6c62163f62194bd4

rename set_sharding to with_sharding_constraint

view details

push time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha f0d1852187820504128c81889a6d27edc7319fba

Update jax/interpreters/sharded_jit.py Co-authored-by: James Bradbury <jekbradbury@google.com>

view details

push time in 8 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha ed0e227ef516cd48c4a7ab2f9709f883faa1e034

Add sharded_jit translation rule. (#3099) This is potentially dangerous, because it lets sharded_jit() be called inside other calls primitives (e.g. jit, pmap) which isn't supported yet. I'm adding it now because I'm planning to implement pmap-of-sharded_jit soon, and it will help with testing a set_sharding API I'm also planning to add soon.

view details

push time in 8 days

PR merged google/jax

Add sharded_jit translation rule. cla: yes

This is potentially dangerous, because it lets sharded_jit() be called inside other calls primitives (e.g. jit, pmap) which isn't supported yet. I'm adding it now because I'm planning to implement pmap-of-sharded_jit soon, and it will help with testing a set_sharding API I'm also planning to add soon.

+49 -4

0 comment

2 changed files

skye

pr closed time in 8 days

Pull request review commentgoogle/jax

Add set_sharding method to be used within sharded_jit.

 def wrapped(*args, **kwargs):     return tree_unflatten(out_tree(), out)    return wrapped+++def _set_sharding_impl(x, partitions):+  # TODO(skye): can we also prevent this from being called in other+  # non-sharded_jit contexts? (e.g. pmap, control flow)+  raise NotImplementedError(+      "set_sharding() should only be called inside sharded_jit()")++def _set_sharding_translation_rule(c, x_node, partitions):+  return xb.set_sharding(c, x_node, partitions)++set_sharding_p = core.Primitive("set_sharding")+set_sharding_p.def_impl(_set_sharding_impl)+set_sharding_p.def_abstract_eval(lambda x, partitions: x)+ad.deflinear(set_sharding_p, lambda ct, partitions: (set_sharding(ct, partitions),))+xla.translations[set_sharding_p] = _set_sharding_translation_rule++def set_sharding(x, partitions: Optional[PartitionSpec]):+  """Identity-like function that specifies how ``x`` should be sharded.++  WARNING: this feature is still under active development! It may not work well,+  and may change without warning!++  This should only be called inside a function transformed by ``sharded_jit``.+  It refines how ``f`` is sharded. ``partitions`` must correspond to the same+  number of total partitions dictated by the outer ``sharded_jit`` and any other+  ``set_sharding`` calls. In the case where only replication has been specified,+  any ``partitions`` are valid.++  Example usage:+    @partial(sharded_jit, in_parts=None, out_parts=None, num_shards=2+    def f(x):+      y = x + 1+      y = set_sharding(y, PartitionSpec(2,1))+      return y * 2++  In this example, the inputs and outputs of ``f`` will be replicated, but the+  inner value of ``y`` will be partitioned in half. ``f`` will run on two+  devices due to the set_sharding call.++  Args:+    x: Array value+    partitions: PartitionSpec indicating how ``x`` should be partitioned, or+      None for replication.++  Returns:+    A new version of ``x`` with the specified sharding applied.

I added a bunch more verbiage in an attempt to explain this :) Lemme know if you have any suggestions to make it more clear.

skye

comment created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha b0a8bfca88dcbd1291526067198ab4e0177f0df4

Expand docstring

view details

push time in 8 days

Pull request review commentgoogle/jax

Add set_sharding method to be used within sharded_jit.

 def get_num_partitions(*partitions):   return num_partitions_set.pop()  +def _inner_partitions(jaxpr, expected_num_parts: Optional[int]):+  """Returns the total number of partitions from PartitonSpecs inside `jaxpr`.++  Also validates that this number matches `expected_num_parts` if provided.+  """+  for eqn in jaxpr.eqns:+    if eqn.primitive == set_sharding_p:+      parts = eqn.params["partitions"]+      nparts = get_num_partitions(parts)+      if expected_num_parts is None:+        expected_num_parts = nparts+      elif nparts != expected_num_parts:+        # TODO(skye): raise this error as we trace the jaxpr

Agreed we shouldn't bake any of this logic into the core tracing machinery. I like your tentative SPMD tracer idea. I was vaguely imagining using dynamic-scoped tracing and/or expanding pmap's DynamicAxisEnv concept to do final-style tracing for all parallel constructs, but it might be too complicated. I think we'll get more clarity as we expand the current initial-style transforms to compose with each other.

skye

comment created time in 8 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 1bfa298628c37822ffee4b9eda3298a739a58252

Update jax/interpreters/sharded_jit.py Co-authored-by: James Bradbury <jekbradbury@google.com>

view details

push time in 8 days

Pull request review commentgoogle/jax

Add sharded_jit translation rule.

 def _partitioned_sharding_spec(num_partitions: int,         replication_factor=1)  +def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,+                                  in_parts, out_parts_thunk, num_partitions,+                                  backend, name, call_jaxpr):+  subc = xc.XlaBuilder(f"sharded_jit_{name}")++  # We assume any extra leading in_nodes are constants and replicate them.+  num_extra_nodes = len(in_nodes) - len(in_parts)+  assert num_extra_nodes >= 0+  in_parts = (None,) * num_extra_nodes + in_parts++  args = []+  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):+    # We use xb.set_sharding instead of xb.with_sharding because inlined calls+    # shouldn't have shardings set directly on the inputs or outputs.+    arg = xb.parameter(subc, i, c.GetShape(n))+    args.append(xb.set_sharding(subc, arg, sharding))++  out_nodes = xla.jaxpr_subcomp(+      subc, call_jaxpr, backend, axis_env, (),+      extend_name_stack(name_stack, wrap_name(name, "sharded_jit")), *args)

True. It doesn't seem very principled or robust to depend on the name strings, but it would work in practice (unless you call a non-pmapped function "pmap" or something, but that's silly). I'm gonna leave it as-is for now, but will keep this in mind as a potential option.

skye

comment created time in 9 days

push eventskye/skye.github.io

Skye Wanderman-Milne

commit sha e4f24185330185f3cde27b1238cffcd5c0eff0b2

fix debugging

view details

push time in 9 days

push eventskye/skye.github.io

Skye Wanderman-Milne

commit sha 9945d83cc4fe92b4a247ac82b523a897953afc15

fixes

view details

push time in 9 days

push eventskye/skye.github.io

Skye Wanderman-Milne

commit sha 467716a5b2c5b1a6dc5d803aa44d732c0f127352

fixes

view details

push time in 9 days

push eventskye/skye.github.io

Skye Wanderman-Milne

commit sha c06e05958eda308eae37816154fc84935b05ba28

fixes

view details

push time in 9 days

push eventskye/skye.github.io

Skye Wanderman-Milne

commit sha c2fd4a26ed67228056def0fe44f2e4a211012b85

fixes

view details

push time in 9 days

more