profile
viewpoint

google/jax 6851

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

skye/community 0

Stores documents used by the TensorFlow developer community

skye/Impala 0

Real-time Query for Hadoop

skye/impala-udf-samples 0

Sample UDF and UDAs for Impala.

skye/jax 0

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

skye/kite 0

Kite SDK

push eventgoogle/jax

Skye Wanderman-Milne

commit sha f6d483739a4b02636301206f77bfefbcc35cd6d5

Shrink nested pmap image in demo notebook. (#2293)

view details

push time in 7 hours

PR merged google/jax

Shrink nested pmap image in demo notebook. cla: yes
+1 -1

0 comment

1 changed file

skye

pr closed time in 7 hours

PR opened google/jax

Shrink nested pmap image in demo notebook.
+1 -1

0 comment

1 changed file

pr created time in 7 hours

push eventskye/jax

Stephan Hoyer

commit sha 218a1711d2f0d8624f606a2e259f9e8d1a0fc374

Add a jit around lax_linalg.lu_pivots_to_permutation (#2277) I think this is almost always called inside a jit already, but adding this results in more interprettable JAXprs.

view details

brett koonce

commit sha 8372a70079698f3d602a1c52d5123d390da2010b

tweak readme pmap imports (#2276)

view details

Peter Hawkins

commit sha af0967fdbf1960d4f830c888103aa8624479c23d

Add an experimental lax.top_k operator. (#2280)

view details

Stephan Hoyer

commit sha 8c3e3b2dae0e4e9b7f6064499317349da3e57a70

Always jit scipy.ndimage.map_coordinates (#2286) Fixes GH2282

view details

Peter Hawkins

commit sha 80abdf0c5307f4b917c281428009e32c66f9f1a9

Unbreak build and update XLA. (#2289) * raise minimum Bazel version to 2.0.0 to match TensorFlow. * set --experimental_repo_remote_exec since it is required by the TF build. * bump TF/XLA version. * use the --config=short_logs trick from TF to suppress build warnings.

view details

Skye Wanderman-Milne

commit sha 7885b1d03482490d1ebb75587b38106b10037dfa

Add nested pmap image to demo notebook. (#2292)

view details

Skye Wanderman-Milne

commit sha 2f35d0854565764d918564dc2cb41ba7714ff859

Shrink nested pmap image in demo notebook.

view details

push time in 7 hours

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 7885b1d03482490d1ebb75587b38106b10037dfa

Add nested pmap image to demo notebook. (#2292)

view details

push time in 7 hours

PR merged google/jax

Add nested pmap image to demo notebook. cla: yes
+11 -1

0 comment

2 changed files

skye

pr closed time in 7 hours

PR opened google/jax

Add nested pmap image to demo notebook.
+11 -1

0 comment

2 changed files

pr created time in 7 hours

create barnchskye/jax

branch : nested_pmap_image

created branch time in 7 hours

push eventgoogle/jax

brett koonce

commit sha 8372a70079698f3d602a1c52d5123d390da2010b

tweak readme pmap imports (#2276)

view details

push time in 3 days

PR merged google/jax

tweak readme pmap imports
+2 -1

1 comment

1 changed file

brettkoonce

pr closed time in 3 days

pull request commentgoogle/jax

tweak readme pmap imports

Thanks!

brettkoonce

comment created time in 3 days

issue closedgoogle/jax

Failed to load Starlark extension '@com_github_grpc_grpc//bazel:grpc_deps.bzl'.

Hi,

When compiling on centos7, it directly crashes stating some extension are missing Is this a dependency missing?

Thanks,

Olivier

Bazel binary path: ./bazel-1.2.1-linux-x86_64
Python binary path: /opt/sw/arch/easybuild/2019b/software/Python/3.7.4-GCCcore-8.3.0/bin/python
Python version: 3.7
MKL-DNN enabled: yes
-march=native: yes
CUDA enabled: yes

Building XLA and installing it in the jaxlib source tree...
./bazel-1.2.1-linux-x86_64 run --verbose_failures=true --config=opt --config=mkl_open_source_only --config=cuda --define=xla_python_enable_gpu=true :install_xla_in_source_tree /usr/local/Software/jax/build
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /usr/local/Software/jax/.bazelrc:
  Inherited 'build' options: --repo_env PYTHON_BIN_PATH=/opt/sw/arch/easybuild/2019b/software/Python/3.7.4-GCCcore-8.3.0/bin/python --python_path=/opt/sw/arch/easybuild/2019b/software/Python/3.7.4-GCCcore-8.3.0/bin/python --repo_env TF_NEED_CUDA=1 --distinct_host_configuration=false --copt=-Wno-sign-compare -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
INFO: Found applicable config definition build:opt in file /usr/local/Software/jax/.bazelrc: --copt=-march=native --host_copt=-march=native
INFO: Found applicable config definition build:mkl_open_source_only in file /usr/local/Software/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:cuda in file /usr/local/Software/jax/.bazelrc: --crosstool_top=@local_config_cuda//crosstool:toolchain --define=using_cuda=true --define=using_cuda_nvcc=true
Loading:
Loading: 0 packages loaded
ERROR: Failed to load Starlark extension '@com_github_grpc_grpc//bazel:grpc_deps.bzl'.
Cycle in the workspace file detected. This indicates that a repository is used prior to being defined.
The following chain of repository dependencies lead to the missing definition.
 - @com_github_grpc_grpc
This could either mean you have to add the '@com_github_grpc_grpc' repository with a statement like `http_archive` in your WORKSPACE file (note that transitive dependencies are not added automatically), or move an existing definition earlier in your WORKSPACE file.
ERROR: cycles detected during target parsing
INFO: Elapsed time: 0.085s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (0 packages loaded)
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully (0 packages loaded)
Traceback (most recent call last):
  File "build/build.py", line 365, in <module>
    main()
  File "build/build.py", line 360, in main
    shell(command)
  File "build/build.py", line 47, in shell
    output = subprocess.check_output(cmd)
  File "/opt/sw/arch/easybuild/2019b/software/Python/3.7.4-GCCcore-8.3.0/lib/python3.7/subprocess.py", line 395, in check_output
    **kwargs).stdout
  File "/opt/sw/arch/easybuild/2019b/software/Python/3.7.4-GCCcore-8.3.0/lib/python3.7/subprocess.py", line 487, in run
    output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['./bazel-1.2.1-linux-x86_64', 'run', '--verbose_failures=true', '--config=opt', '--config=mkl_open_source_only', '--config=cuda', '--define=xla_python_enable_gpu=true', ':install_xla_in_source_tree', '/usr/local/Software/jax/build']' returned non-zero exit status 1.

closed time in 3 days

oliviermattelaer

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 0a78c8c3015b7a628b3f5e157e4d952e58d55aef

Update WORKSPACE to pick up TF grpc upgrade and cleanup.

view details

push time in 3 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha cfb5666ac1ac9889625ddc10d02ed64914a883e7

Update WORKSPACE to setup upstream dependencies. (#2267) This is necessary as of https://github.com/tensorflow/tensorflow/commit/f396035891b0938364ea247a7dd243a147930c6e. Many thanks to @lezh for this fix!

view details

push time in 4 days

PR merged google/jax

Update WORKSPACE to setup upstream dependencies. cla: yes

This is necessary as of https://github.com/tensorflow/tensorflow/commit/f396035891b0938364ea247a7dd243a147930c6e.

Many thanks to @lezh for this fix!

Co-authored-by: Oliver Le Zhuang izhuangle@gmail.com

+24 -0

2 comments

1 changed file

skye

pr closed time in 4 days

push eventskye/jax

Skye Wanderman-Milne

commit sha bce230321b52f28541f56d1ea5e15ca2992604b5

Update WORKSPACE to setup upstream dependencies. This is necessary as of https://github.com/tensorflow/tensorflow/commit/f396035891b0938364ea247a7dd243a147930c6e. Many thanks to @lezh for this fix!

view details

push time in 4 days

PR opened google/jax

Update WORKSPACE to setup upstream dependencies.

This is necessary as of https://github.com/tensorflow/tensorflow/commit/f396035891b0938364ea247a7dd243a147930c6e.

Many thanks to @lezh for this fix!

Co-authored-by: Oliver Le Zhuang izhuangle@gmail.com

+24 -0

0 comment

1 changed file

pr created time in 4 days

push eventskye/jax

Skye Wanderman-Milne

commit sha f4ee487f011d3ac47c4046f6cd6df490bce71396

Update WORKSPACE to setup upstream dependencies. This is necessary as of https://github.com/tensorflow/tensorflow/commit/f396035891b0938364ea247a7dd243a147930c6e. Co-authored-by: Oliver Le Zhuang <izhuangle@gmail.com>

view details

push time in 4 days

create barnchskye/jax

branch : workspace

created branch time in 4 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 18420936c4ce5aca7349766de66c21ea8efdec53

_scatter_jvp bug fix (#2231)

view details

push time in 8 days

PR merged google/jax

Reviewers
_scatter_jvp bug fix cla: yes
+4 -2

0 comment

2 changed files

skye

pr closed time in 8 days

Pull request review commentgoogle/jax

masking.py bug fix

 def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,   new_operand = pad(new_operand, _zero(operand),                     ((0, 1, 0),) + tuple((0, 0, 0) for _ in operand_shape)) -  ids_shape = onp.array(updates_shape)+  ids_shape = onp.array(updates_shape, dtype=onp.int32)

Added a comment.

skye

comment created time in 9 days

push eventskye/jax

Skye Wanderman-Milne

commit sha fb95e3d7c7e9e2740580e5041bda79ccd5f79783

add comment

view details

push time in 9 days

Pull request review commentgoogle/jax

masking.py bug fix

 def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,   new_operand = pad(new_operand, _zero(operand),                     ((0, 1, 0),) + tuple((0, 0, 0) for _ in operand_shape)) -  ids_shape = onp.array(updates_shape)+  ids_shape = onp.array(updates_shape, dtype=onp.int32)

I'm not sure. Paging @mattjj

skye

comment created time in 9 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 68dc2dfbc565a4141dfa6fe9f1c7e0cf032b296a

take 2

view details

push time in 9 days

PR opened google/jax

Reviewers
masking.py bug fix
+2 -2

0 comment

1 changed file

pr created time in 9 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 393938f38720e1c49553bd971f7ed4bd12239110

Update WORKSPACE to TF commit that builds with build_jaxlib_wheels.sh

view details

Skye Wanderman-Milne

commit sha b6dfb8bf18482eff334f9eeb50a09fe9de2e3e7f

Bump minimum bazel version to 0.26.0. (#2060) Fixes #2044

view details

Skye Wanderman-Milne

commit sha 9aba39e9be15b7758f328f37d9c2376d02a3b975

Revert lax_numpy.asclose() behavior to work with lists again. (#2059) This should be revisited to fix the issue originally addressed in https://github.com/google/jax/pull/2051.

view details

Skye Wanderman-Milne

commit sha f1339cd0b0f717a0e21697a48c3b59abd962b6ae

Remove missing PPA in Dockerfile. (#2061) This PPA has been removed by the owner: https://launchpad.net/~jonathonf/+archive/ubuntu/python-3.6 This causes `apt-get update` to fail when generating the Docker image. We don't seem to need this repository, so just remove it before calling `apt-get update`.

view details

Ziyad Edher

commit sha 0c95c26e97d6f53475086d640eef80a740954c78

Implement np.linalg.matrix_power (#2042) * Implement numpy.linalg.matrix_power * Write tests for numpy.linalg.matrix_power * Check for input matrix shapes * Move to matrix-multiplication operator in matrix power * Improve error messages and directly use broadcasting * Include matrix_power in documentation

view details

Peter Hawkins

commit sha 632326ac5c95e83c05201a6a88168cf682496c96

Add unsupported wrapper around XLA RngUniform API. (#2068)

view details

Ziyad Edher

commit sha 0fca476c54b62fbbfe50ce36042a2a490c54b78c

Implement np.linalg.matrix_rank (#2008) * Implement np.linalg.matrix_rank * Test np.linalg.matrix_rank * Use helper numpy testing function * Fix issue with 1D matrix rank procedure * Add new tests for 1D matrices and jit * Do not check dtypes to circumvent int32 vs int64 * Include documentation for matrix_rank * Fix ordering * Use np.sum

view details

James Bradbury

commit sha a15aa9bd4d97cfaa65baf31ecc8ac47ae584ac7b

include call stack + transforms in XLA metadata (#2073)

view details

Roman Novak

commit sha 6a4bb9516925f28f613387d7fdc45e2643fc1de6

Mare the reverse operator work on empty list of dimensions Example that this fixes: ``` from jax import lax import jax.numpy as np from jax.api import jacrev x = np.ones((3, 5)) def f(x): return lax.conv_general_dilated(lhs=x, rhs=np.ones((5, 2)), window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) jacrev(f)(x) ``` currently gives ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-136-2ad65e41f1de> in <module>() 12 dimension_numbers=('NC', 'IO', 'NC')) 13 ---> 14 jacrev(f)(x).shape 15 frames google3/third_party/py/jax/api.py in jacfun(*args, **kwargs) 514 y, pullback = vjp(f_partial, *dyn_args) 515 holomorphic or tree_map(_check_real_output_jacrev, y) --> 516 jac = vmap(pullback)(_std_basis(y)) 517 jac = jac[0] if isinstance(argnums, int) else jac 518 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args google3/third_party/py/jax/api.py in batched_fun(*args) 692 _check_axis_sizes(in_tree, args_flat, in_axes_flat) 693 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat, --> 694 lambda: _flatten_axes(out_tree(), out_axes)) 695 return tree_unflatten(out_tree(), out_flat) 696 google3/third_party/py/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests) 38 def batch(fun, in_vals, in_dims, out_dim_dests): 39 size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} ---> 40 out_vals, out_dims = batch_fun(fun, in_vals, in_dims) 41 return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals) 42 google3/third_party/py/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims) 44 with new_master(BatchTrace) as master: 45 fun, out_dims = batch_subtrace(fun, master, in_dims) ---> 46 out_vals = fun.call_wrapped(*in_vals) 47 del master 48 return out_vals, out_dims() google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 150 gen = None 151 --> 152 ans = self.f(*args, **dict(self.params, **kwargs)) 153 del args 154 while stack: google3/third_party/py/jax/api.py in _vjp_pullback_wrapper(fun, cotangent_dtypes, io_tree, py_args) 1237 "match type of corresponding primal output ({})") 1238 raise TypeError(msg.format(_dtype(a), dtype)) -> 1239 ans = fun(*args) 1240 return tree_unflatten(out_tree, ans) 1241 google3/third_party/py/jax/interpreters/ad.py in vjp_(*cts) 114 dummy_primals_and_cts = (core.unit,) * len(cts) + cts 115 dummy_args = (undefined_primal,) * len(jaxpr.invars) --> 116 _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts) 117 arg_cts = arg_cts[len(primals):] 118 return map(instantiate_zeros, primals, arg_cts) google3/third_party/py/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in) 222 map(write_cotangent, bound_vars, ct_free_vars_out) 223 else: --> 224 cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) 225 cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out 226 map(write_cotangent, eqn.invars, cts_out) google3/third_party/py/jax/interpreters/ad.py in bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs) 505 assert (x is undefined_primal) ^ (y is undefined_primal) 506 if x is undefined_primal: --> 507 out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs) 508 return out, None 509 else: google3/third_party/py/jax/lax/lax.py in _conv_general_dilated_transpose_lhs(g, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, lhs_shape, rhs_shape, precision) 2042 window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation, 2043 rhs_dilation) -> 2044 revd_weights = rev(rhs, rhs_sdims) 2045 return conv_general_dilated( 2046 g, revd_weights, window_strides=lhs_dilation, padding=padding, google3/third_party/py/jax/lax/lax.py in rev(operand, dimensions) 671 operator. 672 """ --> 673 return rev_p.bind(operand, dimensions=tuple(dimensions)) 674 675 def select(pred, on_true, on_false): google3/third_party/py/jax/core.py in bind(self, *args, **kwargs) 157 top_trace = find_top_trace(args) 158 if top_trace is None: --> 159 return self.impl(*args, **kwargs) 160 161 tracers = map(top_trace.full_raise, args) google3/third_party/py/jax/interpreters/xla.py in apply_primitive(prim, *args, **params) 159 def apply_primitive(prim, *args, **params): 160 """Impl rule that compiles and runs a single primitive 'prim' using XLA.""" --> 161 compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params) 162 return compiled_fun(*args) 163 google3/third_party/py/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params) 167 device = _device_from_arg_devices(arg_devices) 168 backend = xb.get_device_backend(device) --> 169 aval_out = prim.abstract_eval(*avals, **params) 170 if not prim.multiple_results: 171 handle_result = aval_to_result_handler(device, aval_out) google3/third_party/py/jax/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs) 1540 return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs)) 1541 elif least_specialized is ShapedArray: -> 1542 return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)) 1543 elif least_specialized is UnshapedArray: 1544 return UnshapedArray(dtype_rule(*args, **kwargs)) google3/third_party/py/jax/lax/lax.py in _rev_shape_rule(operand, dimensions) 2620 msg = 'rev dimensions must be unique, got {}.' 2621 raise TypeError(msg.format(dimensions)) -> 2622 if not _max(dimensions) < operand.ndim: 2623 msg = ('rev dimensions must all be less than operand ndim, got dimensions ' 2624 '{} for operand ndim {}.') ValueError: max() arg is an empty sequence ```

view details

Peter Hawkins

commit sha ffa198e8ef14aa538254636ef1de6e06b389b496

Fix test failure on TPU. (#2088) Update GUARDED_BY annotations to use newer ABSL_GUARDED_BY form.

view details

Peter Hawkins

commit sha 832bb71c5dcfcbb35ca1f1f0dcaa1056af5dce7e

Add missing BUILD dependency. (#2089)

view details

Aidan Dang

commit sha e0ed5adc755ae1c16ae77a65b466844e424df957

Allow JVP for SVD when not computing singular vectors (#2076) * Allow SVD JVP when not computing singular vectors * Test SVD JVP when not computing full singular vecs

view details

Chase Roberts

commit sha 82d6c6ce518d4d6925c42f3cd2cc1fd1a2d15146

Added better error messages. (#2058) #2057 Added better error messages for when a user accidentally uses a python cast instead of a the `jax.numpy` casting.

view details

Roman Novak

commit sha 95ccaae8058f8fb49c81680f1f9061bf96d8d95e

Add test for empty dimension list for reversion

view details

Peter Hawkins

commit sha 21551c2378b1f56f358d72beb814befb38ff0b25

Bump JAX version to 0.1.58.

view details

Peter Hawkins

commit sha 0fe601227fd5f6e9c1c7a6e650167f6113fe7322

Update README.md and CHANGELOG.md.

view details

Peter Hawkins

commit sha 9a0338d6aa1b6006be98983eb3d33c8507dcd383

Update README.md and CHANGELOG.md. (#2096)

view details

Peter Hawkins

commit sha 55f2d3be27eaf0f75aac9b2937e6fab87076315a

Update Jaxlib docker build. * work around https://github.com/bazelbuild/bazel/issues/9254 by setting BAZEL_LINKLIBS=-lstdc++ * drop CUDA 9.0 support, since we use a batched kernel only present in CUDA 9.2 or later. * drop Python 2.7 support.

view details

Peter Hawkins

commit sha 58f949f3168ebdc0264448022f90ba8271746356

Merge pull request #2098 from hawkinsp/jaxlib Update Jaxlib docker build.

view details

Peter Hawkins

commit sha b54c18efb4e30831c77cd8698dcdaa7864e74440

Use Device hash and equality instead of using a (class, id) pair. We couldn't figure out why we did it this way in the first place and all the tests we have pass.

view details

push time in 9 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 2a6e60fd50f7779641302f7643ed57eec27e403e

Bump jaxlib version in README to 0.1.39

view details

push time in 9 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 2ff7019923e1027ef13c9170cbd06a22086ba7de

Update WORKSPACE

view details

push time in 9 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 89edd5eaf0672ae1853407996d873e8e57e0f71b

Update WORKSPACE

view details

push time in 10 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha bf91ebf67a22a139059416b6bf52277b9f90562e

Return error number in build.py on bad bazel version. (#2218) This prevents our build scripts from continuing on error.

view details

push time in 11 days

PR merged google/jax

Return error number in build.py on bad bazel version. cla: yes

This prevents our build scripts from continuing on error.

+2 -2

0 comment

1 changed file

skye

pr closed time in 11 days

PR opened google/jax

Return error number in build.py on bad bazel version.

This prevents our build scripts from continuing on error.

+2 -2

0 comment

1 changed file

pr created time in 11 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 96a65de6c89cb07001120d9681d2d8d5007e2416

Try downloading bazel before using pre-installed bazel. (#2217) This ensures we're using the right bazel version.

view details

Skye Wanderman-Milne

commit sha c5e8fc81b53a2ffc6fcb7af4e85bf0a1c81ed137

Return error number in build.py on bad bazel version. This prevents our build scripts from continuing on error.

view details

push time in 11 days

create barnchskye/jax

branch : errno

created branch time in 11 days

push eventskye/jax

Skye Wanderman-Milne

commit sha c2f07d2cb9db3400eae5b79001273d37b2626a40

Return error number in build.py on bad bazel version. This prevents our build scripts from continuing on error.

view details

push time in 11 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 96a65de6c89cb07001120d9681d2d8d5007e2416

Try downloading bazel before using pre-installed bazel. (#2217) This ensures we're using the right bazel version.

view details

push time in 11 days

PR merged google/jax

Try downloading bazel before using pre-installed bazel. cla: yes

This ensures we're using the right bazel version.

+2 -2

0 comment

1 changed file

skye

pr closed time in 11 days

PR opened google/jax

Try downloading bazel before using pre-installed bazel.

This ensures we're using the right bazel version.

+2 -2

0 comment

1 changed file

pr created time in 11 days

create barnchskye/jax

branch : always_download_bazel

created branch time in 11 days

created taggoogle/jax

tagjaxlib-v0.1.39

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

created time in 11 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha c6b65c0a11db8ffbe1f9bbde6752c1152480a1b1

Bump jaxlib version to 0.1.39 and update WORKSPACE.

view details

push time in 11 days

push eventskye/jax

Sharad Vikram

commit sha 76d77bfc147e0e281d0953260e3aded64e425230

Fix inconsistent indentation in `JaxprTrace.default_process_primitive`.

view details

George Necula

commit sha 0b31342868232ab38dff0b1f34055970b6ad0f66

Added License to new files

view details

George Necula

commit sha e6f50dcb7d9bf1711ff7062a646dc8c3bf20059e

Merge pull request #2202 from gnecula/bug_fix Added License to new files

view details

George Necula

commit sha e81024f5053def119eddb7fb06ff6c4f7b5948a8

Disable linalg_test.testCond. Issue: 2203 This test was added in #2125 but is failing in internal tests.

view details

George Necula

commit sha 3e470af6427f8bd55096a107f5d989b24d7a1a23

Merge pull request #2204 from gnecula/bug_fix Disable linalg_test.testCond.

view details

George Necula

commit sha f3e9bef33eb4c118a4bde4c3fb7646c21775c988

Removed copyright code from third-party/numpy

view details

Peter Hawkins

commit sha e9d06ecf53a5227bb324b682fab74b628430ff9d

Reenable convolution gradient tests on TPU that now pass. (#2207)

view details

Tom Hennigan

commit sha 4c682b46bb05301212b0cfe8affa970e1feccde4

Add missing sources to jax build. (#2208)

view details

George Necula

commit sha dff64e0297a6bf74a50c66e4168ecfb957a90134

Fixed link to Google 3rd party OSS components

view details

George Necula

commit sha f6bd0a79f764a8e077095c723a312190565c1b8b

Merge pull request #2205 from gnecula/bug_fix Removed copyright from third-party/numpy

view details

Anselm Levskaya

commit sha 28e802c6f15d7ed87fee9bee6270137423b9967c

Fix Gotchas notebook regarding control flow differentiation. (#2194)

view details

George Necula

commit sha 20f9230f6e14d6f634a9a02b1185d6bb687905f2

Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in params. The goal is to make the Jaxpr language more uniform: all higher-order primitives carry sub-Jaxprs that are part of the parameters, and they are all called xxx_jaxpr. As a side-effect, some code is simplified (e.g., the code that searches for sub-jaxprs). For now the code assumes that all the `call` (final-style) primitives carry exactly one subjaxpr with the parameter name `call_jaxpr`. These primitives are still processed differently in the internal code, but there is no reason any external consumer of a Jaxpr needs to know this.

view details

George Necula

commit sha f4b946ef23b74edfa19784f5d0b503484b15e7b1

Merge pull request #2199 from sharadmv/patch-1 Fix inconsistent indentation in `JaxprTrace.default_process_primitive`.

view details

George Necula

commit sha fb7e48f756ce939a55171156f29992634a5258ab

Merge pull request #2176 from gnecula/simple_jaxpr2 Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in…

view details

Tom Hennigan

commit sha 9797ea2485540ac7bf9e0f48cba4a4c7f0a6d8bc

Implement size/ndim/__len__/repr/str/eq/hash for ShapeDtypeStruct. (#2206)

view details

Matthew Johnson

commit sha 9e6fe64a66a4fac78ef9c8e57bb0818e4af6b619

bump version and update changelog for pypi

view details

Matthew Johnson

commit sha e41e24358c4956faa428e9e3a429f8f099811af1

fix multi-device error messages (#2213) * fix multi-device error messages, travis tests * don't run multi-device tests on travis (segfaults) * fix typo

view details

Skye Wanderman-Milne

commit sha 5e77789afe6118b80341d341f09b8b390edbf703

Update Cloud TPU email address to jax-tpu@googlegroups.com

view details

Skye Wanderman-Milne

commit sha 323b35017644e86041b6d3d1c85a340f45e334f0

Allow ShardedDeviceArrays to represent arbitrary data shardings. This change adds ShardedDeviceArray.logical_indices, which specifies what part of the full logical array each device buffer represents. The logical indices can be ints, slice objects, or tuples thereof. Previously, a ShardedDeviceArray's shards always represented a single row of the leading axis, but the indices allow specifying a wider range of shard shape, e.g. multi-dimensional tiles. This also removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. I also added a microbenchmark to make sure this doesn't regress the shard_args fast path. The results are pretty noisy, but it looks like this is about the same speed as before. This is how I run the benchmark: TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py

view details

push time in 12 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 5e77789afe6118b80341d341f09b8b390edbf703

Update Cloud TPU email address to jax-tpu@googlegroups.com

view details

push time in 12 days

pull request commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

@gnecula @hawkinsp do you have any more comments or should I merge? (We can also make more follow-up changes, but I'd like to get this in soon since it's a pretty big change to leave unmerged for a long time.)

skye

comment created time in 13 days

PR closed google/jax

Reviewers
Add missing copyright notices. cla: yes
+43 -3

0 comment

3 changed files

skye

pr closed time in 14 days

PR opened google/jax

Reviewers
Add missing copyright notices.
+43 -3

0 comment

3 changed files

pr created time in 14 days

create barnchskye/jax

branch : add_copyright

created branch time in 14 days

push eventskye/jax

Ruizhe Zhao

commit sha 8c7fc3919d3e131da6a2121158084ed480dbec2a

Upgrade bazel from 0.29.1 to 1.2.1 (#2137)

view details

Peter Hawkins

commit sha fe041c75900023c1774bc81b59b094f38250f3b8

Set minimum Bazel version to 1.2.1.

view details

Colin

commit sha d6489103f754674eb5f16ded961bbbbc2c5817e5

Bump cell execution timeout (#2147) Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to - Cell timeouts (which this tries to fix), - Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks), - Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).

view details

Roman Novak

commit sha 1022573b26a1996db524229de10fb84dbe6e08b3

Make stax pooling layers accept `spec=None` (#2145) Currently pooling layers have a default channel-last spec that is explicitly 2D. This change will make this default work for arbitrary input dimensionality.

view details

Stephan Hoyer

commit sha 0644f5c56175104d862cf7e03fe6f7cd14cdba88

Better batching rule for triangular_solve (#2138) * Better batching rule for triangular_solve Now, if only the right hand side argument `b` is batched, we leverage triangular solve's builtin batching for handling multiple right-hand-side vectors. This makes the performance of `vmap` over only the second argument of linear solves equivalent to relying on builtin batching:: rs = onp.random.RandomState(0) a = rs.randn(500, 500) + 0.1 * np.eye(500) b_mat = jax.device_put(rs.randn(500, 10)) solve1 = jax.jit(np.linalg.solve) solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1)) Before:: In [6]: %timeit jax.device_get(solve1(a, b_mat)) 3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) # 8x slower :( In [9]: %timeit jax.device_get(solve2(a, b_mat)) 23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) Now:: In [2]: %timeit jax.device_get(solve1(a, b_mat)) 3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) # same speed :) In [3]: %timeit jax.device_get(solve2(a, b_mat)) 3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) * Test failures * Check b.ndim == 2 in triangular solve shape rule

view details

Peter Hawkins

commit sha 0b1d2fc3d187f779934cfaeb9188e1fcb208a6fc

Avoid accidental type promotion in gamma sampler gradient. (#2150) Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.

view details

Peter Hawkins

commit sha 3c9ae5e221316c82f1dda34aa7f12173b12e3a21

Add jax.scipy.stats.logistic to documentation. (#2149)

view details

George Necula

commit sha 4f5987ccd9c4447955d8dc3463613ac3bc44b6a3

Simplify Jaxpr: remove freevars. Freevars played a very small role, and they can be folded with the invars. This simplifies the Jaxpr data structure.We remove the `freevars` field from Jaxpr and from the bound_subjaxprs. The only non-trivial change is for xla_pmap, where we need to carry one extra parameter `mapped_invars` with a bitmap to encode which invars are mapped and which are broadcast. Previously, the freevars were broadcast.

view details

George Necula

commit sha a955fd9deee29f6de5023bf4077d678a56084785

Updated notebook that refered to freevars

view details

Stephan Hoyer

commit sha 2d0b8c2c609829e9745c1a1bc64c0fcf777fc899

Fix precision in triangular solve batching test for TPUs (#2159)

view details

Skye Wanderman-Milne

commit sha 7404e88b358a377cdf8d8e580349311185184af6

Adjust scipy_stats_test.py tolerance.

view details

Skye Wanderman-Milne

commit sha b19f7e935781bd848525c418c5a093b1d200dca5

WIP sharded_jit implementation (#2158)

view details

Anselm Levskaya

commit sha ffc55ee6008c054a2e58d01f64ba0ced36b36048

Update linspace edgecase to match numpy fix. (#2162) * Update linspace edgecase to match numpy fix. * only test fixed linspace behavior against newer numpy * remove unneeded version pkg

view details

George Necula

commit sha 272620e66cd2f9c686a45306f469df873f535fc7

Added note to CHANGELOG.md

view details

Jonas Adler

commit sha 4080a1c2ce95dc4a90f899fe4bf9ad5ac6a7b8b3

Add np.fft.fftshift/ifftshift (#1850)

view details

Lukas Prediger

commit sha ddc83e093778e227c7688f6ad16888b211a554ef

Added dtype arg for NN initializer factory methods (#2034) * Added dtype arg for NN initializer factory methods Initializer factories in jax/nn/initializers.py (such as uniform(), normal(), glorot_normal(), etc) now have an optional `dtype` argument. The value passed in that argument becomes the default value for the same `dtype` argument of the initializer function returned by the factory. * fixed failed test for delta_orthogonal in d12cdc47

view details

George Necula

commit sha 862a1d594b67845f480bba7b342afb134ffc1a14

Moved the mapped_invars parameter setting to the process_map

view details

George Necula

commit sha d01210e9e338b8051da78fcc104b404b82ffd8a0

Merge pull request #1959 from gnecula/no_freevars An attempt to remove freevars from JAXPR.

view details

George Necula

commit sha b18a4d8583c0e11e228a0792793d6f6e99292766

Disabled tests known to fail on Mac, and optionally slow tests. Issue: #2166 Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known to be slow.

view details

Pavel Sountsov

commit sha b2ef5bc09552e8ed39759df4ff49ea97e32db708

Canonicalize the shape in the wrapper functions in random.py. (#2165) * Canonicalize the shape in the wrapper functions in random.py. This lets the user be more sloppy in using numpy arrays and statically known DeviceArrays for shapes, and still hit the jit cache. When they are not, the error is improved. * Fix some errors. * No need for the Poly workaround. * Bypass canonicalization for None shapes in random.py.

view details

push time in 14 days

issue commentgoogle/jax

Add support for buffer donation (input/output aliasing)

And yes, I didn't get a chance to look at this after all. I forgot to ping here, sorry!

hawkinsp

comment created time in 16 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from . import core+++def get_var_partition_ids(jaxpr, partition_ids):+  """Returns a map from each `Var` to a partition ID."""+  var_partition_ids = {core.unitvar: partition_ids[0]}+  # FIXME(cjfj): For now, all inputs are assumed to be on inner partition 0.+  var_partition_ids.update((v, partition_ids[0]) for v in jaxpr.invars)+  var_partition_ids.update((v, partition_ids[0]) for v in jaxpr.constvars)++  for eqn in jaxpr.eqns:+    if eqn.primitive.name == "partition":+      inner_jaxpr = eqn.params["jaxpr"].jaxpr+      inner_partition_ids = eqn.params["partition_ids"]+      inner_var_partition_ids = get_var_partition_ids(inner_jaxpr, inner_partition_ids)+      var_partition_ids.update(+          (v1, inner_var_partition_ids[v]) for v, v1 in zip(inner_jaxpr.outvars, eqn.outvars))+    else:+      if eqn.primitive.name == "partition_put":+        # Map from 'inner' partition ID to the global partition ID.+        partition_id = partition_ids[eqn.params["partition_id"]]+      else:+        invars = [v for v in eqn.invars if not isinstance(v, core.Literal)]+        if invars:+          input_pids = [var_partition_ids[v] for v in invars]+          partition_id = input_pids[0]+          if any(pid != partition_id for pid in input_pids[1:]):+            raise ValueError("mismatched partition IDs {}".format(input_pids))

Maybe print the equation or something, some hint as to where the mismatch happened

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++import functools++from absl.testing import absltest++from jax import abstract_arrays+from jax import api+from jax import core+from jax import test_util as jtu+from jax.config import config+from jax.interpreters import xla+from jax.lax import lax+from jax.lax import lax_partition+from jax.lib import xla_bridge as xb+import numpy as onp+++config.parse_flags_with_absl()+++class PartitionTest(jtu.JaxTestCase):++  def test_jaxpr_to_xla(self):+    new_var = core.gensym('')+    a = new_var()+    b = new_var()+    c = new_var()+    d = new_var()+    e = new_var()++    partition_put_eqn0 = core.JaxprEqn(+        invars=(a,), outvars=(c,), primitive=lax_partition.partition_put_p,+        bound_subjaxpr=None, params={'partition_id': 1})+    partition_put_eqn1 = core.JaxprEqn(+        invars=(b,), outvars=(d,), primitive=lax_partition.partition_put_p,+        bound_subjaxpr=None, params={'partition_id': 1})+    add_eqn = core.JaxprEqn(+        invars=(c, d), outvars=(e,), primitive=lax.add_p,+        bound_subjaxpr=None, params={})++    partition_jaxpr = core.Jaxpr(+        constvars=(),+        invars=(a, b),+        outvars=(e,),+        eqns=(partition_put_eqn0, partition_put_eqn1, add_eqn))++    partition_typed_jaxpr = core.TypedJaxpr(+        jaxpr=partition_jaxpr,+        literals=(),+        in_avals=[abstract_arrays.ShapedArray((), onp.float32)] * 2,+        out_avals=[abstract_arrays.ShapedArray((), onp.float32)])++    jaxpr = core.Jaxpr(+        constvars=(),+        invars=(a, b),+        outvars=(c,),+        eqns=(core.JaxprEqn(+            invars=(a, b),+            outvars=(c,),+            primitive=lax_partition.partition_p,+            bound_subjaxpr=None,+            params={'jaxpr': partition_typed_jaxpr, 'partition_ids': [0, 1]}),))++    cb = xb.make_computation_builder('xla_computation')+    xla_args = [+        cb.ParameterFromNumpy(onp.array((), dtype=onp.float32)),+        cb.ParameterFromNumpy(onp.array((), dtype=onp.float32))+    ]+    outs = xla.jaxpr_subcomp(cb, jaxpr, 'cpu', None, (), '', *xla_args)+    computation = cb.Build(*outs)++    self.assertIn('sharding={maximal device=1}', computation.GetHloText())++  def test_jaxpr(self):+    @functools.partial(lax_partition.partition, num_partitions=2)+    def f(x):+      y = x * x+      z = 38. + lax_partition.partition_put(y, 1)+      return z++    jaxpr = api.make_jaxpr(f)(2.)+    self.assertEqual(+        '{ lambda  ; a.\n'+        '  let b = partition[ jaxpr={ lambda  ; a.\n'+        '                             let b = mul a a\n'+        '                                 c = partition_put[ partition_id=1 ] b\n'+        '                                 d = add c 38.0\n'+        '                             in [d] }\n'+        '                     partition_ids=(0, 1) ] a\n'+        '  in [b] }', str(jaxpr))++  def test_xla_computation(self):+    @functools.partial(lax_partition.partition, num_partitions=2)+    def f(x):+      y = x * x+      z = 38. + lax_partition.partition_put(y, 1)+      return z++    computation = api.xla_computation(f)(2.)+    self.assertIn('sharding={maximal device=0}', computation.GetHloText())+    self.assertIn('sharding={maximal device=1}', computation.GetHloText())++  def test_simple(self):+    if xb.device_count() < 2:+      self.skipTest('requires two devices')++    @functools.partial(lax_partition.partition, num_partitions=2)+    def f(x):+      y = x * x+      z = 38. + lax_partition.partition_put(y, 1)+      return z++    self.assertEqual(42., f(2.))

Also inspect which device the result is on? (here and in the following tests)

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC

2020

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from contextlib import contextmanager+from functools import partial+from functools import wraps+import itertools+import threading++import numpy as onp++from .. import core+from .. import linear_util as lu+from .. import partition_util+from ..abstract_arrays import raise_to_shaped+from ..api_util import flatten_fun+from ..config import flags+from ..interpreters import ad+from ..interpreters import partial_eval as pe+from ..interpreters import xla+from ..lib import xla_bridge as xb+from ..lib import xla_client as xc+from ..tree_util import tree_flatten+from ..tree_util import tree_map+from ..tree_util import tree_unflatten+from ..util import cache+from ..util import unzip2+++FLAGS = flags.FLAGS+++class _ThreadLocalState(threading.local):+  def __init__(self):+    self.partition_ids_stack = []+_thread_local_state = _ThreadLocalState()+++def _get_partition_ids(num_partitions, outer_partition_ids):+  """Returns the partitions IDs after remapping the `outer_partition_ids`."""+  partition_ids_stack = _thread_local_state.partition_ids_stack+  if partition_ids_stack:+    if outer_partition_ids is None:+      raise ValueError(+          "`outer_partition_ids` must be specified for nested calls to "+          "`jax.partition`.")+    elif len(outer_partition_ids) != num_partitions:+      raise ValueError(+          "`outer_partition_ids` length ({}) did not match the number of "+          "partitions ({})".format(len(outer_partition_ids), num_partitions))+    else:+      return tuple(partition_ids_stack[-1][i] for i in outer_partition_ids)+  else:+    if outer_partition_ids is not None:+      raise ValueError(+          "`outer_partition_ids` must not be set for outermost `jax.partition` "+          "call")+    else:+      return tuple(range(num_partitions))+++def _abstractify(x):+  return raise_to_shaped(core.get_aval(x))+++def partition(fun, num_partitions):+  """Partitions the given function across multiple devices.++  Operations will be executed on the same partition as their inputs (throwing an+  error if there is a mismatch). Values can be moved between partitions using+  `partition_put`.++  Example:+    ```+    @partial(partition, num_partitions=2)+    def f(w):+      x = w ** 2+      y = jax.partition_put(x, 1)+      z = 42 * y+      return z++    result = f(1.)+    ```++    `x` will be computed on partition 0, `z` on partition 1.++  Note that if a call to a partitioned function is nested inside another, it+  must specify `outer_partition_ids`. See `partition_put` for more information.++  Args:+    fun: The function to be partitioned.+    num_partitions: The number of partitions.++  Returns:+    A partitioned function, with an additional, keyword-only parameter,+    `outer_partition_ids` that must be used when called within another+    partitioned function.+  """+  @wraps(fun)+  def fun_partitioned(*args, outer_partition_ids=None, **kwargs):+    partition_ids = _get_partition_ids(num_partitions, outer_partition_ids)++    _thread_local_state.partition_ids_stack.append(partition_ids)+    try:+      in_vals, in_tree = tree_flatten((args, kwargs))+      in_avals = _map(_abstractify, in_vals)+      in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]+      fun1, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)+      jaxpr, out_pvals, consts = pe.trace_to_jaxpr(+          fun1, in_pvals, instantiate=True, stage_out_calls=True)+      out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])+      const_avals = _map(_abstractify, consts)+      jaxpr = core.TypedJaxpr(+          pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals)+      outs = partition_p.bind(*itertools.chain(consts, in_vals),+                              jaxpr=jaxpr, partition_ids=partition_ids)+      return tree_unflatten(out_tree(), outs)+    finally:+      _thread_local_state.partition_ids_stack.pop()++  return fun_partitioned+++def _partition_impl(*args, **params):+  return _partition_callable(*map(xla.arg_spec, args), **params)(*args)+++@cache()+def _partition_callable(*args, jaxpr, partition_ids):+  nparts = max(partition_ids) + 1++  if nparts > xb.device_count():+    msg = ("compiling computation that requires {} devices, but only {} XLA "+           "devices are available")+    raise ValueError(msg.format(nparts, xb.device_count()))++  abstract_args, arg_devices = unzip2(args)+  result_handlers = tuple(map(partial(xla.aval_to_result_handler, None), jaxpr.out_avals))++  var_partition_ids = partition_util.get_var_partition_ids(jaxpr.jaxpr, partition_ids)+  out_partition_ids = tuple(var_partition_ids[v] for v in jaxpr.jaxpr.outvars)++  tuple_args=(len(abstract_args) > 100)+  built = _partition_computation(+      *abstract_args, tuple_args=tuple_args, jaxpr=jaxpr, partition_ids=partition_ids)++  device_assignment = xb.get_backend().get_default_device_assignment(1, nparts)+  device_assignment = onp.vectorize(lambda d: d.id)(onp.array(device_assignment))++  options = xb.get_compile_options(+      num_replicas=1, num_partitions=nparts, device_assignment=device_assignment)+  compiled = built.Compile(compile_options=options, backend=xb.get_backend())++  return xla.create_execute_fn(+      compiled, device_assignment=device_assignment, backend=None,+      tuple_args=tuple_args, result_handlers=result_handlers,+      out_partition_ids=out_partition_ids)+++@cache()+def _partition_computation(*avals, tuple_args, jaxpr, partition_ids):+  c = xb.make_computation_builder('partition_computation')+  backend = xb.get_backend()+  xla_args = xla._xla_callable_args(c, avals, tuple_args)+  out = _partition_translation_rule(+      c, xla.AxisEnv(), 'partition/', *xla_args,+      jaxpr=jaxpr, partition_ids=partition_ids, backend=backend)+  return c.Build(out)+++def _partition_abstract_eval(*args, jaxpr, **other_params):+  del args, other_params  # Unused.+  return _map(raise_to_shaped, jaxpr.out_avals)+++@contextmanager+def _op_partition(computation_builder, partition_id):+  if partition_id is None:+    yield+  else:+    sharding = xc.OpSharding()+    sharding.type = xc.OpSharding.Type.MAXIMAL+    sharding.tile_assignment_devices = [partition_id]+    computation_builder.SetSharding(sharding)+    try:+      yield+    finally:+      computation_builder.ClearSharding()+++def _map(f, *xs): return tuple(map(f, *xs))+++def _partition_translation_rule(+    c, axis_env, name_stack, *in_nodes, jaxpr, partition_ids, backend):+  jaxpr = jaxpr.jaxpr  # The `jaxpr` parameter is a `TypedJaxpr`.+  nodes = {}++  def read_node(v):+    if type(v) is core.Literal:+      return c.Constant(xla.canonicalize_dtype(v.val))+    else:+      return nodes[v]++  def write_node(v, node):+    assert node is not None+    nodes[v] = node++  write_node(core.unitvar, c.Tuple())+  _map(write_node, jaxpr.invars, in_nodes)++  var_partition_ids = partition_util.get_var_partition_ids(jaxpr, partition_ids)++  for eqn in jaxpr.eqns:+    invars = [v for v in eqn.invars if type(v) is not core.Literal]++    if invars:+      # Take partition ID from the first input var. We've already checked, in+      # `_get_var_partition_ids`, that they all match.+      partition_id = var_partition_ids[invars[0]]+    else:+      partition_id = None++    with _op_partition(c, partition_id):+      out_nodes = xla.jaxpr_subcomp_eqn(+          c, eqn, backend, axis_env, name_stack, read_node)+    _map(write_node, eqn.outvars, out_nodes)+  return c.Tuple(*_map(read_node, jaxpr.outvars))+++def _partition_jvp(primals, tangents, jaxpr):+  raise NotImplementedError()  # TODO(cjfj)+++partition_p = core.Primitive('partition')+partition_p.multiple_results = True+partition_p.def_impl(_partition_impl)+partition_p.def_abstract_eval(_partition_abstract_eval)+ad.primitive_jvps[partition_p] = _partition_jvp+xla.initial_style_translations[partition_p] = _partition_translation_rule+++def partition_put(x, partition_id):+  """Places the given value(s) onto the specified partition.++  Note that for nested `partition` calls, the `partition_id` is the index into+  the list of partition IDs used by the innermost partitioned function.++  Example:+    ```+    @partial(partition, num_partitions=2)+    def f(w):+      x = w ** 2+      y = jax.partition_put(x, 1)+      z = 42 * y+      return z++    @partial(partition, num_partitions=4)+    def g(i):+      j = i ** 3+      k = f(i, outer_partition_ids=[1, 2])+      l = 42 * jax.partition_put(k, 3)+      return l++    b = g(a)+    ```++    `g` is the outermost partitioned function here, so partition IDs map+    directly to the logical XLA devices. `f` is nested, with its two partitions+    mapped onto `[1, 2]` using `outer_partition_ids`. `x` is therefore+    calculated on XLA partition 1, whilst `z` is calculated on XLA partition 2.++  Args:+    x: The value(s) to copy. This can be a tree of values.+    partition_id: The partition ID, within the innermost `partition` scope,+      on which to place the value(s).++  Returns:+    The input values, copied onto the specified partition.+  """+  return tree_map(partial(partition_put_p.bind, partition_id=partition_id), x)+++def _partition_put_impl(x, *, partition_id):+  del x, partition_id  # Unused.+  raise NotImplementedError(

nit: make this a ValueError or something, since presumably we'll never implement this?

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

 def _xla_callable(fun, device, backend, name, *arg_specs):       extend_name_stack(wrap_name(name, 'jit')), *xla_args)   built = c.Build(c.Tuple(*out_nodes)) +  if device:+    device_assignment = ((device.id,),)+  else:+    device_assignment = xb.get_backend(backend).get_default_device_assignment(nreps, nparts)+    device_assignment = onp.vectorize(lambda d: d.id)(onp.array(device_assignment))+   options = xb.get_compile_options(-      num_replicas=nreps,-      num_partitions=1,-      device_assignment=(device.id,) if device else None)+      num_replicas=nreps, num_partitions=nparts, device_assignment=device_assignment)   compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend)) -  if nreps == 1:-    return partial(_execute_compiled, compiled, backend, result_handlers, tuple_args)-  else:-    return partial(_execute_replicated, compiled, backend, result_handlers, tuple_args)+  return create_execute_fn(+      compiled, device_assignment=device_assignment, backend=backend,+      tuple_args=tuple_args, result_handlers=result_handlers,+      out_partition_ids=out_partition_ids)++def create_execute_fn(+    compiled, *, device_assignment, backend, tuple_args, result_handlers, out_partition_ids):+  # TODO(cjfj): Expose device assignment from the executable.+  local_devices = compiled.local_devices()+  # We are assuming here that `local_devices` are ordered by replica then partition.+  # TODO(cjfj): Expose logical devices from the executable, rather than re-computing here.+  local_logical_devices = []+  local_device_ids = {d.id for d in local_devices}+  for replica_id, replica_devices in enumerate(device_assignment):+    for partition_id, device_id in enumerate(replica_devices):+      if device_id in local_device_ids:+        local_logical_devices.append((replica_id, partition_id))++  # We take our outputs from the replica 0 devices. Pre-compute the mappings+  # to these devices for each partition.+  local_replica0_devices = {+      part_id: local_device_id for local_device_id, (replica_id, part_id)+      in enumerate(local_logical_devices) if replica_id == 0}++  def execute(*args):

I'm a little worried about adding overheads to the non-partition path... what do you think about keeping the old _execute_compiled and _executed_replicated functions for now, with a TODO to benchmark if there's a noticeable difference?

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

 def _xla_callable(fun, device, backend, name, *arg_specs):       extend_name_stack(wrap_name(name, 'jit')), *xla_args)   built = c.Build(c.Tuple(*out_nodes)) +  if device:+    device_assignment = ((device.id,),)+  else:+    device_assignment = xb.get_backend(backend).get_default_device_assignment(nreps, nparts)+    device_assignment = onp.vectorize(lambda d: d.id)(onp.array(device_assignment))+   options = xb.get_compile_options(-      num_replicas=nreps,-      num_partitions=1,-      device_assignment=(device.id,) if device else None)+      num_replicas=nreps, num_partitions=nparts, device_assignment=device_assignment)   compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend)) -  if nreps == 1:-    return partial(_execute_compiled, compiled, backend, result_handlers, tuple_args)-  else:-    return partial(_execute_replicated, compiled, backend, result_handlers, tuple_args)+  return create_execute_fn(+      compiled, device_assignment=device_assignment, backend=backend,+      tuple_args=tuple_args, result_handlers=result_handlers,+      out_partition_ids=out_partition_ids)++def create_execute_fn(+    compiled, *, device_assignment, backend, tuple_args, result_handlers, out_partition_ids):+  # TODO(cjfj): Expose device assignment from the executable.+  local_devices = compiled.local_devices()+  # We are assuming here that `local_devices` are ordered by replica then partition.+  # TODO(cjfj): Expose logical devices from the executable, rather than re-computing here.+  local_logical_devices = []+  local_device_ids = {d.id for d in local_devices}+  for replica_id, replica_devices in enumerate(device_assignment):+    for partition_id, device_id in enumerate(replica_devices):+      if device_id in local_device_ids:+        local_logical_devices.append((replica_id, partition_id))++  # We take our outputs from the replica 0 devices. Pre-compute the mappings+  # to these devices for each partition.

How does this work? Don't you need outputs from all replicas?

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from contextlib import contextmanager+from functools import partial+from functools import wraps+import itertools+import threading++import numpy as onp++from .. import core+from .. import linear_util as lu+from .. import partition_util+from ..abstract_arrays import raise_to_shaped+from ..api_util import flatten_fun+from ..config import flags+from ..interpreters import ad+from ..interpreters import partial_eval as pe+from ..interpreters import xla+from ..lib import xla_bridge as xb+from ..lib import xla_client as xc+from ..tree_util import tree_flatten+from ..tree_util import tree_map+from ..tree_util import tree_unflatten+from ..util import cache+from ..util import unzip2+++FLAGS = flags.FLAGS+++class _ThreadLocalState(threading.local):+  def __init__(self):+    self.partition_ids_stack = []+_thread_local_state = _ThreadLocalState()+++def _get_partition_ids(num_partitions, outer_partition_ids):+  """Returns the partitions IDs after remapping the `outer_partition_ids`."""+  partition_ids_stack = _thread_local_state.partition_ids_stack+  if partition_ids_stack:+    if outer_partition_ids is None:+      raise ValueError(+          "`outer_partition_ids` must be specified for nested calls to "+          "`jax.partition`.")+    elif len(outer_partition_ids) != num_partitions:+      raise ValueError(+          "`outer_partition_ids` length ({}) did not match the number of "+          "partitions ({})".format(len(outer_partition_ids), num_partitions))+    else:+      return tuple(partition_ids_stack[-1][i] for i in outer_partition_ids)+  else:+    if outer_partition_ids is not None:+      raise ValueError(+          "`outer_partition_ids` must not be set for outermost `jax.partition` "+          "call")+    else:+      return tuple(range(num_partitions))+++def _abstractify(x):+  return raise_to_shaped(core.get_aval(x))+++def partition(fun, num_partitions):+  """Partitions the given function across multiple devices.++  Operations will be executed on the same partition as their inputs (throwing an+  error if there is a mismatch). Values can be moved between partitions using+  `partition_put`.++  Example:+    ```+    @partial(partition, num_partitions=2)+    def f(w):+      x = w ** 2+      y = jax.partition_put(x, 1)+      z = 42 * y+      return z++    result = f(1.)+    ```++    `x` will be computed on partition 0, `z` on partition 1.++  Note that if a call to a partitioned function is nested inside another, it+  must specify `outer_partition_ids`. See `partition_put` for more information.++  Args:+    fun: The function to be partitioned.+    num_partitions: The number of partitions.++  Returns:+    A partitioned function, with an additional, keyword-only parameter,+    `outer_partition_ids` that must be used when called within another+    partitioned function.+  """+  @wraps(fun)+  def fun_partitioned(*args, outer_partition_ids=None, **kwargs):+    partition_ids = _get_partition_ids(num_partitions, outer_partition_ids)++    _thread_local_state.partition_ids_stack.append(partition_ids)+    try:+      in_vals, in_tree = tree_flatten((args, kwargs))+      in_avals = _map(_abstractify, in_vals)+      in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]+      fun1, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)+      jaxpr, out_pvals, consts = pe.trace_to_jaxpr(+          fun1, in_pvals, instantiate=True, stage_out_calls=True)+      out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])+      const_avals = _map(_abstractify, consts)+      jaxpr = core.TypedJaxpr(+          pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals)+      outs = partition_p.bind(*itertools.chain(consts, in_vals),+                              jaxpr=jaxpr, partition_ids=partition_ids)+      return tree_unflatten(out_tree(), outs)+    finally:+      _thread_local_state.partition_ids_stack.pop()++  return fun_partitioned+++def _partition_impl(*args, **params):+  return _partition_callable(*map(xla.arg_spec, args), **params)(*args)+++@cache()+def _partition_callable(*args, jaxpr, partition_ids):+  nparts = max(partition_ids) + 1++  if nparts > xb.device_count():+    msg = ("compiling computation that requires {} devices, but only {} XLA "+           "devices are available")+    raise ValueError(msg.format(nparts, xb.device_count()))++  abstract_args, arg_devices = unzip2(args)+  result_handlers = tuple(map(partial(xla.aval_to_result_handler, None), jaxpr.out_avals))++  var_partition_ids = partition_util.get_var_partition_ids(jaxpr.jaxpr, partition_ids)+  out_partition_ids = tuple(var_partition_ids[v] for v in jaxpr.jaxpr.outvars)++  tuple_args=(len(abstract_args) > 100)

Add spaces around =, no need for parens

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

 def _xla_callable(fun, device, backend, name, *arg_specs):       extend_name_stack(wrap_name(name, 'jit')), *xla_args)   built = c.Build(c.Tuple(*out_nodes)) +  if device:+    device_assignment = ((device.id,),)+  else:+    device_assignment = xb.get_backend(backend).get_default_device_assignment(nreps, nparts)+    device_assignment = onp.vectorize(lambda d: d.id)(onp.array(device_assignment))+   options = xb.get_compile_options(-      num_replicas=nreps,-      num_partitions=1,-      device_assignment=(device.id,) if device else None)+      num_replicas=nreps, num_partitions=nparts, device_assignment=device_assignment)   compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend)) -  if nreps == 1:-    return partial(_execute_compiled, compiled, backend, result_handlers, tuple_args)-  else:-    return partial(_execute_replicated, compiled, backend, result_handlers, tuple_args)+  return create_execute_fn(+      compiled, device_assignment=device_assignment, backend=backend,+      tuple_args=tuple_args, result_handlers=result_handlers,+      out_partition_ids=out_partition_ids)++def create_execute_fn(+    compiled, *, device_assignment, backend, tuple_args, result_handlers, out_partition_ids):+  # TODO(cjfj): Expose device assignment from the executable.+  local_devices = compiled.local_devices()+  # We are assuming here that `local_devices` are ordered by replica then partition.+  # TODO(cjfj): Expose logical devices from the executable, rather than re-computing here.+  local_logical_devices = []+  local_device_ids = {d.id for d in local_devices}+  for replica_id, replica_devices in enumerate(device_assignment):+    for partition_id, device_id in enumerate(replica_devices):+      if device_id in local_device_ids:+        local_logical_devices.append((replica_id, partition_id))++  # We take our outputs from the replica 0 devices. Pre-compute the mappings+  # to these devices for each partition.+  local_replica0_devices = {+      part_id: local_device_id for local_device_id, (replica_id, part_id)+      in enumerate(local_logical_devices) if replica_id == 0}++  def execute(*args):+    in_bufs = []+    for device, (_, partition_id) in zip(local_devices, local_logical_devices):+      # All inputs are on partition 0 at the moment.+      if partition_id == 0:+        device_in_bufs = [device_put(x, device) for x in args if x is not token]+        if tuple_args:+          device_in_bufs = [make_tuple(device_in_bufs, device, backend)]+      else:+        device_in_bufs = []++      in_bufs.append(device_in_bufs) -def _xla_callable_device(nreps, backend, device, arg_devices):-  if nreps > 1:+    out_bufs = [deque(x.destructure()) for x in compiled.ExecuteOnLocalDevices(in_bufs)]+    if FLAGS.jax_debug_nans:+      for device_out_bufs in out_bufs:+        check_nans(xla_call_p, device_out_bufs)++    results = []+    for handler, out_partition_id in zip(result_handlers, out_partition_ids):+      local_device_id = local_replica0_devices.get(out_partition_id)+      if local_device_id is None:+        results.append(None)

When does this happen?

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from contextlib import contextmanager+from functools import partial+from functools import wraps+import itertools+import threading++import numpy as onp++from .. import core+from .. import linear_util as lu+from .. import partition_util+from ..abstract_arrays import raise_to_shaped+from ..api_util import flatten_fun+from ..config import flags+from ..interpreters import ad+from ..interpreters import partial_eval as pe+from ..interpreters import xla+from ..lib import xla_bridge as xb+from ..lib import xla_client as xc+from ..tree_util import tree_flatten+from ..tree_util import tree_map+from ..tree_util import tree_unflatten+from ..util import cache+from ..util import unzip2+++FLAGS = flags.FLAGS+++class _ThreadLocalState(threading.local):+  def __init__(self):+    self.partition_ids_stack = []+_thread_local_state = _ThreadLocalState()+++def _get_partition_ids(num_partitions, outer_partition_ids):+  """Returns the partitions IDs after remapping the `outer_partition_ids`."""+  partition_ids_stack = _thread_local_state.partition_ids_stack+  if partition_ids_stack:+    if outer_partition_ids is None:+      raise ValueError(+          "`outer_partition_ids` must be specified for nested calls to "+          "`jax.partition`.")+    elif len(outer_partition_ids) != num_partitions:+      raise ValueError(+          "`outer_partition_ids` length ({}) did not match the number of "+          "partitions ({})".format(len(outer_partition_ids), num_partitions))+    else:+      return tuple(partition_ids_stack[-1][i] for i in outer_partition_ids)+  else:+    if outer_partition_ids is not None:+      raise ValueError(+          "`outer_partition_ids` must not be set for outermost `jax.partition` "+          "call")+    else:+      return tuple(range(num_partitions))+++def _abstractify(x):+  return raise_to_shaped(core.get_aval(x))+++def partition(fun, num_partitions):

Do you know if this works on non-TPU backends? You don't need to do it for this change, but consider adding a backend argument.

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+

I think this should be jax/interpreters/partition.py, instead of living in lax. Even though it's technically defining primitives, we put other jit-like primitives in interpreters (e.g. xla.py, pxla.py).

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from contextlib import contextmanager+from functools import partial+from functools import wraps+import itertools+import threading++import numpy as onp++from .. import core+from .. import linear_util as lu+from .. import partition_util+from ..abstract_arrays import raise_to_shaped+from ..api_util import flatten_fun+from ..config import flags+from ..interpreters import ad+from ..interpreters import partial_eval as pe+from ..interpreters import xla+from ..lib import xla_bridge as xb+from ..lib import xla_client as xc+from ..tree_util import tree_flatten+from ..tree_util import tree_map+from ..tree_util import tree_unflatten+from ..util import cache+from ..util import unzip2+++FLAGS = flags.FLAGS+++class _ThreadLocalState(threading.local):+  def __init__(self):+    self.partition_ids_stack = []+_thread_local_state = _ThreadLocalState()+++def _get_partition_ids(num_partitions, outer_partition_ids):+  """Returns the partitions IDs after remapping the `outer_partition_ids`."""+  partition_ids_stack = _thread_local_state.partition_ids_stack+  if partition_ids_stack:+    if outer_partition_ids is None:+      raise ValueError(+          "`outer_partition_ids` must be specified for nested calls to "+          "`jax.partition`.")+    elif len(outer_partition_ids) != num_partitions:+      raise ValueError(+          "`outer_partition_ids` length ({}) did not match the number of "+          "partitions ({})".format(len(outer_partition_ids), num_partitions))+    else:+      return tuple(partition_ids_stack[-1][i] for i in outer_partition_ids)+  else:+    if outer_partition_ids is not None:+      raise ValueError(+          "`outer_partition_ids` must not be set for outermost `jax.partition` "+          "call")+    else:+      return tuple(range(num_partitions))+++def _abstractify(x):+  return raise_to_shaped(core.get_aval(x))+++def partition(fun, num_partitions):+  """Partitions the given function across multiple devices.++  Operations will be executed on the same partition as their inputs (throwing an+  error if there is a mismatch). Values can be moved between partitions using+  `partition_put`.++  Example:+    ```+    @partial(partition, num_partitions=2)+    def f(w):+      x = w ** 2+      y = jax.partition_put(x, 1)+      z = 42 * y+      return z++    result = f(1.)+    ```++    `x` will be computed on partition 0, `z` on partition 1.++  Note that if a call to a partitioned function is nested inside another, it+  must specify `outer_partition_ids`. See `partition_put` for more information.++  Args:+    fun: The function to be partitioned.+    num_partitions: The number of partitions.++  Returns:+    A partitioned function, with an additional, keyword-only parameter,+    `outer_partition_ids` that must be used when called within another+    partitioned function.+  """+  @wraps(fun)+  def fun_partitioned(*args, outer_partition_ids=None, **kwargs):+    partition_ids = _get_partition_ids(num_partitions, outer_partition_ids)++    _thread_local_state.partition_ids_stack.append(partition_ids)+    try:+      in_vals, in_tree = tree_flatten((args, kwargs))+      in_avals = _map(_abstractify, in_vals)+      in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]+      fun1, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)+      jaxpr, out_pvals, consts = pe.trace_to_jaxpr(+          fun1, in_pvals, instantiate=True, stage_out_calls=True)+      out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])+      const_avals = _map(_abstractify, consts)+      jaxpr = core.TypedJaxpr(+          pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals)+      outs = partition_p.bind(*itertools.chain(consts, in_vals),+                              jaxpr=jaxpr, partition_ids=partition_ids)+      return tree_unflatten(out_tree(), outs)+    finally:+      _thread_local_state.partition_ids_stack.pop()++  return fun_partitioned+++def _partition_impl(*args, **params):+  return _partition_callable(*map(xla.arg_spec, args), **params)(*args)+++@cache()

This should be@lu.cache

chr1sj0nes

comment created time in 17 days

Pull request review commentgoogle/jax

Add basic support for partitions within JAX.

+# Copyright 2018 Google LLC

Copyright 2020

chr1sj0nes

comment created time in 17 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 13316f35705fee2f43376655313e698b52d1965f

Fix type error in partial_eval.py. (#2171)

view details

push time in 18 days

PR merged google/jax

Fix type error in partial_eval.py. cla: yes

The tuple constructor takes a single iterable argument.

@gnecula this was introduced in https://github.com/google/jax/commit/862a1d594b67845f480bba7b342afb134ffc1a14 and caught via pytype. Are we missing some test coverage?

+2 -2

0 comment

1 changed file

skye

pr closed time in 18 days

PR opened google/jax

Fix type error in partial_eval.py.

The tuple constructor takes a single iterable argument.

@gnecula this was introduced in https://github.com/google/jax/commit/862a1d594b67845f480bba7b342afb134ffc1a14 and caught via pytype. Are we missing some test coverage?

+2 -2

0 comment

1 changed file

pr created time in 18 days

create barnchskye/jax

branch : pytype

created branch time in 18 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def _reshape_impl(operand, new_sizes, dimensions, old_sizes):       aval = ShapedArray(new_sizes, operand.dtype)       lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)       return xla.DeviceArray(aval, None, lazy_expr, operand.device_buffer)-  if (type(operand) is pxla.ShardedDeviceArray and dimensions is None-      and _is_axis_merge(old_sizes, new_sizes)):-    aval = ShapedArray(new_sizes, operand.dtype)-    return pxla.ChunkedDeviceArray(old_sizes[0], aval, operand.device_buffers)-  elif (type(operand) is pxla.ChunkedDeviceArray and dimensions is None-        and _is_axis_split(old_sizes, new_sizes)-        and operand.axis_size == new_sizes[0]):-    aval = ShapedArray(new_sizes, operand.dtype)-    return pxla.ShardedDeviceArray(aval, operand.device_buffers)-  else:-    return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,-                               dimensions=dimensions, old_sizes=old_sizes)++  if type(operand) is pxla.ShardedDeviceArray and dimensions is None:+    # TODO(skye): the axis split/merge logic below assumes that+    # ShardedDevicesArrays are always sharded across their leading axes. Remove+    # this implicit constraint, especially if/when we add APIs that produce+    # sharding across interior axes.+    chunk_shape = operand.device_buffers[0].shape().dimensions()+    if _is_axis_merge(old_sizes, new_sizes):+      num_chunks, ragged = divmod(new_sizes[0], chunk_shape[0])+      if not ragged:+        aval = ShapedArray(new_sizes, operand.dtype)+        # TODO(skye): deal with replication

There's an assert in the SDA constructor that len(buffers) == len(indices).

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def _reshape_impl(operand, new_sizes, dimensions, old_sizes):       aval = ShapedArray(new_sizes, operand.dtype)       lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)       return xla.DeviceArray(aval, None, lazy_expr, operand.device_buffer)-  if (type(operand) is pxla.ShardedDeviceArray and dimensions is None-      and _is_axis_merge(old_sizes, new_sizes)):-    aval = ShapedArray(new_sizes, operand.dtype)-    return pxla.ChunkedDeviceArray(old_sizes[0], aval, operand.device_buffers)-  elif (type(operand) is pxla.ChunkedDeviceArray and dimensions is None-        and _is_axis_split(old_sizes, new_sizes)-        and operand.axis_size == new_sizes[0]):-    aval = ShapedArray(new_sizes, operand.dtype)-    return pxla.ShardedDeviceArray(aval, operand.device_buffers)-  else:-    return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,-                               dimensions=dimensions, old_sizes=old_sizes)++  if type(operand) is pxla.ShardedDeviceArray and dimensions is None:+    # TODO(skye): the axis split/merge logic below assumes that+    # ShardedDevicesArrays are always sharded across their leading axes. Remove+    # this implicit constraint, especially if/when we add APIs that produce+    # sharding across interior axes.+    chunk_shape = operand.device_buffers[0].shape().dimensions()

It's possible I'm missing something and there's a simpler way to write this :) I find this code difficult to reason about... You're right about the extra constraints. I added a check that they're all the same size (although maybe we should make this a constraint on all SDAs), and added a TODOs about the other constraints.

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def block_until_ready(self):   @property   def _value(self):     if self._npy_value is None:-      ids = self._ids()+      # TODO(skye): remove this to avoid transferring replicated buffers?       self.copy_to_host_async()-      self._npy_value = self._collect([self.device_buffers[i].to_py() for i in ids])+      self._npy_value = onp.empty(self.aval.shape, self.aval.dtype)+      # TODO(skye): benchmark and possibly switch to a set (maybe with with+      # cached hashable indices?)+      already_copied_indices = []+      for buf, idx in zip(self.device_buffers, self.logical_indices):+        if idx in already_copied_indices:+          continue

We don't currently enforce it, but yes, overlapping buffers have the same data (since they represent pieces of a single logical buffer). Also note we only have replicated buffers, at least for now, so either two buffers are cover exactly the same index or they don't overlap at all. In theory we could produce partially overlapping buffers (and this code would work, if inefficiently), but I don't think we will for any of the APIs we have planned.

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.

Good idea, done.

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 class ShardedDeviceArray(xla.DeviceArray):   that is, it is only an optimization to reduce transfers.    The number of device buffers underlying a ShardedDeviceArray instance is equal-  to the number of replicas of the computation that produced it. Each buffer-  represents a shard of the original array, meaning a slice along its leading-  axis. These component buffers reside on distinct devices, but need not-  represent distinct logical shards. The correspondence can be computed with-  the assign_shards_to_replicas function.+  to the number of devices of the computation that produced it. Each buffer+  represents a shard of the original array, indicated by its corresponding index+  into the logical array. These component buffers reside on distinct devices,+  but need not represent distinct logical shards.   """-  __slots__ = ["device_buffers", "axis_size"]-  _collect = staticmethod(onp.stack)+  __slots__ = ["device_buffers", "logical_indices"] -  def __init__(self, aval, device_buffers):+  def __init__(self, aval, logical_indices, device_buffers):+    assert len(logical_indices) == len(device_buffers)

Using numpy indexing seemed like a natural choice, because I needed some way to represent a variety of "chunks" of an array, and also a way to translate those chunks back and forth between actual numpy arrays (or DeviceArrays, etc.). We don't support all types of numpy indexing obviously, but it's very useful to use numpy indexing in the ways we do. Does that answer your questions?

skye

comment created time in 19 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 4b92a49d8df8f299f889bb6cbfffe9fbc55a606d

Respond to review comments.

view details

push time in 19 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha b19f7e935781bd848525c418c5a093b1d200dca5

WIP sharded_jit implementation (#2158)

view details

push time in 19 days

PR merged google/jax

WIP sharded_jit implementation cla: yes
+300 -0

0 comment

1 changed file

skye

pr closed time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def block_until_ready(self):   @property   def _value(self):     if self._npy_value is None:-      ids = self._ids()+      # TODO(skye): remove this to avoid transferring replicated buffers?       self.copy_to_host_async()-      self._npy_value = self._collect([self.device_buffers[i].to_py() for i in ids])+      self._npy_value = onp.empty(self.aval.shape, self.aval.dtype)+      # TODO(skye): benchmark and possibly switch to a set (maybe with with+      # cached hashable indices?)+      already_copied_indices = []+      for buf, idx in zip(self.device_buffers, self.logical_indices):+        if idx in already_copied_indices:+          continue+        self._npy_value[idx] = buf.to_py()+        already_copied_indices.append(idx)     return self._npy_value    def __getitem__(self, idx):-    if self._npy_value is None and type(idx) is int:-      ids = self._ids()-      device_buffer = self.device_buffers[ids[idx]]-      aval = ShapedArray(self.aval.shape[1:], self.aval.dtype)-      return xla.DeviceArray(aval, None, lazy.array(aval.shape), device_buffer)+    if self._npy_value is None and idx in self.logical_indices:+      buf = self.device_buffers[self.logical_indices.index(idx)]+      aval = ShapedArray(buf.shape().dimensions(), self.aval.dtype)+      return xla.DeviceArray(aval, None, lazy.array(aval.shape), buf)     else:       return super(ShardedDeviceArray, self).__getitem__(idx) -# This handler code is effectively dead because we in-lined it in shard_args for-# performance reasons.-def _shard_sharded_device_array(x, devices, assignments):-  n = len(devices)-  if n == len(x.device_buffers):-    return (b if b.device() == devices[r] else b.copy_to_device(devices[r])-            for r, b in enumerate(x.device_buffers))-  else:-    return (xla.device_put(x[assignments[r]], devices[r]) for r in range(n))-shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array+def _hashable_index(idx):+  return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,+                  idx)++# The fast path is handled directly in shard_args().+def _shard_sharded_device_array_slow_path(x, devices, indices):+  candidates = defaultdict(list)+  for buf, buf_idx in zip(x.device_buffers, x.logical_indices):+    candidates[_hashable_index(buf_idx)].append(buf)++  bufs = []+  for idx, device in safe_zip(indices, devices):+    # Lookup all buffers that contain the correct slice of the logical array.

Done.

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def block_until_ready(self):   @property   def _value(self):     if self._npy_value is None:-      ids = self._ids()+      # TODO(skye): remove this to avoid transferring replicated buffers?       self.copy_to_host_async()-      self._npy_value = self._collect([self.device_buffers[i].to_py() for i in ids])+      self._npy_value = onp.empty(self.aval.shape, self.aval.dtype)+      # TODO(skye): benchmark and possibly switch to a set (maybe with with+      # cached hashable indices?)+      already_copied_indices = []+      for buf, idx in zip(self.device_buffers, self.logical_indices):+        if idx in already_copied_indices:+          continue+        self._npy_value[idx] = buf.to_py()+        already_copied_indices.append(idx)     return self._npy_value    def __getitem__(self, idx):-    if self._npy_value is None and type(idx) is int:-      ids = self._ids()-      device_buffer = self.device_buffers[ids[idx]]-      aval = ShapedArray(self.aval.shape[1:], self.aval.dtype)-      return xla.DeviceArray(aval, None, lazy.array(aval.shape), device_buffer)+    if self._npy_value is None and idx in self.logical_indices:+      buf = self.device_buffers[self.logical_indices.index(idx)]+      aval = ShapedArray(buf.shape().dimensions(), self.aval.dtype)+      return xla.DeviceArray(aval, None, lazy.array(aval.shape), buf)     else:       return super(ShardedDeviceArray, self).__getitem__(idx) -# This handler code is effectively dead because we in-lined it in shard_args for-# performance reasons.-def _shard_sharded_device_array(x, devices, assignments):-  n = len(devices)-  if n == len(x.device_buffers):-    return (b if b.device() == devices[r] else b.copy_to_device(devices[r])-            for r, b in enumerate(x.device_buffers))-  else:-    return (xla.device_put(x[assignments[r]], devices[r]) for r in range(n))-shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array+def _hashable_index(idx):+  return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,

Good catch. I added this constraint to the SDA docstring and added a TODO to check it. (We could in theory allow strides, but I don't know why anyone would want that...)

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

  def identity(x): return x -def shard_args(backend, devices, assignments, axis_size, tuple_args, args):+def shard_args(backend, devices, indices, axis_size, tuple_args, args):   """Shard each argument data array along its leading axis.    Args:     backend: the platform to be used     devices: list of Devices mapping replica index to a physical device.-    assignments: list of integers with the same length as `devices` mapping-      replica index to an index along the leading axis (i.e. a shard).+    indices: list of logical indices into the argument with the same length as+      `devices`. Can be an int, a slice object, or a tuple thereof. This

These are all great points!

How about tweaking the last two sentences to:

I added an abridged version of your suggestion and refer back to ShardedDeviceArray.logical_indices.

the integer values had to be collectively exhaustive

This is still true, in that the indices must cover the full logical array. I added this to the ShardedDeviceArray.logical_indices description. It would also be good to assert this in the SDA constructor, although I wanna add a benchmark before adding too much stuff to SDA creation. I added a TODO for now.

What constraints exist between the elements of args now? For example, when an element of indices is a tuple, does that mean all args have to have rank (number of axes) greater than or equal to the length of that tuple?

Good question :) Each arg's rank does have to be >= the number of subindices (that's a word I made up earlier today for the elements inside an index tuple). This is gonna become more complicated with replication and partitioning though. Unfortunately, I think we're gonna need different indices for different arguments, since arguments can be individually replicated or partitioned (or pmapped once James has his way). The only constraint will be that each argument is ultimately distributed onto the same number of devices.

we probably should have answers written down somewhere

Agreed. I'd like the code to be as self-documenting as possible, although a design doc or two describing the overall problem and possible solutions wouldn't hurt either. The details/invariants are also gonna change as we iterate and add features (e.g. maybe we'll canonicalize all indices into full-rank tuples, like James suggested). Please keep asking these questions!

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def block_until_ready(self):   @property   def _value(self):     if self._npy_value is None:-      ids = self._ids()+      # TODO(skye): remove this to avoid transferring replicated buffers?

The below line, self.copy_to_host_async(). This seems unnecessary and possibly incurring extra copies (in both the current code and this PR), but maybe I'm missing something.

skye

comment created time in 19 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++import os+import time++import numpy as onp+from tabulate import tabulate++import jax+import jax.numpy as np+from jax import pmap+from jax.util import safe_zip++def pstd(x):+  return x.std() / x.mean() * 100++def benchmark(f, iters=None, warmup=None, name=None, target_total_secs=None):+  if target_total_secs is None:+    target_total_secs = int(os.getenv("TARGET_TOTAL_SECS", 10))++  if warmup is None:+    if iters is None:+      warmup = 1+    else:+      warmup = onp.clip(1, iters // 10, 10)+  for _ in range(warmup):+    f()++  times = []+  count = 0+  while (count < iters if iters is not None+         else sum(times) < target_total_secs):+    start = time.time()+    f()+    end = time.time()+    times.append(end - start)+    count += 1++  times = onp.array(times)+  print("---------Benchmark results for %s---------" % (name or f.__name__))+  print("mean=%f std=%f %%std=%f total=%f" %+        (times.mean(), times.std(), pstd(times), times.sum()))+  print("#iters=%d #warmup=%d" % (count, warmup))+  print()+  return times+++def benchmark_suite(funcs, params_list, param_names, name,+                    target_total_secs=None):+  times = []+  for f, params in safe_zip(funcs, params_list):+    subname = name + "".join("_%s=%s" % (n, p)+                             for n, p in safe_zip(param_names, params))+    times.append(benchmark(f, name=subname,+                           target_total_secs=target_total_secs))++  print("---------Benchmark summary for %s---------" % name)+  print(tabulate([tuple(params) ++                  (t.mean(), pstd(t), t.mean() / times[0].mean())+                  for params, t in safe_zip(params_list, times)],+                 param_names + ["mean", "%std", "relative"]))+++def pmap_shard_args_benchmark():+  """Pmap benchmark focusing on shard_args fast path.++  This is intended to measure how long it takes to dispatch a correctly-sharded+  ShardedDeviceArray to pmap.+  """++  def get_benchmark_fn(nargs, nshards):+    shape = (nshards, 4)+    args = [onp.random.random(shape) for _ in range(nargs)]+    sharded_args = pmap(lambda x: x)(args)+    assert all(type(arg) == jax.pxla.ShardedDeviceArray for arg in sharded_args)+    pmap_fn = pmap(lambda *args: np.sum(args))+    def benchmark_fn():+      for _ in range(100):+        pmap_fn(*sharded_args)+    return benchmark_fn++  params = []+  for nargs in (10, 100, 101, 500):+    nshards = min(4, jax.local_device_count())+    params.append((nargs, nshards))+  for nshards in (2, 4, 8, 100, 500):+    if nshards > jax.local_device_count(): continue

Probably not, unless you have 500 local devices. I run it on CPU with 500 devices just to see what happens (it's slow). It can also help see how it scales overall.

skye

comment created time in 19 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 51b5b44bf641617f7ad836e338553549dbda3dd6

More documentation

view details

push time in 19 days

push eventskye/jax

Skye Wanderman-Milne

commit sha e023182e1766b38efbea005af2f1f28ea92b7dfd

James' suggestions Co-Authored-By: James Bradbury <jekbradbury@google.com>

view details

push time in 19 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 7404e88b358a377cdf8d8e580349311185184af6

Adjust scipy_stats_test.py tolerance.

view details

push time in 20 days

PR opened google/jax

WIP sharded_jit implementation
+300 -0

0 comment

1 changed file

pr created time in 20 days

create barnchskye/jax

branch : sharded_jit

created branch time in 20 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 class ShardedDeviceArray(xla.DeviceArray):   that is, it is only an optimization to reduce transfers.    The number of device buffers underlying a ShardedDeviceArray instance is equal-  to the number of replicas of the computation that produced it. Each buffer-  represents a shard of the original array, meaning a slice along its leading-  axis. These component buffers reside on distinct devices, but need not-  represent distinct logical shards. The correspondence can be computed with-  the assign_shards_to_replicas function.+  to the number of devices of the computation that produced it. Each buffer+  represents a shard of the original array, indicated by its corresponding index

I changed it around some more based on your suggestion. Hopefully this + the below attributes section is clear now.

skye

comment created time in 20 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

  def identity(x): return x -def shard_args(backend, devices, assignments, axis_size, tuple_args, args):+def shard_args(backend, devices, indices, axis_size, tuple_args, args):   """Shard each argument data array along its leading axis.    Args:     backend: the platform to be used     devices: list of Devices mapping replica index to a physical device.-    assignments: list of integers with the same length as `devices` mapping-      replica index to an index along the leading axis (i.e. a shard).+    indices: list of logical indices into the argument with the same length as

Wikipedia says an indexer is the operator overloading allowing a class to act like an array, so I think indices is the right term here. You index with an index.

skye

comment created time in 20 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 class ShardedDeviceArray(xla.DeviceArray):   that is, it is only an optimization to reduce transfers.    The number of device buffers underlying a ShardedDeviceArray instance is equal-  to the number of replicas of the computation that produced it. Each buffer-  represents a shard of the original array, meaning a slice along its leading-  axis. These component buffers reside on distinct devices, but need not-  represent distinct logical shards. The correspondence can be computed with-  the assign_shards_to_replicas function.+  to the number of devices of the computation that produced it. Each buffer+  represents a shard of the original array, indicated by its corresponding index+  into the logical array. These component buffers reside on distinct devices,+  but need not represent distinct logical shards.   """-  __slots__ = ["device_buffers", "axis_size"]-  _collect = staticmethod(onp.stack)+  __slots__ = ["device_buffers", "logical_indices"] -  def __init__(self, aval, device_buffers):+  def __init__(self, aval, logical_indices, device_buffers):+    assert len(logical_indices) == len(device_buffers)

This is a good idea. I'm gonna hold off for now since it'll be easy to add in a follow-up PR, and we might have a better idea of what makes sense later (plus I should probably write a benchmark to measure SDA creation time...).

skye

comment created time in 20 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def _reshape_impl(operand, new_sizes, dimensions, old_sizes):       aval = ShapedArray(new_sizes, operand.dtype)       lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)       return xla.DeviceArray(aval, None, lazy_expr, operand.device_buffer)-  if (type(operand) is pxla.ShardedDeviceArray and dimensions is None-      and _is_axis_merge(old_sizes, new_sizes)):-    aval = ShapedArray(new_sizes, operand.dtype)-    return pxla.ChunkedDeviceArray(old_sizes[0], aval, operand.device_buffers)-  elif (type(operand) is pxla.ChunkedDeviceArray and dimensions is None-        and _is_axis_split(old_sizes, new_sizes)-        and operand.axis_size == new_sizes[0]):-    aval = ShapedArray(new_sizes, operand.dtype)-    return pxla.ShardedDeviceArray(aval, operand.device_buffers)-  else:-    return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,-                               dimensions=dimensions, old_sizes=old_sizes)++  if type(operand) is pxla.ShardedDeviceArray and dimensions is None:

I think you're right! I think this works for now, because it assumes all SDAs are only sharded across leading axes, and our current APIs only produce SDAs with this property. However, we'll have to update this once this isn't the case. I'll add a TODO -- I think once we develop more infra for dealing with indices (e.g. to support efficient nested pmaps) this will be easier to address.

skye

comment created time in 20 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 8a906cc8d1ee29fddf93435f89183b5e41d7e2ba

Even more documentation

view details

push time in 20 days

issue commentgoogle/jax

JAX 0.1.58 doesn't see TPU in public Colab

You need the special preamble that switches to the new Cloud TPU stack that jax uses (we'd eventually like Cloud TPUs to work out of the box, but for now this boilerplate is necessary). See the first cell of any our example cloud TPU notebooks, e.g. https://github.com/google/jax/blob/master/cloud_tpu_colabs/NeurIPS_2019_JAX_demo.ipynb. You can probably bump the jax and jaxlib versions on that first pip install line too, please report if that somehow breaks things!

romanngg

comment created time in 20 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 class ShardedDeviceArray(xla.DeviceArray):   that is, it is only an optimization to reduce transfers.    The number of device buffers underlying a ShardedDeviceArray instance is equal-  to the number of replicas of the computation that produced it. Each buffer-  represents a shard of the original array, meaning a slice along its leading-  axis. These component buffers reside on distinct devices, but need not-  represent distinct logical shards. The correspondence can be computed with-  the assign_shards_to_replicas function.+  to the number of devices of the computation that produced it. Each buffer+  represents a shard of the original array, indicated by its corresponding index

I added a description of the logical_indices field. Please let me know if it's still unclear or you have any suggestions for improvement.

skye

comment created time in 20 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++import os+import time++import numpy as onp+from tabulate import tabulate++import jax+import jax.numpy as np+from jax import pmap+from jax.util import safe_zip++def pstd(x):+  return x.std() / x.mean() * 100++def benchmark(f, iters=None, warmup=None, name=None, target_total_secs=None):+  if target_total_secs is None:+    target_total_secs = int(os.getenv("TARGET_TOTAL_SECS", 10))++  if warmup is None:+    if iters is None:+      warmup = 1+    else:+      warmup = onp.clip(1, iters // 10, 10)+  for _ in range(warmup):+    f()++  times = []+  count = 0+  while (count < iters if iters is not None+         else sum(times) < target_total_secs):+    start = time.time()+    f()+    end = time.time()+    times.append(end - start)+    count += 1++  times = onp.array(times)+  print("---------Benchmark results for %s---------" % (name or f.__name__))+  print("mean=%f std=%f %%std=%f total=%f" %+        (times.mean(), times.std(), pstd(times), times.sum()))+  print("#iters=%d #warmup=%d" % (count, warmup))+  print()+  return times+++def benchmark_suite(funcs, params_list, param_names, name,+                    target_total_secs=None):+  times = []+  for f, params in safe_zip(funcs, params_list):+    subname = name + "".join("_%s=%s" % (n, p)+                             for n, p in safe_zip(param_names, params))+    times.append(benchmark(f, name=subname,+                           target_total_secs=target_total_secs))++  print("---------Benchmark summary for %s---------" % name)+  print(tabulate([tuple(params) ++                  (t.mean(), pstd(t), t.mean() / times[0].mean())+                  for params, t in safe_zip(params_list, times)],+                 param_names + ["mean", "%std", "relative"]))+++def pmap_shard_args_benchmark():

Added a docstring, lemme know if I should include more detail.

skye

comment created time in 20 days

push eventskye/jax

Skye Wanderman-Milne

commit sha 157e910d97aea9c4d96ebec9e338bff8d00825e2

More docstring

view details

push time in 20 days

PR opened google/jax

Reviewers
Allow ShardedDeviceArrays to represent arbitrary data shardings.

This change adds ShardedDeviceArray.logical_indices, which specifies what part of the full logical array each device buffer represents. The logical indices can be ints, slice objects, or tuples thereof. Previously, a ShardedDeviceArray's shards always represented a single row of the leading axis, but the indices allow specifying a wider range of shard shape, e.g. multi-dimensional tiles.

This also removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.

I also added a microbenchmark to make sure this doesn't regress the shard_args fast path. The results are pretty noisy, but it looks like this is about the same speed as before. This is how I run the benchmark:

TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py
+226 -107

0 comment

4 changed files

pr created time in 21 days

push eventskye/jax

Skye Wanderman-Milne

commit sha b9dc926b266aac0d848fba766ff0eac8f074625f

Allow ShardedDeviceArrays to represent arbitrary data shardings. This change adds ShardedDeviceArray.logical_indices, which specifies what part of the full logical array each device buffer represents. The logical indices can be ints, slice objects, or tuples thereof. Previously, a ShardedDeviceArray's shards always represented a single row of the leading axis, but the indices allow specifying a wider range of shard shape, e.g. multi-dimensional tiles. This also removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. I also added a microbenchmark to make sure this doesn't regress the shard_args fast path. The results are pretty noisy, but it looks like this is about the same speed as before. This is how I run the benchmark: TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py

view details

push time in 21 days

create barnchskye/jax

branch : sharded_device_array_indices

created branch time in 21 days

issue commentgoogle/jax

Could not find cublas.h

Can you try setting the TF_CUDA_PATHS environment variable? See https://github.com/tensorflow/tensorflow/blob/master/third_party/gpus/find_cuda_config.py#L26, which is the script that tries to find your cuda install. We may need some extra plumbing to pass the env var through bazel though, please report back if just setting the env var doesn't work.

dhpollack

comment created time in 21 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha efbdaf66bfa584cc635092919a23b684c7fb2247

Adjust scipy_stats_test.py tolerance.

view details

push time in 23 days

issue commentgoogle/jax

JAX build on CentOS 7 (RHEL variant)

Thank you for sharing these great instructions!

MilesCranmer

comment created time in 25 days

issue commentgoogle/jax

JAX/XLA compiling for one device and running on another

@mktal can you share the exact error you're getting, and also the code that's producing it? That's odd you'd see this with two identical GPUs.

jlebar

comment created time in 25 days

more