google/jax 9746
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
altairviz/altair 5966
Declarative statistical visualization library for Python
Interactive plotting for Pandas using VegaLite
altairviz/altairtutorial 257
Notebooks for the Altair tutorial
don't ask
Content for my Astronomy 599 Course: Intro to scientific computing in Python
Interactive data exploration with Altair
A Python package for online & offline access to vega datasets
Altair backend for pandas plotting
altairviz/altairtransform 48
Evaluation of VegaLite transforms in Python
PR opened google/jax
Old error:
>>> import jax.numpy as jnp
>>> from jax import jit
>>> jit(jnp.sum)(jnp.arange(4), 0)
[...]
TypeError: Unexpected type of axis argument: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
New error:
>>> jit(jnp.sum)(jnp.arange(4), 0)
[...]
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
axis argument to jnp.sum().
The error occured while tracing the function sum at /Users/vanderplas/github/google/jax/jax/numpy/lax_numpy.py:1737.
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstracttracervalueencounteredwhereconcretevalueisexpectederror for more information.
Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
pr created time in 17 hours
push eventjakevdp/jax
commit sha 80d5f9ce6c4994a0fe32180f7763ee512550231b
jax.numpy: improved errors for invalid inputs to unary ops
commit sha 6aa35a2a7351e3c2bbf2fa1f58628801e95c8791
Adding an option to return the output tree in make_jaxpr
commit sha ab273acb2173e8f1a5343f26432cc27cb8a1a991
Call _check_arraylike in jax.numpy to improve error messages This PR only adds the call to places where nonarray inputs currently lead to errors. There remain a number of other functions where adding this check would lead to potentially breaking changes; these are deliberately left out of this PR.
commit sha 80fa22cf90459cdee8a6f1fbee7a0f7fedcfcb0c
Merge pull request #4381 from jakevdp:checkarraylike PiperOriginRevId: 333170269
commit sha 7e1b826ef56ff98307fda0d7614d27d6092dd3aa
Enable fast TPU LU decomposition for complex types.
commit sha 625be69333ad46c7e0b4cb765060054faf5f2927
[host_callback] Update the documentation The modulelevel documentation was out of date.
commit sha 2b7580c2d206e692a83fbc37eaeeb846e0a4462f
Consider lists as groups of axis names too
commit sha 3360bee9e96f2fbc8bcdb4746e58bfba1f771476
[jax2tf] Adjust tolerance in flaky float32 eigh test.
commit sha 89b989654e45e742973d1475f02a1119334eb632
Enable tests for complex QR decomposition on TPU.
commit sha f3d6132042ac1bfb1e541f8692569efc9169bf08
Merge pull request #4393 from hawkinsp:qr PiperOriginRevId: 333279018
commit sha c6cd2f91df0f37dd8987b6bed8847597ace9d121
Merge pull request #4392 from SIben:fix_eigh_flakiness_cpu PiperOriginRevId: 333294460
commit sha c9b3df3a64b8a6398d6e3a705e534c4494502fea
Merge pull request #4382 from hawkinsp:lu PiperOriginRevId: 333296510
commit sha c875ab3ec9b2ab4794e4068e64172c9869e1b618
Merge pull request #4391 from apaszke:axis_index_handle_list PiperOriginRevId: 333304709
commit sha 0d5f15f5c0f862d3d5cbf8f6853409cdc1091de1
Fix the abstract eval and translation rule for all_to_all The previous rules assumed that `split_axis == concat_axis` (i.e. that the used collective is equivalent to `pswapaxes`). Since we expose this as part of our API, we should probably make sure that we handle other cases too. Fixes #1332.
commit sha c42d736e347d059290ab20c013635d62a1ee6c45
remove limit on size of random arrays
commit sha 96f5a3c4026c929664e75e07262bba7a4c8d2044
fix test for nonomnistaging
commit sha 71f5f9972cd305d4060637115a7ff316087d229e
skip checks in big randomness test
commit sha d607164d35e05077d60969a8e6b145f10aca95e0
make_jaxpr return_shape use ShapeDtypeStruct, test
commit sha c7e0ef4075ed5166cd2a289057b75b682002a1e9
Merge pull request #4398 from google:liftrandomnesslimit PiperOriginRevId: 333433816
commit sha ebf7c1b6127d6df3a64a4d969297ca70e126a341
add jax logo file
push time in 20 hours
issue commentgoogle/jax
Hi – this is an interesting idea, and definitely would make ineroperability between np.random
and jax.random
easier.
One concern I have is that numpy's random interface is stateful, where JAX's is not. This makes things behave in possibly unexpected ways when they are combined with JAX's jit. For example:
from jax import jit
import jax.numpy as jnp
import numpy as np
def _numpy_generator_to_prngkey(seed: np.random.Generator) > jnp.ndarray:
return jnp.array(
seed.integers(0, np.iinfo(np.uint32).max, (2,), dtype=np.uint32,
endpoint=True))
seed = np.random.default_rng(seed=42)
print(_numpy_generator_to_prngkey(seed) == _numpy_generator_to_prngkey(seed))
# False
_numpy_generator_to_prngkey = jit(_numpy_generator_to_prngkey, static_argnums=[0])
print(_numpy_generator_to_prngkey(seed) == _numpy_generator_to_prngkey(seed))
# True
In the jitted version of the function, the seed is only generated once at tracing time. I suspect that this sort of thing might lead to subtle bugs that would be very difficult to identify.
Do you have thoughts on that?
comment created time in 21 hours
issue commentaltairviz/altair
Plots not showing up when served with voila without an ipywidget
I'm not sure about the details of the voila frontend. I'd suggest asking on the Voila issue tracker.
comment created time in a day
issue commentaltairviz/altair
Question: Can you plot a Gaussian distribution?
Sure, for example:
import altair as alt
import numpy as np
import pandas as pd
from scipy.stats import norm
x = np.linspace(5, 5, 1000)
df = pd.DataFrame({'x': x, 'y': norm.pdf(x)})
alt.Chart(df).mark_line().encode(
x='x',
y='y'
)
comment created time in 2 days
push eventjakevdp/jax
commit sha ab273acb2173e8f1a5343f26432cc27cb8a1a991
Call _check_arraylike in jax.numpy to improve error messages This PR only adds the call to places where nonarray inputs currently lead to errors. There remain a number of other functions where adding this check would lead to potentially breaking changes; these are deliberately left out of this PR.
commit sha 80fa22cf90459cdee8a6f1fbee7a0f7fedcfcb0c
Merge pull request #4381 from jakevdp:checkarraylike PiperOriginRevId: 333170269
commit sha 2cef06d99c8d13b848420cace1bfa8a5e553f9d3
Add test coverage for jnp.cov aweights & fweights
push time in 4 days
PR opened google/jax
pr created time in 4 days
issue commentgoogle/jax
Jax numpy does not throw IndexError when array index out of bounds
Hi  thanks for the report!
This is a known feature in JAX that you can read about here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AAOutofBoundsIndexing
comment created time in 4 days
Pull request review commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+# https://www.apache.org/licenses/LICENSE2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from absl.testing import absltest, parameterized++import numpy as np++from jax import test_util as jtu+import jax.numpy as jnp+from jax import core, jit, lax, lazy, make_jaxpr+from jax.interpreters import xla+from jax.lib import xla_client+xops = xla_client.ops++from jax.config import config+config.parse_flags_with_absl()++# Define a sparse array data structure. The important feature here is that+# it is a jaxpr object that is backed by two device buffers.+class SparseArray:+ """Simple sparse COO array data structure."""+ def __init__(self, aval, data, indices):+ self.aval = aval+ self.shape = aval.shape+ self.data = data+ self.indices = indices++ @property+ def index_dtype(self):+ return self.indices.dtype++ @property+ def dtype(self):+ return self.data.dtype++ @property+ def nnz(self):+ return self.data.shape[0]++ def __repr__(self):+ return repr(list((tuple(ind), d) for ind, d in zip(self.indices, self.data)))+++class AbstractSparseArray(core.ShapedArray):+ __slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval']+ _num_buffers = 2++ def __init__(self, shape, dtype, index_dtype, nnz):+ super(AbstractSparseArray, self).__init__(shape, dtype)+ self.index_dtype = index_dtype+ self.nnz = nnz+ self.data_aval = core.ShapedArray((nnz,), dtype)+ self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype)++ @core.aval_property+ def data(self):+ return sp_data_p.bind(self)++ @core.aval_property+ def indices(self):+ return sp_indices_p.bind(self)++def abstract_sparse_array(arr):+ return AbstractSparseArray(arr.shape, arr.dtype, arr.index_dtype, arr.nnz)++def sparse_array_result_handler(device, aval):+ def build_sparse_array(data_buf, indices_buf):+ data = xla.DeviceArray(aval.data_aval, device, lazy.array(aval.data_aval.shape), data_buf)+ indices = xla.DeviceArray(aval.indices_aval, device, lazy.array(aval.indices_aval.shape), indices_buf)+ return SparseArray(aval, data, indices)+ return build_sparse_array++def sparse_array_shape_handler(a):+ return (+ xla.xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),+ xla.xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),+ )++def sparse_array_device_put_handler(a, device):+ return (+ xla.xb.get_device_backend(device).buffer_from_pyval(a.data, device),+ xla.xb.get_device_backend(device).buffer_from_pyval(a.indices, device)+ )++core.pytype_aval_mappings[SparseArray] = abstract_sparse_array
Added the TODO for now.
comment created time in 4 days
Pull request review commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
def __init__(self, if indices is None: indices = spec_to_indices(aval.shape, sharding_spec) self.aval = aval+ assert all(isinstance(b, xc.Buffer) for b in device_buffers), f"Expected flattened list of device buffers; got {device_buffers}"
Done.
comment created time in 4 days
Pull request review commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
class DeviceArray: _HAS_DYNAMIC_ATTRIBUTES = True def __init__(self, aval: core.ShapedArray, device: Optional[Device], lazy_expr: lazy.LazyExpr, device_buffer: PyLocalBuffer):+ lazy_expr: lazy.LazyExpr,+ device_buffer: PyLocalBuffer):+ if isinstance(device_buffer, Sequence):
Done
comment created time in 4 days
Pull request review commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
def apply_primitive(prim, *args, **params): compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params) return compiled_fun(*args) ++def _partition_outputs(avals, outs):+ nouts = [aval._num_buffers for aval in avals]+ assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."
Done.
comment created time in 4 days
push eventjakevdp/jax
commit sha 699e46ac5a574272d34b79bb46c1df601bd0a461
raise TypeError if device_buffer is a tuple
push time in 4 days
push eventjakevdp/jax
commit sha e0d1b375fa593c5ef777e7e650a3a3e75996cedc
Delete dead axis_index code The primitive was moved to `lax_parallel.py` some time ago, so the one in `core` should no longer be used. This is probably a result of a botched rebase.
commit sha 332a9ba1ad1a08f4deb09b82ff0d4c1556b431eb
Fix axis_index inside nested pmaps The previous translation rule has assumed that `axis_index` is always taken over the outermost axis in the `axis_env`, and was always producing the same output, no matter which axis has been specified. This fixes the translation rule to start taking the `axis_name` into account. Additionally, this adds support for querying the index along multiple axes, which will be useful for `gmap`.
commit sha 8ac19c722211ffe4c7fb0bdbeeb236818d000291
Fix a faulty soft_pmap rule for axis_index The rule didn't specify the precision for the `np.arange` constant, which caused an accidental dtype promotion in X64 mode. Previously the error has luckicly been hidden behind a coerction that followed `axis_index` in that test, but the new implementation has surfaced it.
commit sha 99ffcc44d5d57c6bbb2ea94f92ed703d9492d6a1
Merge pull request #4378 from apaszke:axis_index_nd PiperOriginRevId: 333106474
commit sha 533fe28b47a39d8df923cb221ee92ffc31003816
Merge pull request #4377 from apaszke:axis_index PiperOriginRevId: 333107273
commit sha 05adb3f0236c7cba43dca6b7c411ee71217b86d0
Allow jax objects to be represented by multiple buffers
commit sha 592bfd960ff2683a7dd9709b4abe1095b5448f39
address some review comments
commit sha e3cfaf6b32e62730a85c31e6dc4e3f83d0c79ea5
Remove check for sequences of buffers in DeviceArray
commit sha ade0a3b9026eae55a173b5a3b4aded6692ea016e
check for buffer type in DeviceArray constructor
push time in 4 days
issue commentgoogle/jax
jnp.median scaling vs np.median
Hi  thanks for the report. Because XLA doesn't currently provide any efficient routine for computing the median, JAX currently does this via a full array sort, which accounts for the poor scaling.
comment created time in 4 days
issue commentvega/vegalite
Here is my first suggestion working properly, showing the correct tooltip: Open the Chart in the Vega Editor
comment created time in 4 days
issue commentvega/vegalite
Two things:

If you're going the route of using a full ISO 8601 date string, I wouldn't add a timezone. I'd just use the full string (like
20120102T00:00:00
instead of20120102
. Javascript will parse this as local time, and vega will display it as local time unless you explicitly specify UTC. 
If you're going the route of using utc time units, you'll need to use them everywhere, including the tooltip.
comment created time in 4 days
PR opened google/jax
This PR only adds the call to places where nonarray inputs currently lead to errors. There remain a number of other functions where adding this check would lead to potentially breaking changes; these are deliberately left out of this PR.
pr created time in 4 days
push eventjakevdp/jax
commit sha 611c8001930d12719ca12f94d7430b2a98d9dcec
Remove check for sequences of buffers in DeviceArray
push time in 4 days
push eventjakevdp/jax
commit sha bbe3a6a9a25e0669ccba08c1d3c0297c35649136
Improve segment_sum stability by kway summation.
commit sha 49a01d36c8b71902ed4136219c11c38d5a35b77c
Use jvp(expm) to compute expm_frechet.
commit sha 05cc7e7352f2e2b350cae8aae4f18c509701a624
device_put_sharded: remove incorrect type annotation
commit sha c4f98eb8fa2947e0db4a65bb67daa3243e7d103d
Add back the batching rule for ppermute Just make sure it's correct this time and add a test.
commit sha 2081e5acee2039930c250ccb91483ed6d6cfe580
Test pmap/vmap interactions of all reduction collectives
commit sha d7564c506e752eaf495a744950822ea8e52e6cee
Add test matrices that exercise more code path.
commit sha 614acce43c1900afb2f2c3fecf89139866aeea88
Change segment_sum to use no bucketing by default.
commit sha e76ebea9cb6cf6648467f24253519e52430d3827
Merge branch 'master' into rm_fretchet
commit sha ce1ce9cb276045324e9122830aa3890fef16af0f
Implement jnp.array_equiv
commit sha 83f14012ea9fc016a79a69287a91ba206993826d
Bump tol of float32 for complex64 inner product.
commit sha 35d231990c41ca5d1913f1817e75c45a1f934c02
Add ceil_of_ratio util and bucket_size TODO.
commit sha 18054e05a8f64bbaa9351ebd27564684888c1e96
call _check_arraylike in jnp.diff
commit sha ada6f30f59d44a574af111b540b4771f9708b0ef
Merge pull request #4347 from jakevdp:arrayequiv PiperOriginRevId: 332946445
commit sha be50847ceeafc3d08de31bae21e9404b6283f05f
Make scale_and_translate take spatial dimensions
commit sha ae910cdd311800fd1f2057b62d44028575820f60
Updating image_test
commit sha 2cb795e0d9dd513d613d059fbf5d6b4a7a9b9eaf
Merge pull request #4366 from jakevdp:diffempty PiperOriginRevId: 332960638
commit sha 2cf8d49f5b305bb6814136c81f35cdbfde187176
jnp.moveaxis: fix bug when axes are integer dtype
commit sha 04fa89a12c0c109f3b86d0e578ff0a647472204a
Merge pull request #4299 from zhangqiaorjc:segsum PiperOriginRevId: 332964956
commit sha b6e9da36eb7f265ee43fd61ac87d9240c8330452
Remove unused var. Bump tol.
commit sha 2abb37c286c92d70b8fb0104aacea60cda7675ce
move a _device_put_raw under broadcast impl Before this change, we had to interact with the device to construct an array of zeros, even if we were staging everything out (e.g. with jax.xla_computation and omnistaging).
push time in 4 days
pull request commentgoogle/jax
jax.numpy: improved errors for invalid inputs to unary ops
Is it possible to (usefully) shorten the stack trace even further by introducing something of this sort at the level of, say,
_wraps
?
I looked into this... it seems like it's possible to slightly shorten the traceback by adding the logic to _wraps
. It would look something like this:
>>> jnp.negative([1, 2, 3])

TypeError Traceback (most recent call last)
<ipythoninput2f2d2f6fccc05> in <module>
> 1 jnp.negative([1, 2, 3])
~/github/google/jax/jax/numpy/_util.py in _op(*args, **kwargs)
61 from jax.numpy.lax_numpy import _check_arraylike
62 def _op(*args, **kwargs):
> 63 _check_arraylike(fun.__name__, *(args[i] for i in check_arraylike))
64 op = _op
65 if not hasattr(fun, '__doc__') or fun.__doc__ is None:
~/github/google/jax/jax/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
295 if not _arraylike(arg))
296 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
> 297 raise TypeError(msg.format(fun_name, type(arg), pos))
298
299
TypeError: negative requires ndarray or scalar arguments, got <class 'list'> at position 0.
The problem is that this is hard to do in a way that's not brittle, because you have to somehow specify to _wraps
which arguments must be arraylike, and these can be either passed by argument or keyword, so it would involve parsing the function signature of the wrapped operation and combining that with some extra info that's passed to _wraps
in order to validate the right part of *args
and **kwargs
.
Given that, I think I'm going to stick with doing the validation within the function's implementation.
comment created time in 4 days
Pull request review commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
def __init__(self, if indices is None: indices = spec_to_indices(aval.shape, sharding_spec) self.aval = aval+ assert all(isinstance(b, xc.Buffer) for b in device_buffers), f"Expected flattened list of device buffers; got {device_buffers}"
I added this assertion after running an internal presubmit and finding that it was nearly impossible to locate the root cause of problems. Downstream dependencies were creating Sharded Device Arrays with device_buffers
as a list of tuples , and this was causing an entirely unintelligible error when, later in the code, this device array was passed to some jitted function. This assertion was the only way I could identify the root cause of those errors.
What do you think about leaving it in place, but adding a TODO to delete it once all downstream dependencies are updated?
comment created time in 4 days
PR opened google/jax
Before:
In [1]: import jax.numpy as jnp
In [2]: x = jnp.arange(6).reshape(2, 3)
In [3]: jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))

TypeError Traceback (most recent call last)
<ipythoninput353fba72727a1> in <module>
> 1 jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))
~/github/google/jax/jax/numpy/lax_numpy.py in moveaxis(a, source, destination)
1261 if isinstance(destination, int):
1262 destination = (destination,)
> 1263 source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
1264 destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
1265 if len(source) != len(destination):
~/github/google/jax/jax/interpreters/xla.py in __iter__(self)
1060 def __iter__(self):
1061 if self.ndim == 0:
> 1062 raise TypeError("iteration over a 0d array") # same as numpy error
1063 else:
1064 return self._value.__iter__()
TypeError: iteration over a 0d array
After:
In [1]: import jax.numpy as jnp
In [2]: x = jnp.arange(6).reshape(2, 3)
In [3]: jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))
Out[3]:
DeviceArray([[0, 3],
[1, 4],
[2, 5]], dtype=int32)
pr created time in 5 days
issue commentvega/vegalite
But concretely, if you serialize your dates in full ISO 8601 form (e.g. '20120102T00:00:00'
), javascript will treat them as local dates and your code will work as expected
Alternatively, if you don't want to change the format of the input dates, you can use the UTC equivalent of your time units; i.e. change "timeUnit": "yearmonthdate"
to "timeUnit": "utcyearmonthdate"
and it should work as expected.
comment created time in 5 days
issue commentvega/vegalite
I spent a Lot of time in Altair trying to make the interaction between Pandas dates and Javascript dates sensical from the perspective of the user. If you're generating data in Python and figures in VegaLite and are getting hung up on timezones in date conversions, Altair might be the remedy you're looking for.
comment created time in 5 days
PR opened google/jax
Before:
In [1]: import jax.numpy as jnp
In [2]: jnp.negative([1, 2, 3])

TypeError Traceback (most recent call last)
<ipythoninput2f2d2f6fccc05> in <module>
> 1 jnp.negative([1, 2, 3])
~/github/google/jax/jax/numpy/lax_numpy.py in <lambda>(x)
347 return lax_fn(x)
348 else:
> 349 fn = lambda x: lax_fn(x)
350 if lax_doc:
351 doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
~/github/google/jax/jax/lax/lax.py in neg(x)
89 def neg(x: Array) > Array:
90 r"""Elementwise negation: :math:`x`."""
> 91 return neg_p.bind(x)
92
93 def sign(x: Array) > Array:
~/github/google/jax/jax/core.py in bind(self, *args, **params)
264 top_trace = find_top_trace(args)
265 tracers = map(top_trace.full_raise, args)
> 266 out = top_trace.process_primitive(self, tracers, params)
267 return map(full_lower, out) if self.multiple_results else full_lower(out)
268
~/github/google/jax/jax/core.py in process_primitive(self, primitive, tracers, params)
572
573 def process_primitive(self, primitive, tracers, params):
> 574 return primitive.impl(*tracers, **params)
575
576 def process_call(self, primitive, f, tracers, params):
~/github/google/jax/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
222 def apply_primitive(prim, *args, **params):
223 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
> 224 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
225 return compiled_fun(*args)
226
~/github/google/jax/jax/interpreters/xla.py in arg_spec(x)
214
215 def arg_spec(x):
> 216 aval = abstractify(x)
217 try:
218 return aval, x._device
~/github/google/jax/jax/interpreters/xla.py in abstractify(x)
163 aval_fn = pytype_aval_mappings.get(typ)
164 if aval_fn: return aval_fn(x)
> 165 raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
166
167 def _make_abstract_python_scalar(typ, _):
TypeError: Argument '[1, 2, 3]' of type '<class 'list'>' is not a valid JAX type
After:
In [1]: import jax.numpy as jnp
In [2]: jnp.negative([1, 2, 3])

TypeError Traceback (most recent call last)
<ipythoninput2f2d2f6fccc05> in <module>
> 1 jnp.negative([1, 2, 3])
~/github/google/jax/jax/numpy/lax_numpy.py in <lambda>(x)
345 fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
346 else:
> 347 fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
348 if lax_doc:
349 doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
~/github/google/jax/jax/numpy/lax_numpy.py in _promote_args(fun_name, *args)
300 def _promote_args(fun_name, *args):
301 """Convenience function to apply Numpy argument shape and dtype promotion."""
> 302 _check_arraylike(fun_name, *args)
303 return _promote_shapes(fun_name, *_promote_dtypes(*args))
304
~/github/google/jax/jax/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
295 if not _arraylike(arg))
296 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
> 297 raise TypeError(msg.format(fun_name, type(arg), pos))
298
299
TypeError: negative requires ndarray or scalar arguments, got <class 'list'> at position 0.
pr created time in 5 days
issue commentgoogle/jax
Error when creating test for new implementation
Yes, it's fine for the shape of input arrays to determine the shape of outputs within JITcompiled computations. Where problems arise is when the values within input arrays determine the shape of outputs.
There's a bit of relevant discussion of this here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pythoncontrolflow+JIT
comment created time in 5 days
PR opened google/jax
Fixes jnp.diff
erroneously passingthrough scalars:
In [1]: import numpy as np
In [2]: import jax.numpy as jnp
In [3]: jnp.diff(0)
Out[3]: 0
In [4]: np.diff(0)

ValueError Traceback (most recent call last)
<ipythoninput385ad82741d0e> in <module>
> 1 np.diff(0)
<__array_function__ internals> in diff(*args, **kwargs)
~/.local/share/virtualenvs/jaxLBbfM5ix/lib/python3.8/sitepackages/numpy/lib/function_base.py in diff(a, n, axis, prepend, append)
1244 nd = a.ndim
1245 if nd == 0:
> 1246 raise ValueError("diff requires input that is at least one dimensional")
1247 axis = normalize_axis_index(axis, nd)
1248
ValueError: diff requires input that is at least one dimensional
pr created time in 5 days
issue commentgoogle/jax
Error when creating test for new implementation
In JAX's JIT compilation model, the shapes of arrays cannot be dependent on data values, unless those values are static. In your function, the bins
value determines the shape of the output, so it must be a static parameter in the computation.
There are various ways to make certain a parameter is static, but in the test suite, this is generally done by leaving them out of args_maker
and passing them to the function in question within a closure. You can see an example of that here, in the test for the 1D histogram: https://github.com/google/jax/blob/e0af77fbca88ec882ef33c18926f8382f7178aed/tests/lax_numpy_test.py#L2286L2300
Notice that bins
and density
are treated as static arguments, i.e. they're not parameters passed to the jitted function. If bins
were passed to the function by args_maker
in this test, it would result in the same error you're seeing in your test.
Does that make sense?
comment created time in 5 days
push eventjakevdp/jax
commit sha 9f53d2a8d883bc6eefc1f701cdc6639bdc774f84
Internal change PiperOriginRevId: 332920102
commit sha 55c6bdfe9c5d0631200cb76e4f56481b3656b03f
Cleanup todos related to the upgrade of jaxlib. PiperOriginRevId: 332932271
commit sha 67bd83e4f04e47688f809235d57400ee7056776a
Allow jax objects to be represented by multiple buffers
push time in 5 days
push eventjakevdp/jax
commit sha 6b9dfb139679eff878ffcf1022d392cf5c0123ad
fix incorrect indentation
commit sha 7fd7009c231bf9a86e009f6e86ccd4b374cb29dc
Expose scale_and_translate as a public function, fix a bug in implementation when translation is not 0. Change implementation to use native JAX everywhere allowing vmaping and gradients wrt scale and translation.
commit sha c33335b2d6246565d0f85a312793d56c08c6e8c5
Merge pull request #4229 from johnpjf:changelist/330579231 PiperOriginRevId: 332886460
commit sha 24fe07ebd14fd234fa5c26142dd4870b13eecd65
Merge pull request #4345 from jakevdp:indentation PiperOriginRevId: 332894503
commit sha 47869afe136be71e69c0c179ca53fd2009384bf0
Allow jax objects to be represented by multiple buffers
push time in 5 days
push eventjakevdp/jax
commit sha 90cb99fc040b6cceaf81ff0bb0cea552c6a0baec
insert a stop_gradient in lax.create_token (as it forces a false dependency on the operand)
commit sha 9ddd252b56e01f7342aca0d2a98060eefc624cec
[jax2tf] Add primitive conversion for cholesky_p.
commit sha b3930f0b7e957e95078f77d011138f90ce95643d
Make the custom_assert a oneliner.
commit sha 2bc92f5593c04dfd9a7bbb65ed77b70092038031
Fixed ppermute translation rule (#4349)
commit sha 16e60360940f2baa4f3bc4f94b58680af8295198
trivial change to test source sync PiperOriginRevId: 332544315
commit sha f172fb74e17f250f769ca398d2eb563d88153ac2
plumb donate_argnums into jax.xla_computation
commit sha cfbaca0507b451305d05d5b2f482d2be605ed77a
Merge pull request #4330 from google:createtokenstopgrad PiperOriginRevId: 332564392
commit sha e88579f22b9185790638b3562ce0a97f994a2af3
fix typo
commit sha a6b3fa2c551ad567d260a99fbd6d9a1c60d5c649
add trivial test
commit sha ff5c14570251c4d0c9b469c05a47e62373207853
Merge pull request #4353 from google:xlacomputationdonateargnums PiperOriginRevId: 332573843
commit sha d9f9b50dfaa83199b9d569ebda6698e4dc88d020
Merge pull request #4343 from SIben:test_cholesky PiperOriginRevId: 332627994
commit sha 50dd9c5016ce0bceaad7256208e790b0b02a3142
don't initialize backend in xla_computation This should allow us to use donate_argnums *and* build HLO computations for backends not available at build time.
commit sha 1aab5ced9fbb1f97bb31f84b264b9b684c83108c
fix logic
commit sha 1092fa1d9b3de680aca5a883b781e6ba97c940be
fix logic, skip test
commit sha 85d070f0cdc665b5aad6f3341d09fe4a60d0368d
Merge pull request #4356 from google:xlacomputationdontinitializebackend PiperOriginRevId: 332683955
commit sha 1cde76b130298c4d890d38242d19cff7b65ffaf9
[jax2tf] Implementation of the conversion of eig_p.
commit sha 015bc3c2cc7736720e362b15274d59e85e08359c
Replace manual conj + transpose with call to adjoint.
commit sha 1d94363df02a964df02f0c77795bb9939e76b1fb
Ignore eig conversion test on TPU/GPU, as it is unimplemented in JAX.
commit sha 695e8d88c3a827e0d73bf47022128ad7c1761420
Merge pull request #4338 from SIben:test_eig_p PiperOriginRevId: 332813984
commit sha 2d0de43d88b3898a7c0fd0408b12ba8eb4836c31
Make xla.device_put() return tuples
push time in 5 days
push eventjakevdp/jax
commit sha ce1ce9cb276045324e9122830aa3890fef16af0f
Implement jnp.array_equiv
push time in 5 days
issue commentvega/vegalite
You can see the issue directly in your javascript console:
> new Date("20111002")
Sat Oct 01 2011 17:00:00 GMT0700 (Pacific Daylight Time)
You need to use either a full ISO 8601 date string, or it will be parsed as UTC.
comment created time in 5 days
push eventjakevdp/jax
commit sha e4219391d5f2f1b9e93c3cb71fd01e29881fd87f
Implement jnp.array_equiv
push time in 5 days
push eventjakevdp/jax
commit sha 90cb99fc040b6cceaf81ff0bb0cea552c6a0baec
insert a stop_gradient in lax.create_token (as it forces a false dependency on the operand)
commit sha 9ddd252b56e01f7342aca0d2a98060eefc624cec
[jax2tf] Add primitive conversion for cholesky_p.
commit sha b3930f0b7e957e95078f77d011138f90ce95643d
Make the custom_assert a oneliner.
commit sha 2bc92f5593c04dfd9a7bbb65ed77b70092038031
Fixed ppermute translation rule (#4349)
commit sha 16e60360940f2baa4f3bc4f94b58680af8295198
trivial change to test source sync PiperOriginRevId: 332544315
commit sha f172fb74e17f250f769ca398d2eb563d88153ac2
plumb donate_argnums into jax.xla_computation
commit sha cfbaca0507b451305d05d5b2f482d2be605ed77a
Merge pull request #4330 from google:createtokenstopgrad PiperOriginRevId: 332564392
commit sha e88579f22b9185790638b3562ce0a97f994a2af3
fix typo
commit sha a6b3fa2c551ad567d260a99fbd6d9a1c60d5c649
add trivial test
commit sha ff5c14570251c4d0c9b469c05a47e62373207853
Merge pull request #4353 from google:xlacomputationdonateargnums PiperOriginRevId: 332573843
commit sha d9f9b50dfaa83199b9d569ebda6698e4dc88d020
Merge pull request #4343 from SIben:test_cholesky PiperOriginRevId: 332627994
commit sha 50dd9c5016ce0bceaad7256208e790b0b02a3142
don't initialize backend in xla_computation This should allow us to use donate_argnums *and* build HLO computations for backends not available at build time.
commit sha 1aab5ced9fbb1f97bb31f84b264b9b684c83108c
fix logic
commit sha 1092fa1d9b3de680aca5a883b781e6ba97c940be
fix logic, skip test
commit sha 85d070f0cdc665b5aad6f3341d09fe4a60d0368d
Merge pull request #4356 from google:xlacomputationdontinitializebackend PiperOriginRevId: 332683955
commit sha 1cde76b130298c4d890d38242d19cff7b65ffaf9
[jax2tf] Implementation of the conversion of eig_p.
commit sha 015bc3c2cc7736720e362b15274d59e85e08359c
Replace manual conj + transpose with call to adjoint.
commit sha 1d94363df02a964df02f0c77795bb9939e76b1fb
Ignore eig conversion test on TPU/GPU, as it is unimplemented in JAX.
commit sha 695e8d88c3a827e0d73bf47022128ad7c1761420
Merge pull request #4338 from SIben:test_eig_p PiperOriginRevId: 332813984
commit sha b0b3011843158151e3261da935997203c06da61d
Implement jnp.array_equiv
push time in 5 days
issue commentaltairviz/altair
QUESTION: proxy URLs to CSV files are being interpreted as JSON
I'm not sure how Vega determines the default, but alt.UrlData
has a format
parameter that can be used to specify the format.
comment created time in 5 days
issue commentjohannesjmeyer/rsmf
The vegacli package has methods to save vega/vegalite charts to png, svg, and pdf. If none of those are suitable for the purposes here, that package would be the place to implement other formats.
comment created time in 5 days
issue commentaltairviz/altair
Support for vega custom projections
Altair's only control over a chart comes from what is defined in the VegaLite specification. Yes, you can specify a custom projection in the specification, but in order for it to do anything, that custom projection has to be defined in the renderer.
So where is the renderer? It's in the JuptyerLab frontend code... the nteract frontend code... the VSCode frontend code... the Streamlit frontend code. Except in the simplest circumstances, Altair has no control over how the chart is rendered, beyond what can be specified in the VegaLite spec.
Theoretically, would it be possible to coordinate all these disparate projects and decide on a standard whereby Javascript extensions could be injected into them? Maybe. Probably not, though, given how problematic usergenerated Javascript can be securitywise. At best, it would be a feature that works in some frontends and not in others.
comment created time in 5 days
issue closedaltairviz/altair
Support for vega custom projections
Vega supports the use of custom projections using vega.projection, and thirdparty d3 projections are available (see e.g., this library). But, Altair has a hardcoded list of supported projections, meaning it's not possible to use Altair with custom Vega projections. There should be a way to override the check of known projections in Altair, so that one can use other custom projections.
closed time in 6 days
vectroissue commentaltairviz/altair
Support for vega custom projections
Thanks for the request!
This can be done by changing the Javascript embedding code for the vegalite specifications that altair produces.
That said, Altair is fundamentally a tool to create vegalite specifications, and vega projections are not expressible in vegalite. Given that, I'm going to close this feature request as outofscope for the project.
comment created time in 6 days
issue commentaltairviz/altair
You can do this with mark_square()
 this is akin to mark_point()
and mark_circle()
in that it generates points that can have a size encoding, whereas rect
generates rectangles whose size is controlled by x,x2,y,y2 encodings:
import altair as alt
import numpy as np
import pandas as pd
# Compute x^2 + y^2 across a 2D grid
x, y = np.meshgrid(range(10, 10), range(5, 5))
z = x ** 2 + y ** 2
# Convert this grid to columnar data expected by Altair
source = pd.DataFrame({'x': x.ravel(),
'y': y.ravel(),
'z': z.ravel()})
alt.Chart(source).mark_square().encode(
x='x:O',
y='y:O',
color='z:Q',
size='z:Q'
)
comment created time in 8 days
PR closed google/jax
pr closed time in 8 days
PR opened google/jax
pr created time in 8 days
push eventjakevdp/jax
commit sha a31c6fc458eb05f40f84c5f1089f47f675d48127
fix shape mismatch
push time in 8 days
push eventjakevdp/jax
commit sha b3a098747aa6c674265796ddff8191a2e6a81efe
Make expm transposable and remove custom_jvp rule. (#4314) * Make expm transposable and remove custom_jvp rule. * Add check_grads for up to 2nd order derivative.
commit sha d478e346ac2cb1c7f531c61fb2888ccbbe9a3e60
Fix conditional in eig and expand eig test suite. (#4320) * Fix conditional in eig and expand eig test suite.
commit sha 0ac25c760af2788ba0a14f8a0f5827035e4c29e0
[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278) * [jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 We now fixed tf.raw_ops.AddV2 to support uint32. It was already supporting uint8, so it is a better choice now than tf.math.add. This allowed us to use the threefry implementation using uint32.
commit sha ded7b3854c37f93652e7eed4f4cd50522f3be70f
[jax2tf] Revert '[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278)' (#4332) Generates errors due to Grappler replacing AddV2 with AddN, which is not implemented for uint32
commit sha 8376d92049624bf0784647b17b1f09015acd0947
Disable testExpmGrad on TPU, pending investigation of compiler error (#4333)
commit sha 2911bcd63427bb433193e5450c967a79ddec70f5
Enable complexvalued Cholesky decomposition tests on TPU> (#4339)
commit sha 6a89f60683d19f2797be8e777069a0ba7c515444
fix benchmark sums (#4329)
commit sha 6614f94890429c6c4c9dd46f932e22653dd6316d
rename and simplify TypedJaxpr > ClosedJaxpr (#4328) rename and simplify TypedJaxpr > ClosedJaxpr This change: * simplifies code that constructs TypedJaxprs/ClosedJaxprs (because in_avals / out_avals no longer need to be constructed), making them easier to work with; * correspondingly rules out a class of errors (mismatches between invars/outvars and in_avals/out_avals); * provides a more descriptive class name (ClosedJaxprs are like jaxprs but they're closed in that they are packaged with their constant values). This is part 1 of an attempt to remove TypedJaxprs completely, or at least significantly reduce our use of them. However, I'm not getting rid of them entirely in this first step because it'd require bigger changes (basically allowing all constants to be represented as literals, rather than only scalars) that would not only touch a lot more code (jaxpr formation, jaxprtojaxpr transformations, control flow, XLA lowering) but also might affect XLA lowering right before a conference deadline (ICLR). Plus I'm trying to make big changes in smaller steps :) Coauthoredby: George Necula <gcnecula@gmail.com>
commit sha 56d3333716f462b551ab1c57e7c2dec397827a62
Implement jnp.array_equiv
push time in 8 days
pull request commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
api_benchmark results: master branch:

Benchmark Time CPU Iterations

jit_trivial_dispatch 120578 ns 120539 ns 5381
jit_trivial 136459 ns 136437 ns 5146
jit_simple_dispatch 495457 ns 494981 ns 1461
jit_simple 510835 ns 510391 ns 1185
jit_simple_many_args_dispatch 2863614 ns 2863011 ns 246
jit_simple_many_args 2872962 ns 2872501 ns 234
jit_dispatch_without_transfer 6348334 ns 6347968 ns 91
jit_dispatch_with_transfer 6126572 ns 6126012 ns 117
with this change:

Benchmark Time CPU Iterations

jit_trivial_dispatch 106872 ns 106829 ns 6227
jit_trivial 127626 ns 127625 ns 5493
jit_simple_dispatch 499879 ns 499377 ns 1417
jit_simple 512144 ns 512095 ns 1342
jit_simple_many_args_dispatch 3088399 ns 3087974 ns 227
jit_simple_many_args 3043566 ns 3042474 ns 226
jit_dispatch_without_transfer 6320484 ns 6320214 ns 107
jit_dispatch_with_transfer 6268184 ns 6266837 ns 116
comment created time in 8 days
push eventjakevdp/jax
commit sha 76a3bccb19fd861375d78f14280f4fc5e7adc642
fix unused import
push time in 8 days
PR opened google/jax
pr created time in 8 days
push eventjakevdp/jax
commit sha f0aaaedde92f2792799b97f7ee5bde6abbb6f328
Error early when ShardedDeviceArray is passed nested buffers.
push time in 8 days
push eventjakevdp/jax
commit sha b0caa6e86fe52b77d689d7a4acd79370e024e6d8
more pytype errors
push time in 8 days
push eventjakevdp/jax
commit sha 6614f94890429c6c4c9dd46f932e22653dd6316d
rename and simplify TypedJaxpr > ClosedJaxpr (#4328) rename and simplify TypedJaxpr > ClosedJaxpr This change: * simplifies code that constructs TypedJaxprs/ClosedJaxprs (because in_avals / out_avals no longer need to be constructed), making them easier to work with; * correspondingly rules out a class of errors (mismatches between invars/outvars and in_avals/out_avals); * provides a more descriptive class name (ClosedJaxprs are like jaxprs but they're closed in that they are packaged with their constant values). This is part 1 of an attempt to remove TypedJaxprs completely, or at least significantly reduce our use of them. However, I'm not getting rid of them entirely in this first step because it'd require bigger changes (basically allowing all constants to be represented as literals, rather than only scalars) that would not only touch a lot more code (jaxpr formation, jaxprtojaxpr transformations, control flow, XLA lowering) but also might affect XLA lowering right before a conference deadline (ICLR). Plus I'm trying to make big changes in smaller steps :) Coauthoredby: George Necula <gcnecula@gmail.com>
commit sha 50d84c518ae9e16a94585976296cd6cc1cff8110
Make xla.device_put() return tuples
commit sha 4055bf638ccd620167aea1b12da3498e54fc2f42
Make xla.aval_to_xla_shape return a tuple
commit sha 21ae6c88aa1ca7c4c856d3fbba178c5d8a2acade
make xla.aval_to_result_handler return number of args
commit sha 5dc684dbdf7b515d9495e69e0115efabb0ce4e81
Add ability support for multiple device buffers per object
commit sha cf0ce9112dd73f2802da284d4ab52189a157d219
Add custom object tests
commit sha 80c66ca16aed8b206149787a8a683d128d32ded9
fix multibuffer device put
commit sha e455e48c425d9c70bf89e6e3fc2f1f747bafa621
streamline tests
commit sha 8bb730e797331595a75ed6bf134595c646164c70
return xops tuple from translation rule
commit sha eeb0e3d0b637be0fecdfc7ff8bf798b5ee02fd58
cleanup: use itertools.chain.from_iterable
commit sha 693e77506cefe81ad4cb458b59382bafcb30df57
more minor cleanups
commit sha 5ecc37518ff74dc92f84b6b9dc36a300cac81fcf
minor fixes
commit sha d601e4b84dd934c42dafa4fc887a1b939aacfbc3
fix missing fstring
commit sha 2f3c2f2eb8d9a3714debdc6e1919a943ec080db2
fix interaction between multiple_results and destructure
commit sha f2f3b2097a8ffb0c10a7cff7d19adc647eaf11b8
simplify xla_result_handler definitions
commit sha 117cd70456ade0175b162adfab67486db8e79787
further simplify result handlers
commit sha 62cdfc97fd154168b49579193d603bec72c13edf
more simplifications of result handlers
commit sha 821df45f0daba7741743e0dfd173b505c45410e8
fix internal pytype
commit sha 09297cbfdbdf6500001d9eef8b956917bcfbfd63
add primitives for buffer access
commit sha f76f9d5805cdef10a2e4a4141e6986a56e0ecd6b
Add matvec test
push time in 8 days
push eventjakevdp/jax
commit sha 8f6f3619f1e2b44b889c103101aa7563876f7acf
fix type annotation
push time in 8 days
issue commentgoogle/jax
Implement NumPy sorting routines
Yes, I don't think anyone has worked on partitions yet. It's not trivial to do, because they're not (yet?) implemented in XLA
It may be possible to implement using scans (similar to how jnp.searchsorted
is implemented), but in the end it might actually be faster to implement partitioning via a full sort, since XLA knows how to do this efficiently.
comment created time in 8 days
push eventjakevdp/jax
commit sha 11007ba0e3577dab0c19ac37f60061fdccecb016
test eval_context works w/ and w/o omnistaging (#4325)
commit sha 8a4ee3d8516662b6f6fd95046f4d28119b3c1db2
Fix shape checking rule for conv_general_dilated. (#4318) * Fix shape checking rule for conv_general_dilated. This closes google/jax#4316. * Added test based on google/jax#4316. * Change test name to be more accurate.
commit sha e0af77fbca88ec882ef33c18926f8382f7178aed
Implement jnp.ravel_multi_index() (#4313)
commit sha b3a098747aa6c674265796ddff8191a2e6a81efe
Make expm transposable and remove custom_jvp rule. (#4314) * Make expm transposable and remove custom_jvp rule. * Add check_grads for up to 2nd order derivative.
commit sha d478e346ac2cb1c7f531c61fb2888ccbbe9a3e60
Fix conditional in eig and expand eig test suite. (#4320) * Fix conditional in eig and expand eig test suite.
commit sha 0ac25c760af2788ba0a14f8a0f5827035e4c29e0
[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278) * [jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 We now fixed tf.raw_ops.AddV2 to support uint32. It was already supporting uint8, so it is a better choice now than tf.math.add. This allowed us to use the threefry implementation using uint32.
commit sha ded7b3854c37f93652e7eed4f4cd50522f3be70f
[jax2tf] Revert '[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278)' (#4332) Generates errors due to Grappler replacing AddV2 with AddN, which is not implemented for uint32
commit sha 8376d92049624bf0784647b17b1f09015acd0947
Disable testExpmGrad on TPU, pending investigation of compiler error (#4333)
commit sha 2911bcd63427bb433193e5450c967a79ddec70f5
Enable complexvalued Cholesky decomposition tests on TPU> (#4339)
commit sha 6a89f60683d19f2797be8e777069a0ba7c515444
fix benchmark sums (#4329)
commit sha b8cb85cf75e87696a6ea613e8b7205d68aa1ffb0
Make xla.device_put() return tuples
commit sha 5792b0980586ea169dfd924958f02f2907e6a410
Make xla.aval_to_xla_shape return a tuple
commit sha 10eafcf98bf113c442642677d89e09ded4107f9e
make xla.aval_to_result_handler return number of args
commit sha bdb15be9fffac6e1fd59c139bd0c13c4bd8b4653
Add ability support for multiple device buffers per object
commit sha 96dc04226e757fb7a02c942bb4ddedc8814037ec
Add custom object tests
commit sha 10ad10404b2f88b81107cad3fc73356b0827091d
fix multibuffer device put
commit sha aa24c2f0ebacabde0cdead539978aa8fc1e4dc10
streamline tests
commit sha 1549952c722fe23da685f6a82e05c081fffcaceb
return xops tuple from translation rule
commit sha b3278991ba5f235206dc6ac25fff66254f87a5bb
cleanup: use itertools.chain.from_iterable
commit sha b5b4901a7562c5bf3e19c9a5834d11b643ccecf9
more minor cleanups
push time in 8 days
Pull request review commentgoogle/jax
def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0): if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'")+ # check if the given dtype is compatible with JAX lax._check_user_dtype_supported(dtype, "array") dtype = dtype and dtypes.canonicalize_dtype(dtype) if _can_call_numpy_array(object): object = _np_array(object, dtype=dtype, ndmin=ndmin)++ if type(object) is np.ndarray:+ # check if the inferred np type is compatible with JAX+ _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+ lax._check_user_dtype_supported(_inferred_dtype, "array")+
No worries!
comment created time in 8 days
push eventgoogle/jax
commit sha 6a89f60683d19f2797be8e777069a0ba7c515444
fix benchmark sums (#4329)
push time in 8 days
PR merged google/jax
This was broken by #4195
pr closed time in 8 days
Pull request review commentgoogle/jax
def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0): if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'")+ # check if the given dtype is compatible with JAX lax._check_user_dtype_supported(dtype, "array") dtype = dtype and dtypes.canonicalize_dtype(dtype) if _can_call_numpy_array(object): object = _np_array(object, dtype=dtype, ndmin=ndmin)++ if type(object) is np.ndarray:+ # check if the inferred np type is compatible with JAX+ _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+ lax._check_user_dtype_supported(_inferred_dtype, "array")+
The new version seems redundant: if the input is a numpy array, the typecheck now happens twice. What I had in mind is this:
def array(object, dtype=None, copy=True, order="K", ndmin=0):
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
# check if the given dtype is compatible with JAX
lax._check_user_dtype_supported(dtype, "array")
dtype = dtype and dtypes.canonicalize_dtype(dtype)
if _can_call_numpy_array(object):
object = _np_array(object, dtype=dtype, ndmin=ndmin)
assert type(object) not in dtypes.python_scalar_dtypes
if type(object) is np.ndarray:
_inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
lax._check_user_dtype_supported(_inferred_dtype, "array")
out = _device_put_raw(object)
if dtype: assert _dtype(out) == dtype
elif isinstance(object, (DeviceArray, core.Tracer)):
#...
a single check that takes care of every case where object
is an array.
comment created time in 8 days
Pull request review commentgoogle/jax
def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0): if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'")+ # check if the given dtype is compatible with JAX lax._check_user_dtype_supported(dtype, "array") dtype = dtype and dtypes.canonicalize_dtype(dtype) if _can_call_numpy_array(object): object = _np_array(object, dtype=dtype, ndmin=ndmin)++ if type(object) is np.ndarray:+ # check if the inferred np type is compatible with JAX+ _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+ lax._check_user_dtype_supported(_inferred_dtype, "array")+
Understood  but another path to this is if the input is a numpy array with dtype object
, and it was not clear to me without looking at the source of _can_call_numpy_array
whether that block would be executed in that case. It turns out it is. But due to that obfuscation I thought it would be clearer to put the check in the place where it's obvious at first glance that it will be called for every numpy array input.
What do you think?
comment created time in 8 days
Pull request review commentgoogle/jax
def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0): if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'")+ # check if the given dtype is compatible with JAX lax._check_user_dtype_supported(dtype, "array") dtype = dtype and dtypes.canonicalize_dtype(dtype) if _can_call_numpy_array(object): object = _np_array(object, dtype=dtype, ndmin=ndmin)++ if type(object) is np.ndarray:+ # check if the inferred np type is compatible with JAX+ _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+ lax._check_user_dtype_supported(_inferred_dtype, "array")+
I think the check should be put in the existing if type(object) is np.ndarray
block just below this.
comment created time in 8 days
issue commentscipy/scipy
scipy.sparse.csgraph.connected_components does not work for directed graphs
No problem, glad you got it figured out!
comment created time in 9 days
issue commentgoogle/jax
Yeah, that definitely looks like it could be expressed as a convolution. I'd suggest going that route rather than writing all the loops manually. You can construct a kernel that essentially encodes the sums that happen at each location, and do the whole thing in one operation.
comment created time in 9 days
Pull request review commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
def shard_args(devices: Sequence[xb.xla_client.Device], shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {} shard_arg_handlers[core.Unit] = \ lambda x, devices, _: [xla.device_put(core.unit, d) for d in devices]+ lambda x, devices, _: list(it.chain.from_iterable(xla.device_put(core.unit, d) for d in devices)) def _shard_array(x, devices, indices): return [xla.device_put(x[i], d) for (i, d) in zip(indices, devices)]+ return list(it.chain.from_iterable(xla.device_put(x[i], d) for (i, d) in zip(indices, devices)))
Done, see below.
comment created time in 9 days
pull request commentgoogle/jax
Allow JAX objects to be represented by multiple buffers
Running benchmarks/pmap_benchmark.py
.
Results on master:
Benchmark results for pmap_shard_sharded_device_array_nargs=10_nshards=1
mean=0.018155 std=0.000451 %std=2.484562 total=10.003359
#iters=551 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=100_nshards=1
mean=0.085240 std=0.001963 %std=2.303262 total=10.058309
#iters=118 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=101_nshards=1
mean=0.088433 std=0.003865 %std=4.370711 total=10.081317
#iters=114 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=500_nshards=1
mean=0.389694 std=0.013617 %std=3.494267 total=10.132035
#iters=26 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=1000_nshards=1
mean=0.792406 std=0.023802 %std=3.003702 total=10.301281
#iters=13 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=5000_nshards=1
mean=4.381965 std=0.090272 %std=2.060070 total=13.145895
#iters=3 #warmup=1
Benchmark summary for pmap_shard_sharded_device_array
nargs nshards mean %std relative
    
10 1 0.0181549 2.48456 1
100 1 0.0852399 2.30326 4.69514
101 1 0.0884326 4.37071 4.871
500 1 0.389694 3.49427 21.4649
1000 1 0.792406 3.0037 43.6469
5000 1 4.38197 2.06007 241.365
Benchmark results for pmap_shard_device_array_nargs=10_nshards=1
mean=0.015208 std=0.001568 %std=10.307562 total=10.006748
#iters=658 #warmup=1
Benchmark results for pmap_shard_device_array_nargs=100_nshards=1
mean=0.130824 std=0.003322 %std=2.539380 total=10.073420
#iters=77 #warmup=1
Benchmark results for pmap_shard_device_array_nargs=500_nshards=1
mean=0.725450 std=0.077750 %std=10.717554 total=10.156296
#iters=14 #warmup=1
Benchmark summary for pmap_shard_device_array
nargs nshards mean %std relative
    
10 1 0.0152078 10.3076 1
100 1 0.130824 2.53938 8.60239
500 1 0.72545 10.7176 47.7024
Benchmark results for pmap_shard_outputs_nouts=10_nshards=1
mean=0.024094 std=0.003001 %std=12.455350 total=10.023266
#iters=416 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=100_nshards=1
mean=0.093125 std=0.004977 %std=5.344291 total=10.057464
#iters=108 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=500_nshards=1
mean=0.436330 std=0.085204 %std=19.527318 total=10.035598
#iters=23 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=1000_nshards=1
mean=0.828440 std=0.016951 %std=2.046166 total=10.769719
#iters=13 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=5000_nshards=1
mean=4.229745 std=0.007748 %std=0.183187 total=12.689235
#iters=3 #warmup=1
Benchmark summary for pmap_shard_outputs
nouts nshards mean %std relative
    
10 1 0.0240944 12.4553 1
100 1 0.0931247 5.34429 3.86499
500 1 0.43633 19.5273 18.1092
1000 1 0.82844 2.04617 34.3831
5000 1 4.22974 0.183187 175.549
Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_indices
mean=8.168176 std=0.702482 %std=8.600236 total=16.336353
#iters=2 #warmup=1
Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_2D_indices
mean=11.738572 std=0.000000 %std=0.000000 total=11.738572
#iters=1 #warmup=1
Benchmark summary for ShardedDeviceArray_indexing
indices_fn mean %std relative
   
integer_indices 8.16818 8.60024 1
integer_2D_indices 11.7386 0 1.43711
Results on this branch:
Benchmark results for pmap_shard_sharded_device_array_nargs=10_nshards=1
mean=0.019015 std=0.001062 %std=5.584433 total=10.001693
#iters=526 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=100_nshards=1
mean=0.094039 std=0.007497 %std=7.972502 total=10.062211
#iters=107 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=101_nshards=1
mean=0.090250 std=0.003046 %std=3.375537 total=10.017800
#iters=111 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=500_nshards=1
mean=0.407412 std=0.035182 %std=8.635389 total=10.185308
#iters=25 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=1000_nshards=1
mean=0.764858 std=0.011401 %std=1.490611 total=10.708011
#iters=14 #warmup=1
Benchmark results for pmap_shard_sharded_device_array_nargs=5000_nshards=1
mean=4.366896 std=0.048288 %std=1.105768 total=13.100688
#iters=3 #warmup=1
Benchmark summary for pmap_shard_sharded_device_array
nargs nshards mean %std relative
    
10 1 0.0190146 5.58443 1
100 1 0.0940394 7.9725 4.94563
101 1 0.0902505 3.37554 4.74637
500 1 0.407412 8.63539 21.4263
1000 1 0.764858 1.49061 40.2247
5000 1 4.3669 1.10577 229.66
Benchmark results for pmap_shard_device_array_nargs=10_nshards=1
mean=0.016362 std=0.000771 %std=4.711918 total=10.013431
#iters=612 #warmup=1
Benchmark results for pmap_shard_device_array_nargs=100_nshards=1
mean=0.159848 std=0.019341 %std=12.099536 total=10.070426
#iters=63 #warmup=1
Benchmark results for pmap_shard_device_array_nargs=500_nshards=1
mean=0.756038 std=0.040831 %std=5.400649 total=10.584529
#iters=14 #warmup=1
Benchmark summary for pmap_shard_device_array
nargs nshards mean %std relative
    
10 1 0.0163618 4.71192 1
100 1 0.159848 12.0995 9.76958
500 1 0.756038 5.40065 46.2074
Benchmark results for pmap_shard_outputs_nouts=10_nshards=1
mean=0.024195 std=0.001320 %std=5.455058 total=10.016746
#iters=414 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=100_nshards=1
mean=0.094919 std=0.010104 %std=10.644888 total=10.061403
#iters=106 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=500_nshards=1
mean=0.410957 std=0.023368 %std=5.686321 total=10.273933
#iters=25 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=1000_nshards=1
mean=0.885113 std=0.035264 %std=3.984146 total=10.621352
#iters=12 #warmup=1
Benchmark results for pmap_shard_outputs_nouts=5000_nshards=1
mean=4.575136 std=0.038652 %std=0.844824 total=13.725407
#iters=3 #warmup=1
Benchmark summary for pmap_shard_outputs
nouts nshards mean %std relative
    
10 1 0.024195 5.45506 1
100 1 0.0949189 10.6449 3.92307
500 1 0.410957 5.68632 16.9852
1000 1 0.885113 3.98415 36.5824
5000 1 4.57514 0.844824 189.094
Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_indices
mean=7.891942 std=0.006513 %std=0.082526 total=15.783884
#iters=2 #warmup=1
Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_2D_indices
mean=12.761788 std=0.000000 %std=0.000000 total=12.761788
#iters=1 #warmup=1
Benchmark summary for ShardedDeviceArray_indexing
indices_fn mean %std relative
   
integer_indices 7.89194 0.0825257 1
integer_2D_indices 12.7618 0 1.61707
It looks like for functions with many buffers (e.g. pmap_shard_outputs
) there is a ~5% cost associated wtih the buffer tupling/flattening introduced by this code (4.37 seconds > 4.58 seconds).
comment created time in 9 days
issue commentscipy/scipy
scipy.sparse.csgraph.connected_components does not work for directed graphs
Note that the default is weaklyconnected components. If you want stronglyconnected components, you can set the connection='strong'
keyword.
comment created time in 9 days
issue commentscipy/scipy
scipy.sparse.csgraph.connected_components does not work for directed graphs
When I run your code using scipy 1.5.2, I get the following output:
(0, 1) 1
(0, 2) 1
(1, 2) 1
(3, 4) 1
[0 0 0 1 1]
This appears to be correct (nodes 0, 1, 2 are in a group with mutual connections, and nodes 3, 4 are in a second group with mutual connections).
Is this the output that you're seeing?
comment created time in 9 days
PR opened google/jax
Broken by #4195
pr created time in 9 days