profile
viewpoint
Jake Vanderplas jakevdp Google Oakland CA http://www.vanderplas.com Python, Astronomy, Data Science

google/jax 9746

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

altair-viz/altair 5966

Declarative statistical visualization library for Python

altair-viz/pdvega 336

Interactive plotting for Pandas using Vega-Lite

altair-viz/altair-tutorial 257

Notebooks for the Altair tutorial

davidwhogg/MagicCube 195

don't ask

jakevdp/2013_fall_ASTR599 128

Content for my Astronomy 599 Course: Intro to scientific computing in Python

altair-viz/altair_widgets 102

Interactive data exploration with Altair

altair-viz/vega_datasets 94

A Python package for online & offline access to vega datasets

altair-viz/altair_pandas 71

Altair backend for pandas plotting

altair-viz/altair-transform 48

Evaluation of Vega-Lite transforms in Python

delete branch jakevdp/jax

delete branch : axis-concrete

delete time in 16 hours

PullRequestReviewEvent

PR opened google/jax

Use core.concrete_or_error() to improve errors in reductions better_errors

Old error:

>>> import jax.numpy as jnp
>>> from jax import jit                                                                             
>>> jit(jnp.sum)(jnp.arange(4), 0) 
[...]
TypeError: Unexpected type of axis argument: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

New error:

>>> jit(jnp.sum)(jnp.arange(4), 0) 
[...]
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

axis argument to jnp.sum().

The error occured while tracing the function sum at /Users/vanderplas/github/google/jax/jax/numpy/lax_numpy.py:1737.

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
+3 -0

0 comment

2 changed files

pr created time in 17 hours

create barnchjakevdp/jax

branch : axis-concrete

created branch time in 17 hours

PullRequestReviewEvent

delete branch jakevdp/jax

delete branch : unop-check-array

delete time in 20 hours

push eventjakevdp/jax

Jake VanderPlas

commit sha 80d5f9ce6c4994a0fe32180f7763ee512550231b

jax.numpy: improved errors for invalid inputs to unary ops

view details

Alex Botev

commit sha 6aa35a2a7351e3c2bbf2fa1f58628801e95c8791

Adding an option to return the output tree in make_jaxpr

view details

Jake VanderPlas

commit sha ab273acb2173e8f1a5343f26432cc27cb8a1a991

Call _check_arraylike in jax.numpy to improve error messages This PR only adds the call to places where non-array inputs currently lead to errors. There remain a number of other functions where adding this check would lead to potentially breaking changes; these are deliberately left out of this PR.

view details

jax authors

commit sha 80fa22cf90459cdee8a6f1fbee7a0f7fedcfcb0c

Merge pull request #4381 from jakevdp:check-arraylike PiperOrigin-RevId: 333170269

view details

Peter Hawkins

commit sha 7e1b826ef56ff98307fda0d7614d27d6092dd3aa

Enable fast TPU LU decomposition for complex types.

view details

George Necula

commit sha 625be69333ad46c7e0b4cb765060054faf5f2927

[host_callback] Update the documentation The module-level documentation was out of date.

view details

Adam Paszke

commit sha 2b7580c2d206e692a83fbc37eaeeb846e0a4462f

Consider lists as groups of axis names too

view details

Benjamin Chetioui

commit sha 3360bee9e96f2fbc8bcdb4746e58bfba1f771476

[jax2tf] Adjust tolerance in flaky float32 eigh test.

view details

Peter Hawkins

commit sha 89b989654e45e742973d1475f02a1119334eb632

Enable tests for complex QR decomposition on TPU.

view details

jax authors

commit sha f3d6132042ac1bfb1e541f8692569efc9169bf08

Merge pull request #4393 from hawkinsp:qr PiperOrigin-RevId: 333279018

view details

jax authors

commit sha c6cd2f91df0f37dd8987b6bed8847597ace9d121

Merge pull request #4392 from SIben:fix_eigh_flakiness_cpu PiperOrigin-RevId: 333294460

view details

jax authors

commit sha c9b3df3a64b8a6398d6e3a705e534c4494502fea

Merge pull request #4382 from hawkinsp:lu PiperOrigin-RevId: 333296510

view details

jax authors

commit sha c875ab3ec9b2ab4794e4068e64172c9869e1b618

Merge pull request #4391 from apaszke:axis_index_handle_list PiperOrigin-RevId: 333304709

view details

Adam Paszke

commit sha 0d5f15f5c0f862d3d5cbf8f6853409cdc1091de1

Fix the abstract eval and translation rule for all_to_all The previous rules assumed that `split_axis == concat_axis` (i.e. that the used collective is equivalent to `pswapaxes`). Since we expose this as part of our API, we should probably make sure that we handle other cases too. Fixes #1332.

view details

Matthew Johnson

commit sha c42d736e347d059290ab20c013635d62a1ee6c45

remove limit on size of random arrays

view details

Matthew Johnson

commit sha 96f5a3c4026c929664e75e07262bba7a4c8d2044

fix test for non-omnistaging

view details

Matthew Johnson

commit sha 71f5f9972cd305d4060637115a7ff316087d229e

skip checks in big randomness test

view details

Matthew Johnson

commit sha d607164d35e05077d60969a8e6b145f10aca95e0

make_jaxpr return_shape use ShapeDtypeStruct, test

view details

jax authors

commit sha c7e0ef4075ed5166cd2a289057b75b682002a1e9

Merge pull request #4398 from google:lift-randomness-limit PiperOrigin-RevId: 333433816

view details

Matthew Johnson

commit sha ebf7c1b6127d6df3a64a4d969297ca70e126a341

add jax logo file

view details

push time in 20 hours

issue commentgoogle/jax

Feature request: support more flexible handling of seeds and seed-like objects like `numpy.random.SeedSequence`

Hi – this is an interesting idea, and definitely would make ineroperability between np.random and jax.random easier.

One concern I have is that numpy's random interface is stateful, where JAX's is not. This makes things behave in possibly unexpected ways when they are combined with JAX's jit. For example:

from jax import jit
import jax.numpy as jnp
import numpy as np

def _numpy_generator_to_prngkey(seed: np.random.Generator) -> jnp.ndarray:
    return jnp.array(
        seed.integers(0, np.iinfo(np.uint32).max, (2,), dtype=np.uint32,
                      endpoint=True))
    
seed = np.random.default_rng(seed=42)

print(_numpy_generator_to_prngkey(seed) == _numpy_generator_to_prngkey(seed))
# False

_numpy_generator_to_prngkey = jit(_numpy_generator_to_prngkey, static_argnums=[0])
print(_numpy_generator_to_prngkey(seed) == _numpy_generator_to_prngkey(seed))
# True

In the jitted version of the function, the seed is only generated once at tracing time. I suspect that this sort of thing might lead to subtle bugs that would be very difficult to identify.

Do you have thoughts on that?

gehring

comment created time in 21 hours

issue commentaltair-viz/altair

Plots not showing up when served with voila without an ipywidget

I'm not sure about the details of the voila frontend. I'd suggest asking on the Voila issue tracker.

afonit

comment created time in a day

issue commentaltair-viz/altair

Question: Can you plot a Gaussian distribution?

Sure, for example:

import altair as alt
import numpy as np
import pandas as pd
from scipy.stats import norm

x = np.linspace(-5, 5, 1000)
df = pd.DataFrame({'x': x, 'y': norm.pdf(x)})
alt.Chart(df).mark_line().encode(
    x='x',
    y='y'
)

visualization - 2020-09-23T195702 765

davidpickup

comment created time in 2 days

push eventjakevdp/jax

Jake VanderPlas

commit sha ab273acb2173e8f1a5343f26432cc27cb8a1a991

Call _check_arraylike in jax.numpy to improve error messages This PR only adds the call to places where non-array inputs currently lead to errors. There remain a number of other functions where adding this check would lead to potentially breaking changes; these are deliberately left out of this PR.

view details

jax authors

commit sha 80fa22cf90459cdee8a6f1fbee7a0f7fedcfcb0c

Merge pull request #4381 from jakevdp:check-arraylike PiperOrigin-RevId: 333170269

view details

Jake VanderPlas

commit sha 2cef06d99c8d13b848420cace1bfa8a5e553f9d3

Add test coverage for jnp.cov aweights & fweights

view details

push time in 4 days

delete branch jakevdp/jax

delete branch : check-arraylike

delete time in 4 days

PR opened google/jax

Add test coverage for jnp.cov aweights & fweights
+38 -25

0 comment

2 changed files

pr created time in 4 days

create barnchjakevdp/jax

branch : cov-weights

created branch time in 4 days

issue commentgoogle/jax

Jax numpy does not throw IndexError when array index out of bounds

Hi - thanks for the report!

This is a known feature in JAX that you can read about here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Out-of-Bounds-Indexing

jawolf314

comment created time in 4 days

Pull request review commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from absl.testing import absltest, parameterized++import numpy as np++from jax import test_util as jtu+import jax.numpy as jnp+from jax import core, jit, lax, lazy, make_jaxpr+from jax.interpreters import xla+from jax.lib import xla_client+xops = xla_client.ops++from jax.config import config+config.parse_flags_with_absl()++# Define a sparse array data structure. The important feature here is that+# it is a jaxpr object that is backed by two device buffers.+class SparseArray:+  """Simple sparse COO array data structure."""+  def __init__(self, aval, data, indices):+    self.aval = aval+    self.shape = aval.shape+    self.data = data+    self.indices = indices++  @property+  def index_dtype(self):+    return self.indices.dtype++  @property+  def dtype(self):+    return self.data.dtype++  @property+  def nnz(self):+    return self.data.shape[0]++  def __repr__(self):+    return repr(list((tuple(ind), d) for ind, d in zip(self.indices, self.data)))+++class AbstractSparseArray(core.ShapedArray):+  __slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval']+  _num_buffers = 2++  def __init__(self, shape, dtype, index_dtype, nnz):+    super(AbstractSparseArray, self).__init__(shape, dtype)+    self.index_dtype = index_dtype+    self.nnz = nnz+    self.data_aval = core.ShapedArray((nnz,), dtype)+    self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype)++  @core.aval_property+  def data(self):+    return sp_data_p.bind(self)++  @core.aval_property+  def indices(self):+    return sp_indices_p.bind(self)++def abstract_sparse_array(arr):+  return AbstractSparseArray(arr.shape, arr.dtype, arr.index_dtype, arr.nnz)++def sparse_array_result_handler(device, aval):+  def build_sparse_array(data_buf, indices_buf):+    data = xla.DeviceArray(aval.data_aval, device, lazy.array(aval.data_aval.shape), data_buf)+    indices = xla.DeviceArray(aval.indices_aval, device, lazy.array(aval.indices_aval.shape), indices_buf)+    return SparseArray(aval, data, indices)+  return build_sparse_array++def sparse_array_shape_handler(a):+  return (+    xla.xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),+    xla.xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),+  )++def sparse_array_device_put_handler(a, device):+  return (+    xla.xb.get_device_backend(device).buffer_from_pyval(a.data, device),+    xla.xb.get_device_backend(device).buffer_from_pyval(a.indices, device)+  )++core.pytype_aval_mappings[SparseArray] = abstract_sparse_array

Added the TODO for now.

jakevdp

comment created time in 4 days

PullRequestReviewEvent

Pull request review commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

 def __init__(self,     if indices is None:       indices = spec_to_indices(aval.shape, sharding_spec)     self.aval = aval+    assert all(isinstance(b, xc.Buffer) for b in device_buffers), f"Expected flattened list of device buffers; got {device_buffers}"

Done.

jakevdp

comment created time in 4 days

PullRequestReviewEvent

Pull request review commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

 class DeviceArray:   _HAS_DYNAMIC_ATTRIBUTES = True    def __init__(self, aval: core.ShapedArray, device: Optional[Device],-               lazy_expr: lazy.LazyExpr, device_buffer: PyLocalBuffer):+               lazy_expr: lazy.LazyExpr,+               device_buffer: PyLocalBuffer):+    if isinstance(device_buffer, Sequence):

Done

jakevdp

comment created time in 4 days

PullRequestReviewEvent

Pull request review commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

 def apply_primitive(prim, *args, **params):   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)   return compiled_fun(*args) ++def _partition_outputs(avals, outs):+  nouts = [aval._num_buffers for aval in avals]+  assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."

Done.

jakevdp

comment created time in 4 days

PullRequestReviewEvent

push eventjakevdp/jax

Jake VanderPlas

commit sha 699e46ac5a574272d34b79bb46c1df601bd0a461

raise TypeError if device_buffer is a tuple

view details

push time in 4 days

push eventjakevdp/jax

Adam Paszke

commit sha e0d1b375fa593c5ef777e7e650a3a3e75996cedc

Delete dead axis_index code The primitive was moved to `lax_parallel.py` some time ago, so the one in `core` should no longer be used. This is probably a result of a botched rebase.

view details

Adam Paszke

commit sha 332a9ba1ad1a08f4deb09b82ff0d4c1556b431eb

Fix axis_index inside nested pmaps The previous translation rule has assumed that `axis_index` is always taken over the outermost axis in the `axis_env`, and was always producing the same output, no matter which axis has been specified. This fixes the translation rule to start taking the `axis_name` into account. Additionally, this adds support for querying the index along multiple axes, which will be useful for `gmap`.

view details

Adam Paszke

commit sha 8ac19c722211ffe4c7fb0bdbeeb236818d000291

Fix a faulty soft_pmap rule for axis_index The rule didn't specify the precision for the `np.arange` constant, which caused an accidental dtype promotion in X64 mode. Previously the error has luckicly been hidden behind a coerction that followed `axis_index` in that test, but the new implementation has surfaced it.

view details

jax authors

commit sha 99ffcc44d5d57c6bbb2ea94f92ed703d9492d6a1

Merge pull request #4378 from apaszke:axis_index_nd PiperOrigin-RevId: 333106474

view details

jax authors

commit sha 533fe28b47a39d8df923cb221ee92ffc31003816

Merge pull request #4377 from apaszke:axis_index PiperOrigin-RevId: 333107273

view details

Jake VanderPlas

commit sha 05adb3f0236c7cba43dca6b7c411ee71217b86d0

Allow jax objects to be represented by multiple buffers

view details

Jake VanderPlas

commit sha 592bfd960ff2683a7dd9709b4abe1095b5448f39

address some review comments

view details

Jake VanderPlas

commit sha e3cfaf6b32e62730a85c31e6dc4e3f83d0c79ea5

Remove check for sequences of buffers in DeviceArray

view details

Jake VanderPlas

commit sha ade0a3b9026eae55a173b5a3b4aded6692ea016e

check for buffer type in DeviceArray constructor

view details

push time in 4 days

issue commentgoogle/jax

jnp.median scaling vs np.median

Hi - thanks for the report. Because XLA doesn't currently provide any efficient routine for computing the median, JAX currently does this via a full array sort, which accounts for the poor scaling.

SamDuffield

comment created time in 4 days

issue commentvega/vega-lite

Tooltip dates are off by one

Here is my first suggestion working properly, showing the correct tooltip: Open the Chart in the Vega Editor

matthewcornell

comment created time in 4 days

issue commentvega/vega-lite

Tooltip dates are off by one

Two things:

  • If you're going the route of using a full ISO 8601 date string, I wouldn't add a timezone. I'd just use the full string (like 2012-01-02T00:00:00 instead of 2012-01-02. Javascript will parse this as local time, and vega will display it as local time unless you explicitly specify UTC.

  • If you're going the route of using utc time units, you'll need to use them everywhere, including the tooltip.

matthewcornell

comment created time in 4 days

PR opened google/jax

Call _check_arraylike in jax.numpy to improve error messages

This PR only adds the call to places where non-array inputs currently lead to errors. There remain a number of other functions where adding this check would lead to potentially breaking changes; these are deliberately left out of this PR.

+92 -10

0 comment

1 changed file

pr created time in 4 days

create barnchjakevdp/jax

branch : check-arraylike

created branch time in 4 days

push eventjakevdp/jax

Jake VanderPlas

commit sha 611c8001930d12719ca12f94d7430b2a98d9dcec

Remove check for sequences of buffers in DeviceArray

view details

push time in 4 days

push eventjakevdp/jax

Qiao Zhang

commit sha bbe3a6a9a25e0669ccba08c1d3c0297c35649136

Improve segment_sum stability by k-way summation.

view details

Qiao Zhang

commit sha 49a01d36c8b71902ed4136219c11c38d5a35b77c

Use jvp(expm) to compute expm_frechet.

view details

Jake VanderPlas

commit sha 05cc7e7352f2e2b350cae8aae4f18c509701a624

device_put_sharded: remove incorrect type annotation

view details

Adam Paszke

commit sha c4f98eb8fa2947e0db4a65bb67daa3243e7d103d

Add back the batching rule for ppermute Just make sure it's correct this time and add a test.

view details

Adam Paszke

commit sha 2081e5acee2039930c250ccb91483ed6d6cfe580

Test pmap/vmap interactions of all reduction collectives

view details

Qiao Zhang

commit sha d7564c506e752eaf495a744950822ea8e52e6cee

Add test matrices that exercise more code path.

view details

Qiao Zhang

commit sha 614acce43c1900afb2f2c3fecf89139866aeea88

Change segment_sum to use no bucketing by default.

view details

Qiao Zhang

commit sha e76ebea9cb6cf6648467f24253519e52430d3827

Merge branch 'master' into rm_fretchet

view details

Jake VanderPlas

commit sha ce1ce9cb276045324e9122830aa3890fef16af0f

Implement jnp.array_equiv

view details

Qiao Zhang

commit sha 83f14012ea9fc016a79a69287a91ba206993826d

Bump tol of float32 for complex64 inner product.

view details

Qiao Zhang

commit sha 35d231990c41ca5d1913f1817e75c45a1f934c02

Add ceil_of_ratio util and bucket_size TODO.

view details

Jake VanderPlas

commit sha 18054e05a8f64bbaa9351ebd27564684888c1e96

call _check_arraylike in jnp.diff

view details

jax authors

commit sha ada6f30f59d44a574af111b540b4771f9708b0ef

Merge pull request #4347 from jakevdp:array-equiv PiperOrigin-RevId: 332946445

view details

johnpjf

commit sha be50847ceeafc3d08de31bae21e9404b6283f05f

Make scale_and_translate take spatial dimensions

view details

johnpjf

commit sha ae910cdd311800fd1f2057b62d44028575820f60

Updating image_test

view details

jax authors

commit sha 2cb795e0d9dd513d613d059fbf5d6b4a7a9b9eaf

Merge pull request #4366 from jakevdp:diff-empty PiperOrigin-RevId: 332960638

view details

Jake VanderPlas

commit sha 2cf8d49f5b305bb6814136c81f35cdbfde187176

jnp.moveaxis: fix bug when axes are integer dtype

view details

jax authors

commit sha 04fa89a12c0c109f3b86d0e578ff0a647472204a

Merge pull request #4299 from zhangqiaorjc:segsum PiperOrigin-RevId: 332964956

view details

Qiao Zhang

commit sha b6e9da36eb7f265ee43fd61ac87d9240c8330452

Remove unused var. Bump tol.

view details

Matthew Johnson

commit sha 2abb37c286c92d70b8fb0104aacea60cda7675ce

move a _device_put_raw under broadcast impl Before this change, we had to interact with the device to construct an array of zeros, even if we were staging everything out (e.g. with jax.xla_computation and omnistaging).

view details

push time in 4 days

pull request commentgoogle/jax

jax.numpy: improved errors for invalid inputs to unary ops

Is it possible to (usefully) shorten the stack trace even further by introducing something of this sort at the level of, say, _wraps?

I looked into this... it seems like it's possible to slightly shorten the traceback by adding the logic to _wraps. It would look something like this:

>>> jnp.negative([1, 2, 3])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-f2d2f6fccc05> in <module>
----> 1 jnp.negative([1, 2, 3])

~/github/google/jax/jax/numpy/_util.py in _op(*args, **kwargs)
     61       from jax.numpy.lax_numpy import _check_arraylike
     62       def _op(*args, **kwargs):
---> 63         _check_arraylike(fun.__name__, *(args[i] for i in check_arraylike))
     64       op = _op
     65     if not hasattr(fun, '__doc__') or fun.__doc__ is None:

~/github/google/jax/jax/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    295                     if not _arraylike(arg))
    296     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 297     raise TypeError(msg.format(fun_name, type(arg), pos))
    298 
    299 

TypeError: negative requires ndarray or scalar arguments, got <class 'list'> at position 0.

The problem is that this is hard to do in a way that's not brittle, because you have to somehow specify to _wraps which arguments must be arraylike, and these can be either passed by argument or keyword, so it would involve parsing the function signature of the wrapped operation and combining that with some extra info that's passed to _wraps in order to validate the right part of *args and **kwargs.

Given that, I think I'm going to stick with doing the validation within the function's implementation.

jakevdp

comment created time in 4 days

Pull request review commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

 def __init__(self,     if indices is None:       indices = spec_to_indices(aval.shape, sharding_spec)     self.aval = aval+    assert all(isinstance(b, xc.Buffer) for b in device_buffers), f"Expected flattened list of device buffers; got {device_buffers}"

I added this assertion after running an internal presubmit and finding that it was nearly impossible to locate the root cause of problems. Downstream dependencies were creating Sharded Device Arrays with device_buffers as a list of tuples , and this was causing an entirely unintelligible error when, later in the code, this device array was passed to some jitted function. This assertion was the only way I could identify the root cause of those errors.

What do you think about leaving it in place, but adding a TODO to delete it once all downstream dependencies are updated?

jakevdp

comment created time in 4 days

PullRequestReviewEvent

delete branch jakevdp/jax

delete branch : moveaxis-fix

delete time in 4 days

delete branch jakevdp/jax

delete branch : type-annotation

delete time in 4 days

delete branch jakevdp/jax

delete branch : diff-empty

delete time in 5 days

PR opened google/jax

jnp.moveaxis: fix bug when axes are integer dtype

Before:

In [1]: import jax.numpy as jnp                                                                                                          

In [2]: x = jnp.arange(6).reshape(2, 3) 

In [3]: jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))                                                                                      
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-53fba72727a1> in <module>
----> 1 jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))

~/github/google/jax/jax/numpy/lax_numpy.py in moveaxis(a, source, destination)
   1261   if isinstance(destination, int):
   1262     destination = (destination,)
-> 1263   source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
   1264   destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
   1265   if len(source) != len(destination):

~/github/google/jax/jax/interpreters/xla.py in __iter__(self)
   1060   def __iter__(self):
   1061     if self.ndim == 0:
-> 1062       raise TypeError("iteration over a 0-d array")  # same as numpy error
   1063     else:
   1064       return self._value.__iter__()

TypeError: iteration over a 0-d array

After:

In [1]: import jax.numpy as jnp                                                                                                          

In [2]: x = jnp.arange(6).reshape(2, 3)

In [3]: jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))                                                                                      
Out[3]: 
DeviceArray([[0, 3],
             [1, 4],
             [2, 5]], dtype=int32)
+9 -4

0 comment

1 changed file

pr created time in 5 days

create barnchjakevdp/jax

branch : moveaxis-fix

created branch time in 5 days

issue commentvega/vega-lite

Tooltip dates are off by one

But concretely, if you serialize your dates in full ISO 8601 form (e.g. '2012-01-02T00:00:00'), javascript will treat them as local dates and your code will work as expected

Alternatively, if you don't want to change the format of the input dates, you can use the UTC equivalent of your time units; i.e. change "timeUnit": "yearmonthdate" to "timeUnit": "utcyearmonthdate" and it should work as expected.

matthewcornell

comment created time in 5 days

issue commentvega/vega-lite

Tooltip dates are off by one

I spent a Lot of time in Altair trying to make the interaction between Pandas dates and Javascript dates sensical from the perspective of the user. If you're generating data in Python and figures in Vega-Lite and are getting hung up on timezones in date conversions, Altair might be the remedy you're looking for.

matthewcornell

comment created time in 5 days

PR opened google/jax

jax.numpy: improved errors for invalid inputs to unary ops

Before:

In [1]: import jax.numpy as jnp                                                                                                          

In [2]: jnp.negative([1, 2, 3])                                                                                                          
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-f2d2f6fccc05> in <module>
----> 1 jnp.negative([1, 2, 3])

~/github/google/jax/jax/numpy/lax_numpy.py in <lambda>(x)
    347       return lax_fn(x)
    348   else:
--> 349     fn = lambda x: lax_fn(x)
    350   if lax_doc:
    351     doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()

~/github/google/jax/jax/lax/lax.py in neg(x)
     89 def neg(x: Array) -> Array:
     90   r"""Elementwise negation: :math:`-x`."""
---> 91   return neg_p.bind(x)
     92 
     93 def sign(x: Array) -> Array:

~/github/google/jax/jax/core.py in bind(self, *args, **params)
    264     top_trace = find_top_trace(args)
    265     tracers = map(top_trace.full_raise, args)
--> 266     out = top_trace.process_primitive(self, tracers, params)
    267     return map(full_lower, out) if self.multiple_results else full_lower(out)
    268 

~/github/google/jax/jax/core.py in process_primitive(self, primitive, tracers, params)
    572 
    573   def process_primitive(self, primitive, tracers, params):
--> 574     return primitive.impl(*tracers, **params)
    575 
    576   def process_call(self, primitive, f, tracers, params):

~/github/google/jax/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    222 def apply_primitive(prim, *args, **params):
    223   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 224   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
    225   return compiled_fun(*args)
    226 

~/github/google/jax/jax/interpreters/xla.py in arg_spec(x)
    214 
    215 def arg_spec(x):
--> 216   aval = abstractify(x)
    217   try:
    218     return aval, x._device

~/github/google/jax/jax/interpreters/xla.py in abstractify(x)
    163     aval_fn = pytype_aval_mappings.get(typ)
    164     if aval_fn: return aval_fn(x)
--> 165   raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
    166 
    167 def _make_abstract_python_scalar(typ, _):

TypeError: Argument '[1, 2, 3]' of type '<class 'list'>' is not a valid JAX type

After:

In [1]: import jax.numpy as jnp                                                                                                          

In [2]: jnp.negative([1, 2, 3])                                                                                                          
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-f2d2f6fccc05> in <module>
----> 1 jnp.negative([1, 2, 3])

~/github/google/jax/jax/numpy/lax_numpy.py in <lambda>(x)
    345     fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
    346   else:
--> 347     fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
    348   if lax_doc:
    349     doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()

~/github/google/jax/jax/numpy/lax_numpy.py in _promote_args(fun_name, *args)
    300 def _promote_args(fun_name, *args):
    301   """Convenience function to apply Numpy argument shape and dtype promotion."""
--> 302   _check_arraylike(fun_name, *args)
    303   return _promote_shapes(fun_name, *_promote_dtypes(*args))
    304 

~/github/google/jax/jax/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    295                     if not _arraylike(arg))
    296     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 297     raise TypeError(msg.format(fun_name, type(arg), pos))
    298 
    299 

TypeError: negative requires ndarray or scalar arguments, got <class 'list'> at position 0.
+2 -4

0 comment

1 changed file

pr created time in 5 days

create barnchjakevdp/jax

branch : unop-check-array

created branch time in 5 days

issue commentgoogle/jax

Error when creating test for new implementation

Yes, it's fine for the shape of input arrays to determine the shape of outputs within JIT-compiled computations. Where problems arise is when the values within input arrays determine the shape of outputs.

There's a bit of relevant discussion of this here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT

alexminnaar

comment created time in 5 days

delete branch jakevdp/jax

delete branch : array-equiv

delete time in 5 days

PR opened google/jax

Fix jnp.diff behavior on scalars

Fixes jnp.diff erroneously passing-through scalars:

In [1]: import numpy as np                                                                                                               

In [2]: import jax.numpy as jnp

In [3]: jnp.diff(0)                                                                                                                      
Out[3]: 0                                                                                                    

In [4]: np.diff(0)                                                                                                                       
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-3-85ad82741d0e> in <module>
----> 1 np.diff(0)

<__array_function__ internals> in diff(*args, **kwargs)

~/.local/share/virtualenvs/jax-LBbfM5ix/lib/python3.8/site-packages/numpy/lib/function_base.py in diff(a, n, axis, prepend, append)
   1244     nd = a.ndim
   1245     if nd == 0:
-> 1246         raise ValueError("diff requires input that is at least one dimensional")
   1247     axis = normalize_axis_index(axis, nd)
   1248 

ValueError: diff requires input that is at least one dimensional
+6 -6

0 comment

2 changed files

pr created time in 5 days

create barnchjakevdp/jax

branch : diff-empty

created branch time in 5 days

issue commentgoogle/jax

Error when creating test for new implementation

In JAX's JIT compilation model, the shapes of arrays cannot be dependent on data values, unless those values are static. In your function, the bins value determines the shape of the output, so it must be a static parameter in the computation.

There are various ways to make certain a parameter is static, but in the test suite, this is generally done by leaving them out of args_maker and passing them to the function in question within a closure. You can see an example of that here, in the test for the 1D histogram: https://github.com/google/jax/blob/e0af77fbca88ec882ef33c18926f8382f7178aed/tests/lax_numpy_test.py#L2286-L2300

Notice that bins and density are treated as static arguments, i.e. they're not parameters passed to the jitted function. If bins were passed to the function by args_maker in this test, it would result in the same error you're seeing in your test.

Does that make sense?

alexminnaar

comment created time in 5 days

push eventjakevdp/jax

Jean-Baptiste Lespiau

commit sha 9f53d2a8d883bc6eefc1f701cdc6639bdc774f84

Internal change PiperOrigin-RevId: 332920102

view details

jax authors

commit sha 55c6bdfe9c5d0631200cb76e4f56481b3656b03f

Clean-up todos related to the upgrade of jaxlib. PiperOrigin-RevId: 332932271

view details

Jake VanderPlas

commit sha 67bd83e4f04e47688f809235d57400ee7056776a

Allow jax objects to be represented by multiple buffers

view details

push time in 5 days

push eventjakevdp/jax

Jake VanderPlas

commit sha 6b9dfb139679eff878ffcf1022d392cf5c0123ad

fix incorrect indentation

view details

John Flynn

commit sha 7fd7009c231bf9a86e009f6e86ccd4b374cb29dc

Expose scale_and_translate as a public function, fix a bug in implementation when translation is not 0. Change implementation to use native JAX everywhere allowing vmaping and gradients wrt scale and translation.

view details

jax authors

commit sha c33335b2d6246565d0f85a312793d56c08c6e8c5

Merge pull request #4229 from johnpjf:changelist/330579231 PiperOrigin-RevId: 332886460

view details

jax authors

commit sha 24fe07ebd14fd234fa5c26142dd4870b13eecd65

Merge pull request #4345 from jakevdp:indentation PiperOrigin-RevId: 332894503

view details

Jake VanderPlas

commit sha 47869afe136be71e69c0c179ca53fd2009384bf0

Allow jax objects to be represented by multiple buffers

view details

push time in 5 days

push eventjakevdp/jax

Roy Frostig

commit sha 90cb99fc040b6cceaf81ff0bb0cea552c6a0baec

insert a stop_gradient in lax.create_token (as it forces a false dependency on the operand)

view details

Benjamin Chetioui

commit sha 9ddd252b56e01f7342aca0d2a98060eefc624cec

[jax2tf] Add primitive conversion for cholesky_p.

view details

Benjamin Chetioui

commit sha b3930f0b7e957e95078f77d011138f90ce95643d

Make the custom_assert a one-liner.

view details

Chase Roberts

commit sha 2bc92f5593c04dfd9a7bbb65ed77b70092038031

Fixed ppermute translation rule (#4349)

view details

Roy Frostig

commit sha 16e60360940f2baa4f3bc4f94b58680af8295198

trivial change to test source sync PiperOrigin-RevId: 332544315

view details

Matthew Johnson

commit sha f172fb74e17f250f769ca398d2eb563d88153ac2

plumb donate_argnums into jax.xla_computation

view details

jax authors

commit sha cfbaca0507b451305d05d5b2f482d2be605ed77a

Merge pull request #4330 from google:create-token-stop-grad PiperOrigin-RevId: 332564392

view details

Matthew Johnson

commit sha e88579f22b9185790638b3562ce0a97f994a2af3

fix typo

view details

Matthew Johnson

commit sha a6b3fa2c551ad567d260a99fbd6d9a1c60d5c649

add trivial test

view details

jax authors

commit sha ff5c14570251c4d0c9b469c05a47e62373207853

Merge pull request #4353 from google:xla-computation-donate-argnums PiperOrigin-RevId: 332573843

view details

jax authors

commit sha d9f9b50dfaa83199b9d569ebda6698e4dc88d020

Merge pull request #4343 from SIben:test_cholesky PiperOrigin-RevId: 332627994

view details

Matthew Johnson

commit sha 50dd9c5016ce0bceaad7256208e790b0b02a3142

don't initialize backend in xla_computation This should allow us to use donate_argnums *and* build HLO computations for backends not available at build time.

view details

Matthew Johnson

commit sha 1aab5ced9fbb1f97bb31f84b264b9b684c83108c

fix logic

view details

Matthew Johnson

commit sha 1092fa1d9b3de680aca5a883b781e6ba97c940be

fix logic, skip test

view details

jax authors

commit sha 85d070f0cdc665b5aad6f3341d09fe4a60d0368d

Merge pull request #4356 from google:xla-computation-dont-initialize-backend PiperOrigin-RevId: 332683955

view details

Benjamin Chetioui

commit sha 1cde76b130298c4d890d38242d19cff7b65ffaf9

[jax2tf] Implementation of the conversion of eig_p.

view details

Benjamin Chetioui

commit sha 015bc3c2cc7736720e362b15274d59e85e08359c

Replace manual conj + transpose with call to adjoint.

view details

Benjamin Chetioui

commit sha 1d94363df02a964df02f0c77795bb9939e76b1fb

Ignore eig conversion test on TPU/GPU, as it is unimplemented in JAX.

view details

jax authors

commit sha 695e8d88c3a827e0d73bf47022128ad7c1761420

Merge pull request #4338 from SIben:test_eig_p PiperOrigin-RevId: 332813984

view details

Jake VanderPlas

commit sha 2d0de43d88b3898a7c0fd0408b12ba8eb4836c31

Make xla.device_put() return tuples

view details

push time in 5 days

push eventjakevdp/jax

Jake VanderPlas

commit sha ce1ce9cb276045324e9122830aa3890fef16af0f

Implement jnp.array_equiv

view details

push time in 5 days

issue commentvega/vega-lite

Tooltip dates are off by one

You can see the issue directly in your javascript console:

> new Date("2011-10-02")
Sat Oct 01 2011 17:00:00 GMT-0700 (Pacific Daylight Time)

You need to use either a full ISO 8601 date string, or it will be parsed as UTC.

matthewcornell

comment created time in 5 days

push eventjakevdp/jax

Jake VanderPlas

commit sha e4219391d5f2f1b9e93c3cb71fd01e29881fd87f

Implement jnp.array_equiv

view details

push time in 5 days

push eventjakevdp/jax

Roy Frostig

commit sha 90cb99fc040b6cceaf81ff0bb0cea552c6a0baec

insert a stop_gradient in lax.create_token (as it forces a false dependency on the operand)

view details

Benjamin Chetioui

commit sha 9ddd252b56e01f7342aca0d2a98060eefc624cec

[jax2tf] Add primitive conversion for cholesky_p.

view details

Benjamin Chetioui

commit sha b3930f0b7e957e95078f77d011138f90ce95643d

Make the custom_assert a one-liner.

view details

Chase Roberts

commit sha 2bc92f5593c04dfd9a7bbb65ed77b70092038031

Fixed ppermute translation rule (#4349)

view details

Roy Frostig

commit sha 16e60360940f2baa4f3bc4f94b58680af8295198

trivial change to test source sync PiperOrigin-RevId: 332544315

view details

Matthew Johnson

commit sha f172fb74e17f250f769ca398d2eb563d88153ac2

plumb donate_argnums into jax.xla_computation

view details

jax authors

commit sha cfbaca0507b451305d05d5b2f482d2be605ed77a

Merge pull request #4330 from google:create-token-stop-grad PiperOrigin-RevId: 332564392

view details

Matthew Johnson

commit sha e88579f22b9185790638b3562ce0a97f994a2af3

fix typo

view details

Matthew Johnson

commit sha a6b3fa2c551ad567d260a99fbd6d9a1c60d5c649

add trivial test

view details

jax authors

commit sha ff5c14570251c4d0c9b469c05a47e62373207853

Merge pull request #4353 from google:xla-computation-donate-argnums PiperOrigin-RevId: 332573843

view details

jax authors

commit sha d9f9b50dfaa83199b9d569ebda6698e4dc88d020

Merge pull request #4343 from SIben:test_cholesky PiperOrigin-RevId: 332627994

view details

Matthew Johnson

commit sha 50dd9c5016ce0bceaad7256208e790b0b02a3142

don't initialize backend in xla_computation This should allow us to use donate_argnums *and* build HLO computations for backends not available at build time.

view details

Matthew Johnson

commit sha 1aab5ced9fbb1f97bb31f84b264b9b684c83108c

fix logic

view details

Matthew Johnson

commit sha 1092fa1d9b3de680aca5a883b781e6ba97c940be

fix logic, skip test

view details

jax authors

commit sha 85d070f0cdc665b5aad6f3341d09fe4a60d0368d

Merge pull request #4356 from google:xla-computation-dont-initialize-backend PiperOrigin-RevId: 332683955

view details

Benjamin Chetioui

commit sha 1cde76b130298c4d890d38242d19cff7b65ffaf9

[jax2tf] Implementation of the conversion of eig_p.

view details

Benjamin Chetioui

commit sha 015bc3c2cc7736720e362b15274d59e85e08359c

Replace manual conj + transpose with call to adjoint.

view details

Benjamin Chetioui

commit sha 1d94363df02a964df02f0c77795bb9939e76b1fb

Ignore eig conversion test on TPU/GPU, as it is unimplemented in JAX.

view details

jax authors

commit sha 695e8d88c3a827e0d73bf47022128ad7c1761420

Merge pull request #4338 from SIben:test_eig_p PiperOrigin-RevId: 332813984

view details

Jake VanderPlas

commit sha b0b3011843158151e3261da935997203c06da61d

Implement jnp.array_equiv

view details

push time in 5 days

PullRequestEvent

PR opened google/jax

device_put_sharded: remove incorrect type annotation
+1 -1

0 comment

1 changed file

pr created time in 5 days

create barnchjakevdp/jax

branch : type-annotation

created branch time in 5 days

issue commentaltair-viz/altair

QUESTION: proxy URLs to CSV files are being interpreted as JSON

I'm not sure how Vega determines the default, but alt.UrlData has a format parameter that can be used to specify the format.

dmoore247

comment created time in 5 days

issue commentjohannesjmeyer/rsmf

Altair support

The vega-cli package has methods to save vega/vega-lite charts to png, svg, and pdf. If none of those are suitable for the purposes here, that package would be the place to implement other formats.

jedbrown

comment created time in 5 days

issue commentaltair-viz/altair

Support for vega custom projections

Altair's only control over a chart comes from what is defined in the Vega-Lite specification. Yes, you can specify a custom projection in the specification, but in order for it to do anything, that custom projection has to be defined in the renderer.

So where is the renderer? It's in the JuptyerLab frontend code... the nteract frontend code... the VSCode frontend code... the Streamlit frontend code. Except in the simplest circumstances, Altair has no control over how the chart is rendered, beyond what can be specified in the Vega-Lite spec.

Theoretically, would it be possible to coordinate all these disparate projects and decide on a standard whereby Javascript extensions could be injected into them? Maybe. Probably not, though, given how problematic user-generated Javascript can be security-wise. At best, it would be a feature that works in some frontends and not in others.

vectro

comment created time in 5 days

issue closedaltair-viz/altair

Support for vega custom projections

Vega supports the use of custom projections using vega.projection, and third-party d3 projections are available (see e.g., this library). But, Altair has a hard-coded list of supported projections, meaning it's not possible to use Altair with custom Vega projections. There should be a way to override the check of known projections in Altair, so that one can use other custom projections.

closed time in 6 days

vectro

issue commentaltair-viz/altair

Support for vega custom projections

Thanks for the request!

This can be done by changing the Javascript embedding code for the vega-lite specifications that altair produces.

That said, Altair is fundamentally a tool to create vega-lite specifications, and vega projections are not expressible in vega-lite. Given that, I'm going to close this feature request as out-of-scope for the project.

vectro

comment created time in 6 days

issue commentaltair-viz/altair

Setting the size of mark_rect

You can do this with mark_square() - this is akin to mark_point() and mark_circle() in that it generates points that can have a size encoding, whereas rect generates rectangles whose size is controlled by x,x2,y,y2 encodings:

import altair as alt
import numpy as np
import pandas as pd

# Compute x^2 + y^2 across a 2D grid
x, y = np.meshgrid(range(-10, 10), range(-5, 5))
z = x ** 2 + y ** 2

# Convert this grid to columnar data expected by Altair
source = pd.DataFrame({'x': x.ravel(),
                     'y': y.ravel(),
                     'z': z.ravel()})

alt.Chart(source).mark_square().encode(
    x='x:O',
    y='y:O',
    color='z:Q',
    size='z:Q'
)

visualization (37)

ilyasustun

comment created time in 8 days

PR closed google/jax

Implement jnp.array_equiv cla: yes
+43 -2

0 comment

4 changed files

jakevdp

pr closed time in 8 days

PR opened google/jax

Implement jnp.array_equiv
+43 -2

0 comment

4 changed files

pr created time in 8 days

push eventjakevdp/jax

Jake VanderPlas

commit sha a31c6fc458eb05f40f84c5f1089f47f675d48127

fix shape mismatch

view details

push time in 8 days

push eventjakevdp/jax

Qiao Zhang

commit sha b3a098747aa6c674265796ddff8191a2e6a81efe

Make expm transposable and remove custom_jvp rule. (#4314) * Make expm transposable and remove custom_jvp rule. * Add check_grads for up to 2nd order derivative.

view details

Benjamin Chetioui

commit sha d478e346ac2cb1c7f531c61fb2888ccbbe9a3e60

Fix conditional in eig and expand eig test suite. (#4320) * Fix conditional in eig and expand eig test suite.

view details

George Necula

commit sha 0ac25c760af2788ba0a14f8a0f5827035e4c29e0

[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278) * [jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 We now fixed tf.raw_ops.AddV2 to support uint32. It was already supporting uint8, so it is a better choice now than tf.math.add. This allowed us to use the threefry implementation using uint32.

view details

George Necula

commit sha ded7b3854c37f93652e7eed4f4cd50522f3be70f

[jax2tf] Revert '[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278)' (#4332) Generates errors due to Grappler replacing AddV2 with AddN, which is not implemented for uint32

view details

George Necula

commit sha 8376d92049624bf0784647b17b1f09015acd0947

Disable testExpmGrad on TPU, pending investigation of compiler error (#4333)

view details

Peter Hawkins

commit sha 2911bcd63427bb433193e5450c967a79ddec70f5

Enable complex-valued Cholesky decomposition tests on TPU> (#4339)

view details

Jake Vanderplas

commit sha 6a89f60683d19f2797be8e777069a0ba7c515444

fix benchmark sums (#4329)

view details

Matthew Johnson

commit sha 6614f94890429c6c4c9dd46f932e22653dd6316d

rename and simplify TypedJaxpr -> ClosedJaxpr (#4328) rename and simplify TypedJaxpr -> ClosedJaxpr This change: * simplifies code that constructs TypedJaxprs/ClosedJaxprs (because in_avals / out_avals no longer need to be constructed), making them easier to work with; * correspondingly rules out a class of errors (mismatches between invars/outvars and in_avals/out_avals); * provides a more descriptive class name (ClosedJaxprs are like jaxprs but they're closed in that they are packaged with their constant values). This is part 1 of an attempt to remove TypedJaxprs completely, or at least significantly reduce our use of them. However, I'm not getting rid of them entirely in this first step because it'd require bigger changes (basically allowing all constants to be represented as literals, rather than only scalars) that would not only touch a lot more code (jaxpr formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering) but also might affect XLA lowering right before a conference deadline (ICLR). Plus I'm trying to make big changes in smaller steps :) Co-authored-by: George Necula <gcnecula@gmail.com>

view details

Jake VanderPlas

commit sha 56d3333716f462b551ab1c57e7c2dec397827a62

Implement jnp.array_equiv

view details

push time in 8 days

pull request commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

api_benchmark results: master branch:

------------------------------------------------------------------------
Benchmark                              Time             CPU   Iterations
------------------------------------------------------------------------
jit_trivial_dispatch              120578 ns       120539 ns         5381
jit_trivial                       136459 ns       136437 ns         5146
jit_simple_dispatch               495457 ns       494981 ns         1461
jit_simple                        510835 ns       510391 ns         1185
jit_simple_many_args_dispatch    2863614 ns      2863011 ns          246
jit_simple_many_args             2872962 ns      2872501 ns          234
jit_dispatch_without_transfer    6348334 ns      6347968 ns           91
jit_dispatch_with_transfer       6126572 ns      6126012 ns          117

with this change:

------------------------------------------------------------------------
Benchmark                              Time             CPU   Iterations
------------------------------------------------------------------------
jit_trivial_dispatch              106872 ns       106829 ns         6227
jit_trivial                       127626 ns       127625 ns         5493
jit_simple_dispatch               499879 ns       499377 ns         1417
jit_simple                        512144 ns       512095 ns         1342
jit_simple_many_args_dispatch    3088399 ns      3087974 ns          227
jit_simple_many_args             3043566 ns      3042474 ns          226
jit_dispatch_without_transfer    6320484 ns      6320214 ns          107
jit_dispatch_with_transfer       6268184 ns      6266837 ns          116

jakevdp

comment created time in 8 days

push eventjakevdp/jax

Jake VanderPlas

commit sha 76a3bccb19fd861375d78f14280f4fc5e7adc642

fix unused import

view details

push time in 8 days

PR opened google/jax

fix incorrect indentation
+3 -3

0 comment

1 changed file

pr created time in 8 days

create barnchjakevdp/jax

branch : indentation

created branch time in 8 days

push eventjakevdp/jax

Jake VanderPlas

commit sha f0aaaedde92f2792799b97f7ee5bde6abbb6f328

Error early when ShardedDeviceArray is passed nested buffers.

view details

push time in 8 days

push eventjakevdp/jax

Jake VanderPlas

commit sha b0caa6e86fe52b77d689d7a4acd79370e024e6d8

more pytype errors

view details

push time in 8 days

push eventjakevdp/jax

Matthew Johnson

commit sha 6614f94890429c6c4c9dd46f932e22653dd6316d

rename and simplify TypedJaxpr -> ClosedJaxpr (#4328) rename and simplify TypedJaxpr -> ClosedJaxpr This change: * simplifies code that constructs TypedJaxprs/ClosedJaxprs (because in_avals / out_avals no longer need to be constructed), making them easier to work with; * correspondingly rules out a class of errors (mismatches between invars/outvars and in_avals/out_avals); * provides a more descriptive class name (ClosedJaxprs are like jaxprs but they're closed in that they are packaged with their constant values). This is part 1 of an attempt to remove TypedJaxprs completely, or at least significantly reduce our use of them. However, I'm not getting rid of them entirely in this first step because it'd require bigger changes (basically allowing all constants to be represented as literals, rather than only scalars) that would not only touch a lot more code (jaxpr formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering) but also might affect XLA lowering right before a conference deadline (ICLR). Plus I'm trying to make big changes in smaller steps :) Co-authored-by: George Necula <gcnecula@gmail.com>

view details

Jake VanderPlas

commit sha 50d84c518ae9e16a94585976296cd6cc1cff8110

Make xla.device_put() return tuples

view details

Jake VanderPlas

commit sha 4055bf638ccd620167aea1b12da3498e54fc2f42

Make xla.aval_to_xla_shape return a tuple

view details

Jake VanderPlas

commit sha 21ae6c88aa1ca7c4c856d3fbba178c5d8a2acade

make xla.aval_to_result_handler return number of args

view details

Jake VanderPlas

commit sha 5dc684dbdf7b515d9495e69e0115efabb0ce4e81

Add ability support for multiple device buffers per object

view details

Jake VanderPlas

commit sha cf0ce9112dd73f2802da284d4ab52189a157d219

Add custom object tests

view details

Jake VanderPlas

commit sha 80c66ca16aed8b206149787a8a683d128d32ded9

fix multi-buffer device put

view details

Jake VanderPlas

commit sha e455e48c425d9c70bf89e6e3fc2f1f747bafa621

streamline tests

view details

Jake VanderPlas

commit sha 8bb730e797331595a75ed6bf134595c646164c70

return xops tuple from translation rule

view details

Jake VanderPlas

commit sha eeb0e3d0b637be0fecdfc7ff8bf798b5ee02fd58

cleanup: use itertools.chain.from_iterable

view details

Jake VanderPlas

commit sha 693e77506cefe81ad4cb458b59382bafcb30df57

more minor cleanups

view details

Jake VanderPlas

commit sha 5ecc37518ff74dc92f84b6b9dc36a300cac81fcf

minor fixes

view details

Jake VanderPlas

commit sha d601e4b84dd934c42dafa4fc887a1b939aacfbc3

fix missing f-string

view details

Jake VanderPlas

commit sha 2f3c2f2eb8d9a3714debdc6e1919a943ec080db2

fix interaction between multiple_results and destructure

view details

Jake VanderPlas

commit sha f2f3b2097a8ffb0c10a7cff7d19adc647eaf11b8

simplify xla_result_handler definitions

view details

Jake VanderPlas

commit sha 117cd70456ade0175b162adfab67486db8e79787

further simplify result handlers

view details

Jake VanderPlas

commit sha 62cdfc97fd154168b49579193d603bec72c13edf

more simplifications of result handlers

view details

Jake VanderPlas

commit sha 821df45f0daba7741743e0dfd173b505c45410e8

fix internal pytype

view details

Jake VanderPlas

commit sha 09297cbfdbdf6500001d9eef8b956917bcfbfd63

add primitives for buffer access

view details

Jake VanderPlas

commit sha f76f9d5805cdef10a2e4a4141e6986a56e0ecd6b

Add matvec test

view details

push time in 8 days

push eventjakevdp/jax

Jake VanderPlas

commit sha 8f6f3619f1e2b44b889c103101aa7563876f7acf

fix type annotation

view details

push time in 8 days

issue commentgoogle/jax

Implement NumPy sorting routines

Yes, I don't think anyone has worked on partitions yet. It's not trivial to do, because they're not (yet?) implemented in XLA

It may be possible to implement using scans (similar to how jnp.searchsorted is implemented), but in the end it might actually be faster to implement partitioning via a full sort, since XLA knows how to do this efficiently.

shoyer

comment created time in 8 days

push eventjakevdp/jax

Matthew Johnson

commit sha 11007ba0e3577dab0c19ac37f60061fdccecb016

test eval_context works w/ and w/o omnistaging (#4325)

view details

Benjamin Chetioui

commit sha 8a4ee3d8516662b6f6fd95046f4d28119b3c1db2

Fix shape checking rule for conv_general_dilated. (#4318) * Fix shape checking rule for conv_general_dilated. This closes google/jax#4316. * Added test based on google/jax#4316. * Change test name to be more accurate.

view details

Jake Vanderplas

commit sha e0af77fbca88ec882ef33c18926f8382f7178aed

Implement jnp.ravel_multi_index() (#4313)

view details

Qiao Zhang

commit sha b3a098747aa6c674265796ddff8191a2e6a81efe

Make expm transposable and remove custom_jvp rule. (#4314) * Make expm transposable and remove custom_jvp rule. * Add check_grads for up to 2nd order derivative.

view details

Benjamin Chetioui

commit sha d478e346ac2cb1c7f531c61fb2888ccbbe9a3e60

Fix conditional in eig and expand eig test suite. (#4320) * Fix conditional in eig and expand eig test suite.

view details

George Necula

commit sha 0ac25c760af2788ba0a14f8a0f5827035e4c29e0

[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278) * [jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 We now fixed tf.raw_ops.AddV2 to support uint32. It was already supporting uint8, so it is a better choice now than tf.math.add. This allowed us to use the threefry implementation using uint32.

view details

George Necula

commit sha ded7b3854c37f93652e7eed4f4cd50522f3be70f

[jax2tf] Revert '[jax2tf] Replace tf.math.add with tf.raw_ops.AddV2 (#4278)' (#4332) Generates errors due to Grappler replacing AddV2 with AddN, which is not implemented for uint32

view details

George Necula

commit sha 8376d92049624bf0784647b17b1f09015acd0947

Disable testExpmGrad on TPU, pending investigation of compiler error (#4333)

view details

Peter Hawkins

commit sha 2911bcd63427bb433193e5450c967a79ddec70f5

Enable complex-valued Cholesky decomposition tests on TPU> (#4339)

view details

Jake Vanderplas

commit sha 6a89f60683d19f2797be8e777069a0ba7c515444

fix benchmark sums (#4329)

view details

Jake VanderPlas

commit sha b8cb85cf75e87696a6ea613e8b7205d68aa1ffb0

Make xla.device_put() return tuples

view details

Jake VanderPlas

commit sha 5792b0980586ea169dfd924958f02f2907e6a410

Make xla.aval_to_xla_shape return a tuple

view details

Jake VanderPlas

commit sha 10eafcf98bf113c442642677d89e09ded4107f9e

make xla.aval_to_result_handler return number of args

view details

Jake VanderPlas

commit sha bdb15be9fffac6e1fd59c139bd0c13c4bd8b4653

Add ability support for multiple device buffers per object

view details

Jake VanderPlas

commit sha 96dc04226e757fb7a02c942bb4ddedc8814037ec

Add custom object tests

view details

Jake VanderPlas

commit sha 10ad10404b2f88b81107cad3fc73356b0827091d

fix multi-buffer device put

view details

Jake VanderPlas

commit sha aa24c2f0ebacabde0cdead539978aa8fc1e4dc10

streamline tests

view details

Jake VanderPlas

commit sha 1549952c722fe23da685f6a82e05c081fffcaceb

return xops tuple from translation rule

view details

Jake VanderPlas

commit sha b3278991ba5f235206dc6ac25fff66254f87a5bb

cleanup: use itertools.chain.from_iterable

view details

Jake VanderPlas

commit sha b5b4901a7562c5bf3e19c9a5834d11b643ccecf9

more minor cleanups

view details

push time in 8 days

PullRequestReviewEvent

Pull request review commentgoogle/jax

Solving #2347

 def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0):   if order is not None and order != "K":     raise NotImplementedError("Only implemented for order='K'")+  # check if the given dtype is compatible with JAX   lax._check_user_dtype_supported(dtype, "array")   dtype = dtype and dtypes.canonicalize_dtype(dtype)    if _can_call_numpy_array(object):     object = _np_array(object, dtype=dtype, ndmin=ndmin)++  if type(object) is np.ndarray:+    # check if the inferred np type is compatible with JAX+    _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+    lax._check_user_dtype_supported(_inferred_dtype, "array")+

No worries!

tudorcebere

comment created time in 8 days

PullRequestReviewEvent

delete branch jakevdp/jax

delete branch : benchmarks

delete time in 8 days

push eventgoogle/jax

Jake Vanderplas

commit sha 6a89f60683d19f2797be8e777069a0ba7c515444

fix benchmark sums (#4329)

view details

push time in 8 days

PR merged google/jax

fix benchmark sums cla: yes

This was broken by #4195

+2 -2

0 comment

1 changed file

jakevdp

pr closed time in 8 days

Pull request review commentgoogle/jax

Solving #2347

 def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0):   if order is not None and order != "K":     raise NotImplementedError("Only implemented for order='K'")+  # check if the given dtype is compatible with JAX   lax._check_user_dtype_supported(dtype, "array")   dtype = dtype and dtypes.canonicalize_dtype(dtype)    if _can_call_numpy_array(object):     object = _np_array(object, dtype=dtype, ndmin=ndmin)++  if type(object) is np.ndarray:+    # check if the inferred np type is compatible with JAX+    _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+    lax._check_user_dtype_supported(_inferred_dtype, "array")+

The new version seems redundant: if the input is a numpy array, the typecheck now happens twice. What I had in mind is this:

def array(object, dtype=None, copy=True, order="K", ndmin=0):
  if order is not None and order != "K":
    raise NotImplementedError("Only implemented for order='K'")

  # check if the given dtype is compatible with JAX
  lax._check_user_dtype_supported(dtype, "array")
  dtype = dtype and dtypes.canonicalize_dtype(dtype)

  if _can_call_numpy_array(object):
    object = _np_array(object, dtype=dtype, ndmin=ndmin)

  assert type(object) not in dtypes.python_scalar_dtypes

  if type(object) is np.ndarray:
    _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
    lax._check_user_dtype_supported(_inferred_dtype, "array")
    out = _device_put_raw(object)
    if dtype: assert _dtype(out) == dtype
  elif isinstance(object, (DeviceArray, core.Tracer)):
    #...

a single check that takes care of every case where object is an array.

tudorcebere

comment created time in 8 days

PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentgoogle/jax

Solving #2347

 def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0):   if order is not None and order != "K":     raise NotImplementedError("Only implemented for order='K'")+  # check if the given dtype is compatible with JAX   lax._check_user_dtype_supported(dtype, "array")   dtype = dtype and dtypes.canonicalize_dtype(dtype)    if _can_call_numpy_array(object):     object = _np_array(object, dtype=dtype, ndmin=ndmin)++  if type(object) is np.ndarray:+    # check if the inferred np type is compatible with JAX+    _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+    lax._check_user_dtype_supported(_inferred_dtype, "array")+

Understood - but another path to this is if the input is a numpy array with dtype object, and it was not clear to me without looking at the source of _can_call_numpy_array whether that block would be executed in that case. It turns out it is. But due to that obfuscation I thought it would be clearer to put the check in the place where it's obvious at first glance that it will be called for every numpy array input.

What do you think?

tudorcebere

comment created time in 8 days

Pull request review commentgoogle/jax

Solving #2347

 def atleast_3d(*arys): def array(object, dtype=None, copy=True, order="K", ndmin=0):   if order is not None and order != "K":     raise NotImplementedError("Only implemented for order='K'")+  # check if the given dtype is compatible with JAX   lax._check_user_dtype_supported(dtype, "array")   dtype = dtype and dtypes.canonicalize_dtype(dtype)    if _can_call_numpy_array(object):     object = _np_array(object, dtype=dtype, ndmin=ndmin)++  if type(object) is np.ndarray:+    # check if the inferred np type is compatible with JAX+    _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)+    lax._check_user_dtype_supported(_inferred_dtype, "array")+

I think the check should be put in the existing if type(object) is np.ndarray block just below this.

tudorcebere

comment created time in 8 days

PullRequestReviewEvent

issue commentscipy/scipy

scipy.sparse.csgraph.connected_components does not work for directed graphs

No problem, glad you got it figured out!

gunan

comment created time in 9 days

issue commentgoogle/jax

Help with array slice indices

Yeah, that definitely looks like it could be expressed as a convolution. I'd suggest going that route rather than writing all the loops manually. You can construct a kernel that essentially encodes the sums that happen at each location, and do the whole thing in one operation.

AashinShazar

comment created time in 9 days

Pull request review commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

 def shard_args(devices: Sequence[xb.xla_client.Device],  shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {} shard_arg_handlers[core.Unit] = \-    lambda x, devices, _: [xla.device_put(core.unit, d) for d in devices]+    lambda x, devices, _: list(it.chain.from_iterable(xla.device_put(core.unit, d) for d in devices)) def _shard_array(x, devices, indices):-  return [xla.device_put(x[i], d) for (i, d) in zip(indices, devices)]+  return list(it.chain.from_iterable(xla.device_put(x[i], d) for (i, d) in zip(indices, devices)))

Done, see below.

jakevdp

comment created time in 9 days

PullRequestReviewEvent

pull request commentgoogle/jax

Allow JAX objects to be represented by multiple buffers

Running benchmarks/pmap_benchmark.py.

Results on master:

---------Benchmark results for pmap_shard_sharded_device_array_nargs=10_nshards=1---------
mean=0.018155 std=0.000451 %std=2.484562 total=10.003359
#iters=551 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=100_nshards=1---------
mean=0.085240 std=0.001963 %std=2.303262 total=10.058309
#iters=118 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=101_nshards=1---------
mean=0.088433 std=0.003865 %std=4.370711 total=10.081317
#iters=114 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=500_nshards=1---------
mean=0.389694 std=0.013617 %std=3.494267 total=10.132035
#iters=26 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=1000_nshards=1---------
mean=0.792406 std=0.023802 %std=3.003702 total=10.301281
#iters=13 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=5000_nshards=1---------
mean=4.381965 std=0.090272 %std=2.060070 total=13.145895
#iters=3 #warmup=1

---------Benchmark summary for pmap_shard_sharded_device_array---------
  nargs    nshards       mean     %std    relative
-------  ---------  ---------  -------  ----------
     10          1  0.0181549  2.48456     1
    100          1  0.0852399  2.30326     4.69514
    101          1  0.0884326  4.37071     4.871
    500          1  0.389694   3.49427    21.4649
   1000          1  0.792406   3.0037     43.6469
   5000          1  4.38197    2.06007   241.365

---------Benchmark results for pmap_shard_device_array_nargs=10_nshards=1---------
mean=0.015208 std=0.001568 %std=10.307562 total=10.006748
#iters=658 #warmup=1

---------Benchmark results for pmap_shard_device_array_nargs=100_nshards=1---------
mean=0.130824 std=0.003322 %std=2.539380 total=10.073420
#iters=77 #warmup=1

---------Benchmark results for pmap_shard_device_array_nargs=500_nshards=1---------
mean=0.725450 std=0.077750 %std=10.717554 total=10.156296
#iters=14 #warmup=1

---------Benchmark summary for pmap_shard_device_array---------
  nargs    nshards       mean      %std    relative
-------  ---------  ---------  --------  ----------
     10          1  0.0152078  10.3076      1
    100          1  0.130824    2.53938     8.60239
    500          1  0.72545    10.7176     47.7024

---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=1---------
mean=0.024094 std=0.003001 %std=12.455350 total=10.023266
#iters=416 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=1---------
mean=0.093125 std=0.004977 %std=5.344291 total=10.057464
#iters=108 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=1---------
mean=0.436330 std=0.085204 %std=19.527318 total=10.035598
#iters=23 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=1000_nshards=1---------
mean=0.828440 std=0.016951 %std=2.046166 total=10.769719
#iters=13 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=5000_nshards=1---------
mean=4.229745 std=0.007748 %std=0.183187 total=12.689235
#iters=3 #warmup=1

---------Benchmark summary for pmap_shard_outputs---------
  nouts    nshards       mean       %std    relative
-------  ---------  ---------  ---------  ----------
     10          1  0.0240944  12.4553       1
    100          1  0.0931247   5.34429      3.86499
    500          1  0.43633    19.5273      18.1092
   1000          1  0.82844     2.04617     34.3831
   5000          1  4.22974     0.183187   175.549

---------Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_indices---------
mean=8.168176 std=0.702482 %std=8.600236 total=16.336353
#iters=2 #warmup=1

---------Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_2D_indices---------
mean=11.738572 std=0.000000 %std=0.000000 total=11.738572
#iters=1 #warmup=1

---------Benchmark summary for ShardedDeviceArray_indexing---------
indices_fn              mean     %std    relative
------------------  --------  -------  ----------
integer_indices      8.16818  8.60024     1
integer_2D_indices  11.7386   0           1.43711

Results on this branch:

---------Benchmark results for pmap_shard_sharded_device_array_nargs=10_nshards=1---------
mean=0.019015 std=0.001062 %std=5.584433 total=10.001693
#iters=526 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=100_nshards=1---------
mean=0.094039 std=0.007497 %std=7.972502 total=10.062211
#iters=107 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=101_nshards=1---------
mean=0.090250 std=0.003046 %std=3.375537 total=10.017800
#iters=111 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=500_nshards=1---------
mean=0.407412 std=0.035182 %std=8.635389 total=10.185308
#iters=25 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=1000_nshards=1---------
mean=0.764858 std=0.011401 %std=1.490611 total=10.708011
#iters=14 #warmup=1

---------Benchmark results for pmap_shard_sharded_device_array_nargs=5000_nshards=1---------
mean=4.366896 std=0.048288 %std=1.105768 total=13.100688
#iters=3 #warmup=1

---------Benchmark summary for pmap_shard_sharded_device_array---------
  nargs    nshards       mean     %std    relative
-------  ---------  ---------  -------  ----------
     10          1  0.0190146  5.58443     1
    100          1  0.0940394  7.9725      4.94563
    101          1  0.0902505  3.37554     4.74637
    500          1  0.407412   8.63539    21.4263
   1000          1  0.764858   1.49061    40.2247
   5000          1  4.3669     1.10577   229.66

---------Benchmark results for pmap_shard_device_array_nargs=10_nshards=1---------
mean=0.016362 std=0.000771 %std=4.711918 total=10.013431
#iters=612 #warmup=1

---------Benchmark results for pmap_shard_device_array_nargs=100_nshards=1---------
mean=0.159848 std=0.019341 %std=12.099536 total=10.070426
#iters=63 #warmup=1

---------Benchmark results for pmap_shard_device_array_nargs=500_nshards=1---------
mean=0.756038 std=0.040831 %std=5.400649 total=10.584529
#iters=14 #warmup=1

---------Benchmark summary for pmap_shard_device_array---------
  nargs    nshards       mean      %std    relative
-------  ---------  ---------  --------  ----------
     10          1  0.0163618   4.71192     1
    100          1  0.159848   12.0995      9.76958
    500          1  0.756038    5.40065    46.2074

---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=1---------
mean=0.024195 std=0.001320 %std=5.455058 total=10.016746
#iters=414 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=1---------
mean=0.094919 std=0.010104 %std=10.644888 total=10.061403
#iters=106 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=1---------
mean=0.410957 std=0.023368 %std=5.686321 total=10.273933
#iters=25 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=1000_nshards=1---------
mean=0.885113 std=0.035264 %std=3.984146 total=10.621352
#iters=12 #warmup=1

---------Benchmark results for pmap_shard_outputs_nouts=5000_nshards=1---------
mean=4.575136 std=0.038652 %std=0.844824 total=13.725407
#iters=3 #warmup=1

---------Benchmark summary for pmap_shard_outputs---------
  nouts    nshards       mean       %std    relative
-------  ---------  ---------  ---------  ----------
     10          1  0.024195    5.45506      1
    100          1  0.0949189  10.6449       3.92307
    500          1  0.410957    5.68632     16.9852
   1000          1  0.885113    3.98415     36.5824
   5000          1  4.57514     0.844824   189.094

---------Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_indices---------
mean=7.891942 std=0.006513 %std=0.082526 total=15.783884
#iters=2 #warmup=1

---------Benchmark results for ShardedDeviceArray_indexing_indices_fn=integer_2D_indices---------
mean=12.761788 std=0.000000 %std=0.000000 total=12.761788
#iters=1 #warmup=1

---------Benchmark summary for ShardedDeviceArray_indexing---------
indices_fn              mean       %std    relative
------------------  --------  ---------  ----------
integer_indices      7.89194  0.0825257     1
integer_2D_indices  12.7618   0             1.61707

It looks like for functions with many buffers (e.g. pmap_shard_outputs) there is a ~5% cost associated wtih the buffer tupling/flattening introduced by this code (4.37 seconds -> 4.58 seconds).

jakevdp

comment created time in 9 days

issue commentscipy/scipy

scipy.sparse.csgraph.connected_components does not work for directed graphs

Note that the default is weakly-connected components. If you want strongly-connected components, you can set the connection='strong' keyword.

gunan

comment created time in 9 days

issue commentscipy/scipy

scipy.sparse.csgraph.connected_components does not work for directed graphs

When I run your code using scipy 1.5.2, I get the following output:

  (0, 1)	1
  (0, 2)	1
  (1, 2)	1
  (3, 4)	1
[0 0 0 1 1]

This appears to be correct (nodes 0, 1, 2 are in a group with mutual connections, and nodes 3, 4 are in a second group with mutual connections).

Is this the output that you're seeing?

gunan

comment created time in 9 days

PR opened google/jax

fix benchmark sums

Broken by #4195

+2 -2

0 comment

1 changed file

pr created time in 9 days

create barnchjakevdp/jax

branch : benchmarks

created branch time in 9 days

more