profile
viewpoint
Tom Hennigan tomhennigan @DeepMind London, UK https://tom.gd Software Engineer at @deepmind. Previously: @google, @duedil-ltd.

google/jax 9033

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

deepmind/sonnet 8425

TensorFlow-based neural network library

deepmind/dm-haiku 663

JAX-based neural network library

drf/amsn2 110

The new version of aMSN client

deepmind/tree 85

tree is a library for working with nested data structures

tomhennigan/monzo-fs 23

🏦 📁 A FUSE filesystem to access your Monzo bank account(s)

tomhennigan/nyan_logger 4

Nyan cat `logging.Formatter` implementation.

tomhennigan/dnssniff 3

dnssniff is a libpcap based capture utility that logs dns requests through an interface.

tarnfeld/gif-man 2

A simple SIMBL bundle for Skype for inline gifs, videos and other awesomeness.

Pull request review commentgoogle/jax

Add simple JAX API microbenchmarks.

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+"""Microbenchmarks for JAX `api` functions."""+import functools+import operator++import jax+import jax.numpy as jnp++import benchmark+++def required_devices(num_devices_required):+  """Helper to skip benchmarks that require more devices."""+  def helper1(f):+    @functools.wraps(f)+    def helper2(state):+      if jax.device_count() < num_devices_required:+        state.skip_with_error(f"requires {num_devices_required} devices")+        return+      return f(state)+    return helper2+  return helper1+++@benchmark.register+def jit_trivial(state):+  f = jax.jit(swap)+  a, b = f(1, 2)++  while state:+    a, b = f(a, b)

I think the goal here is to measure dispatch time and IIRC jit(swap) does not actually result in an xlacomputation so the block_until_ready is not necessary:

c, d = jit(lambda a, b: (b, a))(a, b)
assert c is b and d is a

Agree we can probably skip assignment.

chr1sj0nes

comment created time in 12 hours

issue closeddeepmind/dm-haiku

Folder Structure

I've noticed _src has become quite large. I think eventually splitting it up into folders makes more sense. We could have:

  • nn
  • initializers
  • regularizers
  • losses
  • metrics

closed time in 3 days

cgarciae

issue commentdeepmind/dm-haiku

Folder Structure

Thanks for the suggestion! The folder structure here matches Sonnet (2) and our experience with managing a flat structure there has been that it works quite well for navigating code and has not been a maintenance burden.

Additionally we don't expect the root of _src to grow much over the next year or two, we may add more nets and examples (which are already in subdirs) but the core set of features and modules has grown extremely slowly compared to other libraries, so I don't think we need to do this pre-emptively.

Haiku (and Sonnet) both strongly encourage development to occur outside the core library (e.g. DeepMind recently released ACME which includes Haiku implementations of a number of common RL agents/networks) and new modules are added to the core library very rarely (as you can see from #46 we are quite conservative 😄 ).

So I think for now we should maintain the structure as is.

cgarciae

comment created time in 3 days

issue closeddeepmind/dm-haiku

hk.Module should be an abstract Callable

Hey, I use pyright / pylance for type checking and they are pretty unhappy that hk.Module doesn't define and abstract __call__ method, I get type errors all over the place when defining code that take arbitrary hk.Modules. Given most of Haiku is already typed this would be a nice addition.

closed time in 3 days

cgarciae

issue commentdeepmind/dm-haiku

hk.Module should be an abstract Callable

Thanks for the FR! In Haiku we don't special case __call__ (or other methods) meaning that module instances don't actually have to be callable. As an example for a VAE you may want a single module that defines def encode(self, x) and def decode(self, z) but not __call__.

It is common for modules to be callable of course, and we have considered adding def __call__(self, *a, **k) -> Any: raise NotImplementedError to the Module base class (purely for type hints), however our current thinking is that this is not actually more useful as a type hint than users using Callable[..., Any] and actually might be harmful (modules that aren't actually callable would pass the static analysis). Another downside IMHO is that we cannot define a calling convention other than *a, **k -> Any because users can (and do) do anything with their __call__ method.

Using Callable instead of hk.Module may have other benefits, for example in most places where you could pass a callable module you could also pass a function. As a concrete example in our Sequential module we take a list of callables and this means users can pass lambda x: x or a module instance or a JAX function etc.

We're keen where possible to encourage JAX code to be decoupled from Haiku, we feel that overall this is best for the ecosystem and users will not be locked into a particular way of using JAX by our libraries.

Concretely, if you're thinking about requiring hk.Modules in your type signatures to, I would suggest instead requiring Callable[..., Any] and even better define the calling convention you require too: Callable[[jnp.ndarray], jnp.ndarray].

cgarciae

comment created time in 3 days

issue closeddeepmind/dm-haiku

hk.add_loss

To enable users to easily create per layer weight and activity regularizers plus other forms of losses created by intermediate layers it would be very useful if haiku had a hk.add_loss utility that when called within a transform it would append a loss to a list of losses which the user could later retrieve as an additional output from apply. I guess that this would require an additional flag to hk.transform and friends.

closed time in 3 days

cgarciae

issue commentdeepmind/dm-haiku

hk.add_loss

Hey @cgarciae , thanks for the FR! In general we try to keep Haiku fairly lean and encourage features (e.g. training loops, optimizers etc) to be solved in other libraries (then they can benefit all JAX users not just Haiku users) or built from existing Haiku/JAX features.

Wrt using existing features, you might consider using hk.set_state for this. This is a fairly general mechanism in Haiku for logging values associated with modules. You could use hk.data_structures.filter to extract all losses from the state dict:

def f(x):
  y = hk.nets.ResNet50(1000)(x, True)
  loss_1 = y.sum()
  loss_2 = loss_1 ** 2
  hk.set_state("loss_1", loss_1)
  return loss_2

f = hk.transform_with_state(f)

rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 224, 224, 3])
params, state = f.init(rng, x)

# Apply as usual:
params, state = f.apply(params, state, rng, x)

# Extract losses from state:
is_loss = lambda m, n, v: n.startswith("loss")
losses = hk.data_structures.filter(is_loss, state)
print(losses)  # frozendict({'~': frozendict({'loss_1': DeviceArray(0., dtype=float32)})})

If you want to implement this as a standalone feature (e.g. decoupled from hk.set_state) then I've forked the following from @ibab who has implemented something similar (his version is more robust with thread safety and nesting).

from contextlib import contextmanager

loggers = []

@contextmanager
def context():
  data = {}
  loggers.append(data)
  try:
    yield data
  finally:
    loggers.pop()

def log(name, value):
  # NOTE: log(..) ignored when not logging.
  if loggers:
    data = loggers[-1]
    data[name] = value

def f(x):
  x = x ** 2
  log("a", x)
  x = x ** 2
  log("b", x)
  return x

def g():
  with context() as data:
    y = f(2)
  return y, data

y, data = g()
assert y == 16
assert data == {'a': 4, 'b': 16}

I hope that's useful, please feel free to reopen if this does not solve your usecase.

cgarciae

comment created time in 3 days

issue commentdeepmind/dm-haiku

JAX version in Colab

Thanks for the report, we've avoided pinning jaxlib since afaik users should pick a specific distribution of this depending on their local setup (c.f. https://github.com/google/jax#installation).

I'm open to pinning a minimum version of jax, but Currently at HEAD Haiku also relies on unreleased JAX. We need to track JAX commits relatively closely because we use some more niche features in JAX (e.g. we have a tracer for to_dot and a custom primitive for named_call) and both JAX and Haiku are synced to Alphabet's monorepo where libraries need to depend on each-other and build at HEAD.

I'm planning to do a new release this week (since we just flipped the default value of apply_rng in transform) and will chat with the JAX team to get their advice on how to make this dependency more explicit.

joaogui1

comment created time in 3 days

Pull request review commentgoogle/jax

Compare treedefs by `num_leaves` not `traversal_` in `tree_transpose`.

 std::unique_ptr<PyTreeDef> PyTreeDef::Compose(const PyTreeDef& inner) const {       out->traversal_.push_back(n);     }   }+  const auto& root = traversal_.back();+  const auto& inner_root = inner.traversal_.back();+  auto& out_root = out->traversal_.back();+  out_root.num_nodes = (root.num_nodes - root.num_leaves) +

Re mutating in place afaik out->traversal_.push_back(n) in the loop above will make a copy of the node, so we're only mutating the copy.

Re restoring for all interior nodes, I think you're right. In the cases I was looking carefully at we have a very simple traversal (kLeaf, .., kLeaf, kCustom) but for a deeply nested tree we would need to recompute the full traversal I think. I'll take a look later this week at it, but I'm surprised our tests pass as is?

tomhennigan

comment created time in 3 days

PR opened google/jax

Compare treedefs by `num_leaves` not `traversal_` in `tree_transpose`.

In general for a kCustom node it is not guaranteed that a.compose(b) will have the same traversal_ as some structure c (which is the composition of a+b). We have a real world example in deepmind/dm-haiku with our FlatMapping type and I've put a simpler example in tree_util_tests.py.

Since this test seems largely to be for input validation I've changed this to compute the expected number of leaves (which is cheaper than using compose as the previous implementation did) which will catch common errors and is guaranteed to work for any well formed pytree (additionally I had to fix the leaf and node count for composed pytrees which were wrong at HEAD).

+91 -4

0 comment

3 changed files

pr created time in 5 days

create barnchtomhennigan/jax

branch : changelist/319616371

created branch time in 5 days

pull request commentgoogle/jax

Correctly handle bfloat16 input and output from jax2tf functions.

How did you find this bug? Was it like in the get_concrete_function test?

Yeah Tamas found these with exactly that style of test.

tomhennigan

comment created time in 5 days

PR opened google/jax

Correctly handle bfloat16 input and output from jax2tf functions.

TF and JAX have different NumPy dtypes for bfloat16 so we need to be careful to use the right version. I think there are a few other cases in jax2tf where we should be using to_tf_dtype rather than passing v.dtype directly into tf ops (e.g. I am a bit surprised to only update convert_element_type_p), however I think a follow up adding tests for those cases would be best.

+28 -6

0 comment

2 changed files

pr created time in 6 days

create barnchtomhennigan/jax

branch : changelist/319521416

created branch time in 6 days

issue closeddeepmind/sonnet

Device Placement with Tensorflow 2

I'm trying to specify execution of a Sonnet module on the CPU instead of the GPU. How can I specify device placement of Sonnet modules in Tensorflow 2? A previous issue suggested using functions with tf.device like this:

from tensorflow.core.framework import node_def_pb2

[...]
def device_setter(op):
  _variable_ops = ["Variable", "VariableV2", "VarHandleOp"]
  node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
  return  '/cpu:0' if node_def.op in _variable_ops else '/gpu:0'
[...]   

with tf.device(device_setter):
  outputs = rcnn_output(t)

However, with tensorflow 2, tf.device can no longer accept functions as arguments. How can I work around this?

closed time in 8 days

agnusmaximus

issue commentdeepmind/sonnet

Device Placement with Tensorflow 2

Hi @agnusmaximus if you want to run on the CPU then create and use your modules inside a with tf.device("CPU"): block. For example:

x = tf.ones([1, 28 * 28])
with tf.device("CPU"):
  mod = snt.nets.MLP([300, 100, 10])
  logits = mod(x)

It may be useful to know that outside of a with tf.device scope TensorFlow's default device placement policy will silently move tensors from the CPU to the GPU, you may want to set TF's device policy to explicit if you want to ensure all computation remains on the CPU:

tf.config.experimental.set_device_policy("explicit")

If you want your whole TF program to only run using CPU then you can configure TensorFlow to ignore all other devices using the following at the start of your program:

# Ignore all devices other than available CPU(s).
cpus = tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpus)
agnusmaximus

comment created time in 8 days

Pull request review commentgoogle/jax

[jax2tf] Add support for first-order AD to converted functions

 def add_unit(v: TfValOrUnit, aval: core.AbstractValue): tf_impl: Dict[core.Primitive,               Callable[..., Any]] = {} -def disable_gradient(fun):-  """Prevents the wrapped function from being differentiated."""--  def grad_disabled(*dy, variables=None):-    raise ValueError("jax2tf currently does not support gradients. Please "-                     "reach out to jax-core@ if this feature is a blocker "-                     "for you, we are working on it!")--  is_tensor = lambda t: isinstance(t, (tf.Tensor, tf.Variable))--  def wrapper(*args, **kwargs):-    flat_args, treedef = jax.tree_flatten((args, kwargs))-    tensor_idxs = [i for i, t in enumerate(flat_args) if is_tensor(t)]-    tensor_args = [t for t in flat_args if is_tensor(t)]--    @tf.custom_gradient-    def inner_wrapper(*tensor_args):-      flat_copy = list(flat_args)-      for i, t in zip(tensor_idxs, tensor_args):-        flat_copy[i] = t--      args, kwargs = jax.tree_unflatten(treedef, flat_copy)-      out = fun(*args, **kwargs)-      return out, grad_disabled--    return inner_wrapper(*tensor_args)--  return wrapper---def convert(fun):+def convert(fun,+            with_gradient=False):

I feel like True might be a safer default to start with? Also nit that we don't need to split the line here.

gnecula

comment created time in 8 days

Pull request review commentgoogle/jax

[jax2tf] Add support for first-order AD to converted functions

 def add_unit(v: TfValOrUnit, aval: core.AbstractValue): tf_impl: Dict[core.Primitive,               Callable[..., Any]] = {} -def disable_gradient(fun):-  """Prevents the wrapped function from being differentiated."""--  def grad_disabled(*dy, variables=None):-    raise ValueError("jax2tf currently does not support gradients. Please "-                     "reach out to jax-core@ if this feature is a blocker "-                     "for you, we are working on it!")--  is_tensor = lambda t: isinstance(t, (tf.Tensor, tf.Variable))--  def wrapper(*args, **kwargs):-    flat_args, treedef = jax.tree_flatten((args, kwargs))-    tensor_idxs = [i for i, t in enumerate(flat_args) if is_tensor(t)]-    tensor_args = [t for t in flat_args if is_tensor(t)]--    @tf.custom_gradient-    def inner_wrapper(*tensor_args):-      flat_copy = list(flat_args)-      for i, t in zip(tensor_idxs, tensor_args):-        flat_copy[i] = t--      args, kwargs = jax.tree_unflatten(treedef, flat_copy)-      out = fun(*args, **kwargs)-      return out, grad_disabled--    return inner_wrapper(*tensor_args)--  return wrapper---def convert(fun):+def convert(fun,+            with_gradient=False):   """Transforms `fun` to be executed by TensorFlow.    Args:     fun: Function to be transformed. Its arguments and return value should be       JAX arrays, or (nested) standard Python containers (tuple/list/dict)       thereof.+    with_gradient: if set, will add a tf.custom_gradient to the converted+      function, by converting the ``jax.vjp(fun)``. Only first-order+      differentiation is supported for now. If the converted function is+      saved in a SavedModel, the custom gradients are currently lost and+      an error will be raised if a gradient computation is attempted.    Returns:     A version of `fun` that expects TfVals as arguments (or     tuple/lists/dicts) thereof, and returns TfVals as outputs.   """   api._check_callable(fun) -  @disable_gradient-  def wrapped_fun(*args: TfValOrUnit) -> TfValOrUnit:-    # TODO(necula): remove the jit disabling once we handle all control-flow.-    # Disabling the jit helps to avoid some unsupported jax primitives.-    # E.g. scan will be statically unrolled.-    f = lu.wrap_init(fun)+  def converted_fun(*args: TfVal) -> TfVal:+    # This function may take pytrees of TfVals. We can only set+    # tf.custom_gradient on functions that take a flat argument list.     args_flat, in_tree = tree_util.tree_flatten((args, {}))     for a in args_flat:       if not _is_tfvalorunit(a):         msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "                "be NumPy array, scalar, tf.Variable, or tf.Tensor")         raise TypeError(msg)-    flat_fun, out_tree = flatten_fun(f, in_tree)-    out_flat = _interpret_fun(flat_fun, args_flat)-    return tree_util.tree_unflatten(out_tree(), out_flat) -  return wrapped_fun+    f = lu.wrap_init(fun)+    # out_tree_thunk() will be the output tree, after running _interpret_fun.+    flat_fun, out_tree_thunk = flatten_fun(f, in_tree)++    # Prepare the grad_fn for tf.custom_gradient.+    def converted_grad_fn(*out_cts_flat: TfVal, **kwargs):+      # TODO(cl/318778369): change **kwargs with variables=None+      variables = kwargs.get("variables", [])+      if variables:+        raise ValueError("Unexpected variables used in forward pass. "+                         "This should not happen for first-order differentiation. "+                         f"variables={variables}")++      def fun_vjp_jax(args_jax, out_cts_jax):+        # One may think that we can get the pullback while we are converting+        # the main function in the first place. That is problematic, because the+        # pullback may contain captured tracers from the conversion of the+        # main function. Those tracers will confuse the conversion of the+        # pullback. So, we construct the vjp anew.

I'm not sure I fully understand the implications of this, but does this mean if I have a custom gradient that closes over some expensive to compute intermediate value from forward and uses it as part of computing the gradient that when running via jax2tf we will actually always recompute this value (because the custom gradient for our function is "disconnected" from the forward pass)?

As a concrete example for this pseudocode:

@custom_gradient
def op(x):
  intermediate = x ** 2
  def grad(g):
    return 5 * g * intermediate
  return intermediate ** 2

I think we would produce a tf graph with two copies of the forward pass, one to compute the forward result and one in order for us to compute grad (aka. we'll recompute intermediate in order to compute grad). A smart enough compiler could fix this for us I guess. I'm a bit worried because this is a pretty common use case for custom gradients so we should at least call it out as a major limitation.

gnecula

comment created time in 8 days

IssuesEvent

issue commentgoogle/jax

Buffer donation to a jit function

FYI buffeer donation is only supported on TPU at the moment, XLA team are working to support this on CPU/GPU but that may be why we cannot use the donation.

romanngg

comment created time in 9 days

PR opened google/jax

Retain original docstring when vmap'ing functions.
+16 -0

0 comment

2 changed files

pr created time in 11 days

create barnchtomhennigan/jax

branch : changelist/318766285

created branch time in 11 days

PR opened google/jax

Cast int8 to bool for lax.not in jax2tf
+15 -1

0 comment

2 changed files

pr created time in 17 days

push eventtomhennigan/jax

Tom Hennigan

commit sha 98e28c56441e4b4f30f400040d3d0d40f6f6510b

Cast int8 to bool for lax.not in jax2tf.

view details

push time in 17 days

create barnchtomhennigan/jax

branch : changelist/317680303

created branch time in 17 days

Pull request review commentgoogle/jax

[jax2tf] Fixed the handling of `core.unit` in control-flow primitives.

 def doit():   def _interpret_fun(fun: lu.WrappedFun,-                   in_vals: Sequence[TfVal]) -> Sequence[TfVal]:+                   in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]:   with core.new_master(TensorFlowTrace) as master:     fun = _interpret_subtrace(fun, master)-    out_vals = fun.call_wrapped(*in_vals)+    out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals)     del master   return out_vals   @lu.transformation-def _interpret_subtrace(master: core.MasterTrace, *in_vals: TfVal):+def _interpret_subtrace(master: core.MasterTrace, *in_vals: TfValOrUnit):   trace = TensorFlowTrace(master, core.cur_sublevel())-  in_tracers = [TensorFlowTracer(trace, val) for val in in_vals]-  outs = yield in_tracers, {}-  out_tracers = map(trace.full_raise, outs)-  out_vals = [t.val for t in out_tracers]+  in_tracers = tuple(TensorFlowTracer(trace, val) for val in in_vals)+  outs = yield in_tracers, {}  # type : Sequence[TfValOrUnit]+  out_tracers: Sequence[TensorFlowTracer] = map(trace.full_raise, outs)  # type: ignore[assignment]

The type here is a lie since map returns a generator? Would be easier to type that instead (or not type).

gnecula

comment created time in 25 days

Pull request review commentgoogle/jax

[jax2tf] Fixed the handling of `core.unit` in control-flow primitives.

 from tensorflow.compiler.tf2xla.python import xla as tfxla  # type: ignore[import] from tensorflow.compiler.xla import xla_data_pb2  # type: ignore[import] -# A value suitable in a TF tracing context: tf.Tensor, tf.Var, or-# Python scalar or numpy.ndarray.+# A value suitable in a TF tracing context: tf.Tensor, tf.Variable, tf.EagerTensor,+# or Python scalar or numpy.ndarray. TfVal = Any +# During JAX transformations we sometimes produce a Jaxpr that has arguments+# of abstract value core.abstract_unit and results equal to core.unit.+# These are arguments and results that are not used in the computation.+# Whenever we are in a JAX tracing context we must use `core.unit` values+# in those places. However, when we move to TF we have to turn them into+# some small TFVal; it does not matter which value since it will never be used+# in an actual operation.+TfValOrUnit = Union[TfVal, core.Unit]++_unit_tfval = tf.constant(9, tf.uint8)

Perhaps f32 NaN? Incase we leak this in somewhere it would be better to propagate NaNs not scale/shift by 9.

gnecula

comment created time in 25 days

Pull request review commentgoogle/jax

[jax2tf] Fixed the handling of `core.unit` in control-flow primitives.

 from tensorflow.compiler.tf2xla.python import xla as tfxla  # type: ignore[import] from tensorflow.compiler.xla import xla_data_pb2  # type: ignore[import] -# A value suitable in a TF tracing context: tf.Tensor, tf.Var, or-# Python scalar or numpy.ndarray.+# A value suitable in a TF tracing context: tf.Tensor, tf.Variable, tf.EagerTensor,

FYI I don't think tf.EagerTensor is a public type in TF.

gnecula

comment created time in 25 days

issue commentdeepmind/rlax

Install error with jaxlib version 0.1.47

Are you on Windows? I think at the moment jaxlib is not supported (except on CPU via WSL) hence why you can't find a distribution.

renos-zabounidis

comment created time in 25 days

pull request commentgoogle/jax

Attach source info to Jaxpr equations.

Have you guys tried benchmarking compilation time with this turned on? @LenaMartens has been working on enabling named_call [0] for all Haiku users and we've found that adding that this incurs a 2x regression in compile time (we have a benchmark internally that we can share with you if useful). We're not 100% sure where this comes from, but one thing we have in common with you is adding a bunch of additional metadata to the computation.

I'd also be keen to see what this looks like in practice for users of OO JAX libraries (Haiku/Flax et al). I suspect this will be something useful like basic.py:123 Linear:__call__ but I'm guessing 😄

[0] https://dm-haiku.readthedocs.io/en/latest/api.html#named-call

hawkinsp

comment created time in a month

PR opened google/jax

Improve error message when passing an invalid dtype.

I spotted this when debugging an issue like:

self.assertEqual(x_jax, x_tf, check_dtypes=True)

The fix here is of course to use x_tf.numpy(), but it was not clear where the error was from originally.

+8 -1

0 comment

2 changed files

pr created time in a month

create barnchtomhennigan/jax

branch : changelist/315865003

created branch time in a month

issue commentgoogle/jax

Add support for buffer donation (input/output aliasing)

FYI this was fixed for TPU in #2936 and XLA team are in the process of supporting this on GPU right now.

hawkinsp

comment created time in a month

MemberEvent
MemberEvent

pull request commentdeepmind/dm-haiku

Adds Identity initializer

It was pytype (gain was marked as float but we pass a jnp.ndarray in one of the tests. I've fixed this as part of the process of importing the change back into the DeepMind codebase. All the tests are passing now and I fixed another minor linter issue (adding a new line in the docstrings).

I'll see if we can add this as a github action too, that and pylint are two tools where internally we treat warnings as errors to try and have high code quality.

FYI we have a slightly janky (same as Google) setup where we mirror our internal version of Haiku to GitHub and this means to merge PRs we pull the same change into our internal repo, merge it there and then a copybara robot marks the PR as merged). That's why it takes a bit of time between approving the PR and it being merged. This should be merged in a few hours once someone internally LGTMs my proposal to import this 😄

joaogui1

comment created time in a month

created tagdeepmind/dm-haiku

tagv0.0.1

JAX-based neural network library

created time in a month

release deepmind/dm-haiku

v0.0.1

released time in a month

delete tag deepmind/dm-haiku

delete tag : v0.0.1

delete time in a month

created tagdeepmind/dm-haiku

tagv0.0.1

JAX-based neural network library

created time in a month

release deepmind/dm-haiku

v0.0.1

released time in a month

pull request commentdeepmind/dm-haiku

Adds Identity initializer

Sorry for the delay @joaogui1, I've been less productive than usual the last few weeks.. Looks good!

joaogui1

comment created time in a month

create barnchdeepmind/dm-haiku

branch : pr/40

created branch time in a month

issue closeddeepmind/sonnet

bug: Batch Normalization moving average of variance is initialized as 0

The moving average of the variance in Batch Normalization is initialized as 0 at here and here. As a result, when I directly evaluate a randomly initialized network without training it first, the results are always NaN.

closed time in a month

pluskid

issue commentdeepmind/sonnet

bug: Batch Normalization moving average of variance is initialized as 0

Hey @pluskid I suspect this is only the case if you are setting eps=0, we (inline with the paper) include a small epsilon (default: 1e-5) to avoid division by zero (https://sonnet.readthedocs.io/en/latest/api.html#batchnorm) and indeed for a randomly initialized batchnorm layer I do not get NaN for any value of is_training+test_local_stats:

x = tf.random.normal([4, 4])
for is_training in (True, False):
  for test_local_stats in (True, False):
    bn = snt.BatchNorm(True, True)
    y = bn(x, is_training=is_training, test_local_stats=test_local_stats)
    assert not tf.reduce_any(tf.math.is_nan(y))

If you want to keep eps=0 and avoid division by zero you could pass test_local_stats=True to your BatchNorm layers to use the mean/variance of the current batch rather than the average.

pluskid

comment created time in a month

MemberEvent
MemberEvent
MemberEvent
MemberEvent
MemberEvent
MemberEvent

pull request commentgoogle/jax

Refactoring of jax_to_tf tests:

Overall looks good, but as a metapoint I personally much prefer keeping a 1:1 association between unit tests and source files. For example control_flow_test.py should pair with control_flow.py. I don't think that is the style in JAX core, but is there a developer guide which articulates how code is structured in JAX?

gnecula

comment created time in a month

Pull request review commentgoogle/jax

Refactoring of jax_to_tf tests:

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++from typing import Any, Callable, Sequence, Tuple+import tensorflow as tf  # type: ignore[import]++from jax.experimental import jax_to_tf+from jax import test_util as jtu++class JaxToTfTestCase(jtu.JaxTestCase):++  def assertAllClose(self, x, y, *, atol=None, rtol=None,+                     check_dtypes=False, canonicalize_dtypes=False):+    """Compares recursively list/tupls of arrays."""+    # TODO: turn on check_dtypes (get error comparing TF dtype and JAX dtype)

I'd suggest asserting that the kwargs here have their default value as well as the comment.

gnecula

comment created time in a month

Pull request review commentgoogle/jax

Refactoring of jax_to_tf tests:

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+"""Tests for the jax_to_tf conversion for control-flow primitives."""++from absl.testing import absltest+from absl.testing import parameterized+from typing import Any, Callable, Sequence, Tuple++import jax+import jax.lax as lax+import jax.numpy as jnp+from jax import test_util as jtu+import numpy as np++from jax.experimental.jax_to_tf.tests import tf_test_util++from jax.config import config+config.parse_flags_with_absl()+++class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):++  @parameterized.named_parameters(jtu.cases_from_list(+    dict(testcase_name=f"_function={with_function}",+         with_function=with_function)+    for with_function in [False, True]))+  def test_cond(self, with_function=False):+    def f_jax(pred, x):+      return lax.cond(pred, lambda t: t + 1., lambda f: f, x)++    self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)+    self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)++  @parameterized.named_parameters(jtu.cases_from_list(+    dict(testcase_name=f"_function={with_function}",+         with_function=with_function)+    for with_function in [False, True]))+  def test_cond_multiple_results(self, with_function=False):+    def f_jax(pred, x):+      return lax.cond(pred, lambda t: (t + 1., 1.), lambda f: (f + 2., 2.), x)++    self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)+    self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)++  @parameterized.named_parameters(jtu.cases_from_list(+    dict(testcase_name=f"_function={with_function}",+         with_function=with_function)+    for with_function in [False, True]))+  def test_while_single_carry(self, with_function=False):+    """A while with a single carry"""+    def func(x):+      # Equivalent to:+      #      for(i=x; i < 4; i++);+      return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)++    self.ConvertAndCompare(func, 0, with_function=with_function)++  @parameterized.named_parameters(jtu.cases_from_list(+    dict(testcase_name=f"_function={with_function}",+         with_function=with_function)+    for with_function in [False, True]))+  def test_while(self, with_function=False):+    # Some constants to capture in the conditional branches+    cond_const = np.ones(3, dtype=np.float32)+    body_const1 = np.full_like(cond_const, 1.)+    body_const2 = np.full_like(cond_const, 2.)++    def func(x):+      # Equivalent to:+      #      c = [1, 1, 1]+      #      for(i=0; i < 3; i++)+      #        c += [1, 1, 1] + [2, 2, 2]+      #+      # The function is set-up so that it captures constants in the+      # body of the functionals. This covers some cases in the representation+      # of the lax.while primitive.+      def cond(idx_carry):+        i, c = idx_carry+        return i < jnp.sum(lax.tie_in(i, cond_const))  # Capture cond_const++      def body(idx_carry):+        i, c = idx_carry+        return (i + 1, c + body_const1 + body_const2)++      return lax.while_loop(cond, body, (0, x))++    self.ConvertAndCompare(func, cond_const, with_function=with_function)+++  @parameterized.named_parameters(jtu.cases_from_list(+    dict(testcase_name=f"_function={with_function}",+         with_function=with_function)+    for with_function in [False, True]))+  def test_while_batched(self, with_function=True):+    """A while with a single carry"""+    def product(x, y):+      # Equivalent to "x * y" implemented as:+      #      res = 0.+      #      for(i=0; i < y; i++)+      #         res += x+      return lax.while_loop(lambda idx_carry: idx_carry[0] < y,+                            lambda idx_carry: (idx_carry[0] + 1,+                                               idx_carry[1] + x),+                            (0, 0.))++    # We use vmap to compute result[i, j] = i * j+    xs = np.arange(4, dtype=np.int32)+    ys = np.arange(5, dtype=np.int32)++    def product_xs_y(xs, y):+      return jax.vmap(product, in_axes=(0, None))(xs, y)+    def product_xs_ys(xs, ys):+      return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys)++    self.ConvertAndCompare(product_xs_ys, xs, ys, with_function=with_function)++  @parameterized.named_parameters(jtu.cases_from_list(+    dict(testcase_name=f"_function={with_function}",+         with_function=with_function)+    for with_function in [False, True]))+  def test_scan(self, with_function=False):+    def f_jax(xs, ys):+      body_const = np.ones((2, ), dtype=np.float32)  # Test constant capture+      def body(res0, inputs):+        x, y = inputs+        return res0 + x * y, body_const+      return lax.scan(body, 0., (xs, ys))++    arg = np.arange(10, dtype=np.float32)+    self.ConvertAndCompare(f_jax, arg, arg, with_function=with_function)+++if __name__ == "__main__":+  absltest.main()

google3 import will fail if there is not a newline at the end here.

gnecula

comment created time in a month

push eventtomhennigan/jax

Tom Hennigan

commit sha 464199e49a1335cf7da09d3f114987832a94c267

Add support for buffer donation in `jit` and `pmap`. For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.

view details

push time in a month

push eventtomhennigan/jax

Tom Hennigan

commit sha ee57d74156b8819d35f4067a614874c38bc0ee3e

Add support for buffer donation in `jit` and `pmap`. For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.

view details

push time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def jit(fun: Callable, static_argnums: Union[int, Iterable[int]] = (),       XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``.     backend: This is an experimental feature and the API is likely to change.       Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.+    donate_argnums: Specify which arguments are "donated" to the computation.+      When arguments are donated XLA will use their memory to store the result

Ah thanks, I hadn't updated the docstring to match the implementation. Have reworded.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def dynamic_fun(dummy, *args):     out_tuple = xb.with_sharding(c, out_parts, build_out_tuple)   else:     out_tuple = build_out_tuple()+  backend = xb.get_backend(backend)+  if backend.platform == "tpu":+    donated_invars = xla.setup_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)

Done. I was following c.setup_alias, I think thats a typo too so I'll fix it in a follow up.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def type_transfer(prim, invars, params):       if "mapped_invars" not in params:         raise TypeError(             f"Map primitive {prim} missing 'mapped_invars' parameter")+      if "donated_invars" not in params:

Happy to do so, but we'd need to match on primitive name which doesn't feel ideal, since donation doesn't apply to all call primitives (e.g. remat does not have this argument) but it does apply to all (well, the only) map primitive.

tomhennigan

comment created time in a month

push eventtomhennigan/jax

Tom Hennigan

commit sha 92cbfa5e5db57b630e1faa4e6140995fb2593d97

Add support for buffer donation in `jit` and `pmap`. For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.

view details

push time in a month

pull request commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

FYI two known issues that I think we can address in a follow up:

  1. If you try and donate the same array twice in the same computation (e.g. jit(f, donate_argnums=(0, 1))(x, x)) you get a deadlock.
  2. Buffers that are donated but not used are currently not marked as "deleted" meaning that you can reuse them (this is not a correctness issue, but might lead to confusing "this buffer is deleted" errors if you move your code to TPU where input/output aliasing is supported). In a follow up we should delete these buffers at call time.
tomhennigan

comment created time in a month

pull request commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

Thanks for the comments everyone, this PR is ready for your review again 😄 I've rebased on master and made some major changes based on the design review:

  1. We now interact correctly with JVPTrace, following the guidance of mapped_invars (thanks for the help @mattjj!!).
  2. Donation is now optional, on CPU/GPU we simply log a warning at "compile" time that the buffers aren't going to be used and do nothing else.
  3. Function arguments and type hints are now consistent with the rest of JAX.

PTAL!

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):   else:     return partial(_execute_replicated, compiled, result_handlers) +def configure_aliasing(+    c: xc.XlaBuilder,+    xla_args: Sequence[xe.XlaOp],+    out_tuple: xe.XlaOp,+    arg_donate: Sequence[bool],+    tuple_args: bool,+):+  """Configures input/output "must" aliasing based on `arg_donate`."""+  assert len(xla_args) == len(arg_donate)++  # First for every input array add it to `donations` iff it is a member of+  # `arg_donate`.+  donations = defaultdict(deque)

I think this method is pretty clear now, adding types would probably complicate the code a fair bit and not add too much value.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def f_pmapped(*args, **kwargs):         global_axis_size=axis_size,         devices=tuple(devices) if devices is not None else devices,         name=flat_fun.__name__,-        mapped_invars=tuple(axis is not None for axis in in_axes_flat))+        mapped_invars=tuple(axis is not None for axis in in_axes_flat),+        donate=donate)

Have kept this as donate_argnums on the API and donate_invars on the primitive.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):   else:     return partial(_execute_replicated, compiled, result_handlers) +def configure_aliasing(+    c: xc.XlaBuilder,+    xla_args: Sequence[xe.XlaOp],+    out_tuple: xe.XlaOp,+    arg_donate: Sequence[bool],+    tuple_args: bool,+):+  """Configures input/output "must" aliasing based on `arg_donate`."""+  assert len(xla_args) == len(arg_donate)++  # First for every input array add it to `donations` iff it is a member of+  # `arg_donate`.+  donations = defaultdict(deque)+  for arg_index, arg in enumerate(xla_args):

I'm going to assume the current tests are fine and close this out, but please re-open if you feel strongly.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):   else:     return partial(_execute_replicated, compiled, result_handlers) +def configure_aliasing(+    c: xc.XlaBuilder,+    xla_args: Sequence[xe.XlaOp],+    out_tuple: xe.XlaOp,+    arg_donate: Sequence[bool],+    tuple_args: bool,+):+  """Configures input/output "must" aliasing based on `arg_donate`."""+  assert len(xla_args) == len(arg_donate)++  # First for every input array add it to `donations` iff it is a member of+  # `arg_donate`.+  donations = defaultdict(deque)+  for arg_index, arg in enumerate(xla_args):

We test the result of this (e.g. on supported platforms input buffers are deleted and output, if there is no match we raise a warning) but I haven't unit tested the method directly, it feels like an implementation detail to me. Do you feel strongly that this is tested in isolation? To do this well I would need to refactor it a bit to not require us to build up a computation.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):   else:     return partial(_execute_replicated, compiled, result_handlers) +def configure_aliasing(+    c: xc.XlaBuilder,+    xla_args: Sequence[xe.XlaOp],+    out_tuple: xe.XlaOp,+    arg_donate: Sequence[bool],+    tuple_args: bool,+):+  """Configures input/output "must" aliasing based on `arg_donate`."""+  assert len(xla_args) == len(arg_donate)++  # First for every input array add it to `donations` iff it is a member of+  # `arg_donate`.+  donations = defaultdict(deque)+  for arg_index, arg in enumerate(xla_args):+    if arg_donate[arg_index]:+      for param_index, element in flatten_shape(c.GetShape(arg)):+        key = (element.dimensions(), element.numpy_dtype())+        if tuple_args:+          param_number = 0+          param_index = (arg_index,) + tuple(param_index)+          donations[key].append((param_number, param_index, arg_index))+        else:+          param_number = arg_index+          donations[key].append((param_number, param_index, arg_index))++  # Consume donations for outputs.+  for output_index, element in flatten_shape(c.GetShape(out_tuple)):+    key = (element.dimensions(), element.numpy_dtype())+    if donations.get(key, ()):+      param_number, param_index, _ = donations[key].popleft()+      c.SetUpAlias(output_index, param_number, param_index)++  # Check that all donations have been consumed such that buffer donation+  # "must alias".+  leftovers = tree_leaves([tuple(v) for v in donations.values()])+  if leftovers:+    leftovers = "\n".join([+        f"- flat_args[{arg_index}] (spec: {dtype.name}{list(dims)})"+        for (dims, dtype), donation in donations.items()+        for _, _, arg_index in donation])++    raise ValueError(

Per the design review I think we're good with "donation" being "use if you like" (allowing us to not use them at all on CPU/GPU today and still be "in contract")/

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def jaxpr_collectives(jaxpr):  ### xla_call underlying jit -def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name):-  compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))+def _xla_call_impl(+    fun: lu.WrappedFun,+    *args,+    device: Optional[xe.Device],+    backend,+    name: str,+    donate: Sequence[bool],+):+  if len(args) != len(donate):+    if any(donate):+      raise ValueError("Taking the JVP of a jit/pmap decorated function which "+                       "donates some arguments is not supported. Maybe try "+                       "`jit(partial(jvp, f), donate_argnums=...)((x,), (x,))` "+                       "instead of "+                       "`jvp(jit(f, donate_argnums=...), (x,), (x,))`")+    donate = (False,) * len(args)++  arg_specs = safe_map(arg_spec, args, donate)+  compiled_fun = _xla_callable(fun, device, backend, name, *arg_specs)   try:     return compiled_fun(*args)   except FloatingPointError:     print("Invalid value encountered in the output of a jit function. "           "Calling the de-optimized version.")     return fun.call_wrapped(*args)  # probably won't return +def flatten_shape(s: xe.Shape) -> Sequence[Tuple[Sequence[int], xe.Shape]]:

I think its good to have this function to handle that case, my preference would be to keep the general implementation since it is not all that complex. I guess we could shave a few lines off if we made this switch on tuple_args. Please feel free to reopen if you feel strongly.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def jaxpr_collectives(jaxpr):  ### xla_call underlying jit -def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name):-  compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))+def _xla_call_impl(+    fun: lu.WrappedFun,+    *args,+    device: Optional[xe.Device],+    backend,+    name: str,+    donate: Sequence[bool],+):+  if len(args) != len(donate):

With @mattjj's help this is now resolved. properly 😄

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def execute_replicated(compiled, backend, in_handler, out_handler, *args): xla_pmap_p.def_custom_bind(xla_pmap) xla_pmap_p.def_impl(xla_pmap_impl) -def _pmap_translation_rule(c, axis_env,-                           in_nodes, name_stack, axis_name, axis_size,-                           global_axis_size, devices, name,-                           call_jaxpr, *, backend=None, mapped_invars):+def _pmap_translation_rule(

Heh I have a different opinion 😄 . Have reverted this (and all other signatures) to the current "condensed and lightly typed" style matching the rest of the file.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def argnums_partial(f, dyn_argnums, args):   dyn_args = tuple(args[i] for i in dyn_argnums)   return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args +def donation_vector(donate_argnums, args) -> Tuple[bool, ...]:+  """Returns a tuple with a boolean value for each leaf in args."""+  res = []+  for i, arg in enumerate(args):+    donate = bool(i in donate_argnums)+    res.extend((donate,) * tree_structure(arg).num_leaves)+  return tuple(res)++def rebase_donate_argnums(donate_argnums: Iterable[int],+                          static_argnums: Iterable[int]):+  """Rebases donate_argnums on static_argnums."""

Have added a comment.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):   else:     return partial(_execute_replicated, compiled, result_handlers) +def configure_aliasing(+    c: xc.XlaBuilder,+    xla_args: Sequence[xe.XlaOp],+    out_tuple: xe.XlaOp,+    arg_donate: Sequence[bool],+    tuple_args: bool,+):+  """Configures input/output "must" aliasing based on `arg_donate`."""+  assert len(xla_args) == len(arg_donate)++  # First for every input array add it to `donations` iff it is a member of+  # `arg_donate`.+  donations = defaultdict(deque)+  for arg_index, arg in enumerate(xla_args):+    if arg_donate[arg_index]:+      for param_index, element in flatten_shape(c.GetShape(arg)):+        key = (element.dimensions(), element.numpy_dtype())+        if tuple_args:+          param_number = 0+          param_index = (arg_index,) + tuple(param_index)+          donations[key].append((param_number, param_index, arg_index))+        else:+          param_number = arg_index+          donations[key].append((param_number, param_index, arg_index))++  # Consume donations for outputs.+  for output_index, element in flatten_shape(c.GetShape(out_tuple)):+    key = (element.dimensions(), element.numpy_dtype())+    if donations.get(key, ()):+      param_number, param_index, _ = donations[key].popleft()+      c.SetUpAlias(output_index, param_number, param_index)

As per discussion in design review we'll leave this API as is for now (just specifying what buffers you are letting JAX reason about and not telling us specifically what to do with them).

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def jit(fun: Callable, static_argnums: Union[int, Iterable[int]] = (),       XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``.     backend: This is an experimental feature and the API is likely to change.       Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.+    donate_argnums: Specify which arguments are "donated" to the computation.

Per the discussion in the design review I've made consumption of donated buffers optional and as of this PR all we'll do is setup "must aliases" on TPU iff we find a match between the input and the output buffers.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def vjp(x_tangent):     b = np.dot(a + np.eye(a.shape[0]), real_x)     print(gf(a, b))  # doesn't crash +class BufferDonationTest(jtu.JaxTestCase):++  @jtu.skip_on_devices("cpu", "gpu")  # Buffer donation only supported on TPU.+  def test_jit_donate_argnums_invalidates_input(self):+    # We can't just use `lambda x: x` because JAX simplifies this away.+    move = jit(lambda x: x + x - x, donate_argnums=0)+    x = np.ones([])+    y = move(x)+    self.assertDeleted(x)+    self.assertEqual(y, 1.)++  @jtu.skip_on_devices("cpu", "gpu")  # Buffer donation only supported on TPU.+  def test_jit_donate_argnums_must_alias(self):+    f = jit(lambda x: np.expand_dims(x, 0), donate_argnums=0)+    with self.assertRaisesRegex(+        ValueError,+        r"donated but not used.*\n.*flat_args\[0\] \(spec: float32\[\]\)"):+      f(np.ones([]))++  @jtu.skip_on_devices("cpu", "gpu")  # Buffer donation only supported on TPU.+  def test_jit_donate_argnums_static_argnums(self):+    jit_fun = jit(lambda a, b, c, d: ((a + b + c), (a + b + d)),+                  static_argnums=(0, 1), donate_argnums=(2, 3))++    a = np.array(1)+    b = np.array(2)+    c = jax.device_put(np.array([1., 1.]))+    d = jax.device_put(np.array([1., 1., 1.]))+    e, f = jit_fun(a, b, c, d)+    onp.testing.assert_allclose(e, np.array([4., 4.]))+    onp.testing.assert_allclose(f, np.array([4., 4., 4.]))+    self.assertNotDeleted(a)+    self.assertNotDeleted(b)+    self.assertDeleted(c)+    self.assertDeleted(d)++  def test_jit_nested_donate_ignored(self):+    jit_fun = jit(lambda x: jit(lambda y: y ** 2, donate_argnums=0)(x))+    a = jax.device_put(np.array(1))+    with self.assertRaisesRegex(ValueError, "nested.*not supported"):+      jit_fun(a)++  def test_jvp_of_jit_donate_not_supported(self):+    x = np.array(1.)+    with self.assertRaisesRegex(ValueError,+                                "JVP of.* jit.* which donates.* not supported"):+      api.jvp(jit(lambda x: x ** 2, donate_argnums=0), (x,), (x,))++  # === pmap ===++  @jtu.skip_on_devices("cpu", "gpu")  # Buffer donation only supported on TPU.+  def test_pmap_donate_argnums_invalidates_input(self):+    # We can't just use `lambda x: x` because JAX simplifies this away.+    move = pmap(lambda x: x + x - x, donate_argnums=0)+    n = jax.local_device_count()+    x = pmap(lambda x: x)(np.ones([n]))+    y = move(x)+    self.assertDeleted(x)+    onp.testing.assert_allclose(y, [1.] * n)++  @jtu.skip_on_devices("cpu", "gpu")  # Buffer donation only supported on TPU.+  def test_pmap_donate_argnums_must_alias(self):+    f = pmap(lambda x: np.expand_dims(x, 0), donate_argnums=0)+    with self.assertRaisesRegex(+        ValueError,+        r"donated but not used.*\n.*flat_args\[0\] \(spec: float32\[\]\)"):+      f(np.ones([1]))++  def test_pmap_nested_donate_raises(self):+    pmap_fun = jit(lambda x: pmap(lambda y: y ** 2, donate_argnums=0)(x))+    a = pmap(lambda x: x)(np.array([1]))+    with self.assertRaisesRegex(ValueError, "nested.*not supported"):+      pmap_fun(a)++  assertDeleted = lambda self, x: self._assertDeleted(x, True)+  assertNotDeleted = lambda self, x: self._assertDeleted(x, False)++  def _assertDeleted(self, x, deleted):+    if hasattr(x, "device_buffer"):+      self.assertEqual(x.device_buffer.is_deleted(), deleted)

When a buffer is donated and used as part of a computation it is marked as deleted and further uses will raise an error. We discussed leaving a tombstone on the buffer object to give better errors, I'll leave this for a follow up.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def _get_device(device, backend): xla_call_p.def_impl(_xla_call_impl) pe.staged_out_calls.add(xla_call_p) -def _xla_call_translation_rule(c, axis_env,-                               in_nodes, name_stack, backend, name,-                               call_jaxpr, device=None):+def _xla_call_translation_rule(+    c: xc.XlaBuilder,+    axis_env,+    in_nodes,+    name_stack: str,+    backend,+    name: str,+    call_jaxpr,+    device: Optional[xe.Device],+    donate: Sequence[bool],+):   del device  # Ignored.+  if not all(not x for x in donate):+    raise ValueError("Donating buffers passed to a jit nested inside a jit or "+                     "pmap is not supported.")+   subc = xb.make_computation_builder(f"jit_{name}")   args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]   out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),                             extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)   subc = subc.Build(xops.Tuple(subc, out_nodes))   return xops.Call(c, subc, list(in_nodes))-ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)++def _xla_call_transpose_rule(params, call_jaxpr, args, ct):+  params["donate"] = (False,) * len(args)

See above comments wrt JVPTrace, we are now handling this correctly where inputs to call primitives are changed and I've removed this.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def jaxpr_collectives(jaxpr):  ### xla_call underlying jit -def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name):-  compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))+def _xla_call_impl(+    fun: lu.WrappedFun,+    *args,+    device: Optional[xe.Device],

I've dropped these hints for now based on other comments re arg formatting.

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):  ### the xla_pmap primitive and its rules are comparable to xla_call in xla.py +def _arg_spec(val, donate):+  aval = xla.abstractify(val)+  return aval, donate+ def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, global_axis_size,-                  devices, name, mapped_invars):-  abstract_args = map(xla.abstractify, args)+                  devices, name, mapped_invars, donate):+  donate += (False,) * (len(args) - len(donate))  # AD changes args not donate.

Have fixed this properly with @mattjj's help, we now correctly handle this following mapped_invars (extending to call as well as map).

tomhennigan

comment created time in a month

Pull request review commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

 def argnums_partial(f, dyn_argnums, args):   dyn_args = tuple(args[i] for i in dyn_argnums)   return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args +def donation_vector(donate_argnums, args) -> Tuple[bool, ...]:+  """Returns a tuple with a boolean value for each leaf in args."""+  res = []+  for i, arg in enumerate(args):+    donate = bool(i in donate_argnums)+    res.extend((donate,) * tree_structure(arg).num_leaves)+  return tuple(res)++def rebase_donate_argnums(donate_argnums: Iterable[int],+                          static_argnums: Iterable[int]):+  """Rebases donate_argnums on static_argnums."""+  if not static_argnums:+    return donate_argnums++  if not all(a not in static_argnums for a in donate_argnums):+    raise ValueError(f"`static_argnums` {static_argnums} and "+                     f"`donate_argnums` {donate_argnums} cannot intersect.")++  def rebase(a):+    n = 0+    for i in static_argnums:

Fixed.

tomhennigan

comment created time in a month

push eventtomhennigan/jax

Jake VanderPlas

commit sha d8d71407dc8c977185fe26b3cdf309b2dae6546f

Deprecate random.shuffle() and implement random.permutation() for multi-dimensional matrices.

view details

Jake Vanderplas

commit sha 6425ca2aedfb237b13cc9323302033dae6ff72c9

Merge pull request #2925 from jakevdp/shuffle Deprecate random.shuffle() and implement random.permutation() for multi-dim inputs

view details

Peter Hawkins

commit sha ee38e1b3b6874355bc13dd3d5a0ec7682734a212

Update XLA. (#2929) Includes a fix that may help with issue #2906.

view details

Stephan Hoyer

commit sha 46ce80b03212dfff86624e341d8a2b59ac474482

jax.random.poisson (#2805) * jax.random.poisson The implementation for lam < 10 was directly copied from TensorFlow probability: https://github.com/tensorflow/probability/blob/v0.10.0-rc0/tensorflow_probability/python/internal/backend/numpy/random_generators.py#L155 I adapted the implementation for lam > 10 from TensorFlow: https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc The methods themselves match both TensorFlow and NumPy: https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574 * add a check for even larger lambda * increment iter count * remove comment that makes no sense * Fix chi-squared tests in random_test.py As far as I can tell, the previous implementation of the chi-squared test for samples from discrete probability distributions was broken. It should have been asserting that the p-value was greater 0.01, e.g., as illustrated here: http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html This hid a few other bugs, such a miscalculation of expected frequencies. Fortunately, the existing random tests for Bernoulli and Categorical *mostly* still pass, which the exception of multi-dimensional logits for Categorical. Those tests are disabled by this PR. * Fix accept condition (based on correct chi-squared test) * Add moment checks for Poisson * Add batching test, more Poisson rates

view details

Peter Hawkins

commit sha a18257860a78c63ad4c5d460b206b62872e371f9

Update XLA. (#2932) Mention illegal instruction fix in changelog.

view details

Matthew Johnson

commit sha 64f12a42463f90bdd35e5c969028a0e539e253e4

improve docs and error message for odeint *args (#2931) cf. #2920

view details

Matthew Johnson

commit sha 9f7115ece2dcc0ef324e736bd5f15607fb9c15ee

reduce use of lax on static data (e.g. shapes) (#2933) * reduce use of lax on static data (e.g. shapes) * use f-string for error message

view details

James Bradbury

commit sha 1cc6b7dd6cf9e14cbc0df5730605df96761bfca7

support axis argument in nn.glu (#2879) * support axis argument in nn.glu * also add basic correctness test * Update nn_test.py

view details

George Necula

commit sha d315564ebf9097625d23dac7ead018253cfb3bee

Fixed a few more places where device commitment was lost. (#2913) * trivial jit computations were forcing commitment to the default device * a device_put with a device specification would not set the commitment if the data was already (uncommitted) on the specified device. * added tests for the above * once the above were fixed the LaztTest.test_zeros_ones_compilation stated to fail because the `sticky` parameter to lazy_force_computation was changing. Fixed this by removing stickyness from the compilation key. * Expanded docstring for jax.device_put; expanded the device placement FAQ entry.

view details

Roman Ring

commit sha 525235d8c976d60ffecf4a017b6b528565a9cc7b

Fix a codeblock in the "understanding jaxpr" doc. (#2942) This fixes an issue where the codeblock didn't render properly on the website.

view details

Peter Hawkins

commit sha 4d236b5c47294b39e3b494554e23ad2d686b8fe6

Update XLA to fix build failures. (#2950)

view details

tamaranorman

commit sha 04102e5b9d98f2f79d3f4b5c1b9bb21216ff4c3e

Allow ConvDimensionNumbers to be passed into conv_transpose (#2915)

view details

Peter Hawkins

commit sha d61d6f44dc2f0155238d4ed0cac25f1f8f631661

Fix a number of flaky tests. (#2953) * relax some test tolerances. * disable 'random' preconditioner in CG test (#2951). * ensure that scatter and top-k tests don't create ties.

view details

Peter Hawkins

commit sha 72efa783ab889acd69a7163d050f491fa0421f5e

Fix spurious rank promotion warning. (#2954)

view details

Stephan Hoyer

commit sha 5a0bf46234481887d18f1c8623c8a78d4a2a842e

DOC: add a table of contents for top level API docs (#2946) This makes them easier to scan.

view details

Stephan Hoyer

commit sha 6aab9e5f50073ee5264bdacf05fb58472d80cf2c

DOC: write a new dosctring for jax.numpy.vectorize (#2944) * DOC: write a new dosctring for jax.numpy.vectorize This version is customized entirely for JAX. * review and typo fixes

view details

joschkabraun

commit sha c9c653aaf089f3e135093f8322a44ec286478e0c

Implementation numpy.ediff1d (#2729) * Implementation of numpy.ediff1d * Added testing for numpy.ediff1d implementation * Made ediff1d jit-compatible * Implemented corrections: style and more testing * Adapted tests * changed tests * modified tests * Incorporated changes * Style changes * Added line between tests * Changed op_record test

view details

Tom Hennigan

commit sha 4c2c5ad5f4ce10b160770c9625f5c794850f06e2

Add a note about jax.pmap when leading dim is smaller than num devices. (#2949)

view details

Peter Hawkins

commit sha 91746842532a3152538db2b8f2eff0363789b524

Cache test_utils.format_shape_and_dtype_string. (#2959) A significant fraction of time when collecting test cases is spent building shape and dtype strings (which are usually similar and usually thrown away.)

view details

yurodiviy

commit sha 3e522373a00124eec8c996bc1ed89937095e3d86

Raise an error in np.var when array is complex and dtype is not (#2288) Co-authored-by: vlad <veryfakemail@ya.ru>

view details

push time in a month

Pull request review commentgoogle/jax

Initial import of jax2tf into JAX core

+# Copyright 2020 Google LLC+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     https://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.++import functools+from absl.testing import absltest+import jax+from jax import test_util as jtu+import numpy as np++from jax.experimental import tf_bridge++from jax.experimental import stax+from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,  # pylint: disable=g-multiple-import+                                   FanOut, Flatten, GeneralConv, Identity,+                                   MaxPool, Relu, LogSoftmax)++from jax.config import config+config.parse_flags_with_absl()+++def ConvBlock(kernel_size, filters, strides=(2, 2)):  # pylint: disable=invalid-name+  ks = kernel_size+  filters1, filters2, filters3 = filters+  Main = stax.serial(  # pylint: disable=invalid-name+      Conv(filters1, (1, 1), strides), BatchNorm(), Relu,+      Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,+      Conv(filters3, (1, 1)), BatchNorm())+  Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())  # pylint: disable=invalid-name+  return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)+++def IdentityBlock(kernel_size, filters):  # pylint: disable=invalid-name+  ks = kernel_size+  filters1, filters2 = filters+  def make_main(input_shape):+    # the number of output channels depends on the number of input channels+    return stax.serial(+        Conv(filters1, (1, 1)), BatchNorm(), Relu,+        Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,+        Conv(input_shape[3], (1, 1)), BatchNorm())+  Main = stax.shape_dependent(make_main)  # pylint: disable=invalid-name+  return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)+++def ResNet50(num_classes):  # pylint: disable=invalid-name

Given that we're inside JAX core now we could just from jax.examples import resnet50 and replace this all with resnet50.ResNet50(..)?

gnecula

comment created time in a month

issue commentgoogle/jax

ResNet50 example should not have biases for Conv modules

BTW I think folks are also very likely to fork this code given its prominent location in the examples dir, I think we should be clearer about:

  1. How confident are we in the implementation, has it been shown to reproduce any published experiment?
  2. Is this ResNetV1 or V2?
  3. Has this implementation reproduced any published experiments?
  4. What statistics are used during test time (a: "local batch, same as training").
tomhennigan

comment created time in a month

issue commentgoogle/jax

Automatically treat dataclasses as pytrees

For context in the various other tree libraries (tf.nest and dm-tree) we have pushed back on dataclasses automatically being treated as nests because for these "struct" types it is not clear if this (treating these as containers) is intended in all cases. These libraries treat namedtuple/typing.NamedTuple as trees because it has a weird duality of being a structure (supporting named attribute access) and iterable (supporting x.__iter__()). attr.s (basically dataclasses) slipped in as a historical accident (and since these APIs don't have a way for users to register custom types).

Instead of treeating all dataclasses as jaxtrees, could we instead create a drop in replacement for dataclass for users who know they want this behavior? Here's an example implementation which is basically a fork of flax.struct:

from dataclasses import dataclass
from typing import Any, Type, TypeVar
import jax
import jax.numpy as jnp

T = TypeVar("T")

def jax_tree(cls: T) -> T:
  is_data = lambda x: isinstance(x, jnp.ndarray) or hasattr(x, '__jax_dataclass')

  def flatten_fun(obj):
    meta = {}
    data = {}
    for k, v in obj.__dict__.items():
      if isinstance(v, list):  # We can add other containers here.
        are_data = list(map(is_data, v))
        assert all(are_data) or not any(are_data)
        data[k] = v
      elif is_data(v):
        data[k] = v
      else:
        meta[k] = v
    meta['__data_keys'] = list(data.keys())
    data = list(data.values())
    return tuple(data), tuple(meta.items())

  def unflatten_fun(meta, data):
    meta = dict(meta)
    data = dict(zip(meta.pop('__data_keys'), data))
    return cls(**meta, **data)

  jax.tree_util.register_pytree_node(cls, flatten_fun, unflatten_fun)

  cls.__jax_dataclass = True
  return dataclass(cls)

jax.tree = jax_tree

@jax.tree
class Bar:
  c: jnp.ndarray

@jax.tree
class Foo(object):
  a: jnp.ndarray
  b: Bar

>>> foo = Foo(jnp.ones([]), Bar(jnp.zeros([])))
>>> jax.tree_leaves(foo)
[DeviceArray(1., dtype=float32), DeviceArray(0., dtype=float32)]
shoyer

comment created time in a month

issue openedgoogle/jax

ResNet50 example should not have biases for Conv modules

The ResNet50 example has too many parameters (161 are expected but stax produces 214). This is afaik because the conv layers include biases where they should not. Per the paper author [0] the batchnorm layers should contain the biases.

The delta comes from (1x) the initial conv (4x 4) from the conv blocks and (12x 3) from the identity blocks (leading to the extra 53 parameters).

[0] https://github.com/KaimingHe/deep-residual-networks/issues/10

created time in a month

pull request commentdeepmind/dm-haiku

Adds Zeros, Ones and Identity initializers

Thanks, I think including Identity is a useful addition, please add a test (you should be able to port some over from Sonnet 2).

joaogui1

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

Adds Zeros, Ones and Identity initializers

 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:     q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))     q_mat = jnp.moveaxis(q_mat, 0, self.axis)     return jax.lax.convert_element_type(self.scale, dtype) * q_mat+++  class Zeros(Initializer):+    """Initializer that generates tensors initialized to 0."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.zeros(shape, dtype)++  class Ones(Initializer):+    """Initializer that generates tensors initialized to 1."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.ones(shape, dtype)

Thanks, looks good now.

joaogui1

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

Adding regularizers

 State = Mapping[str, Mapping[str, jnp.ndarray]] Padding = Callable[[int], Sequence[int]] Paddings = Union[Padding, Sequence[Padding]]+Regularizer = Callable[[jnp.ndarray], float]

This takes any tree of ndarrays and returns an ndarray, I think the type should be Callable[[Any], jnp.ndarray] or Callable[[Params], jnp.ndarray].

joaogui1

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

Adding regularizers

+# Lint as: python3+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+# ==============================================================================+"""Haiku regularizers."""++from haiku._src import base+from haiku._src.typing import Shape, DType, Regularizer+from jax import tree_flatten+import jax.numpy as jnp+++class L1(Regularizer):+  """L1 regularizer."""++  def __init__(self, scale):+    """Create an L1 regularizer.+    Args:+      scale: A non-negative regularization factor.+    Raises:+      ValueError: if scale is <0.+    """+    if scale < 0:+        raise ValueError("scale must be a non-negative value")+    self.scale = scale+++  def __call__(self, parameters) -> jnp.array:+    leaves, _ = tree_flatten(parameters)

jax.tree_leaves ?

joaogui1

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

Adds Zeros, Ones and Identity initializers

 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:     q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))     q_mat = jnp.moveaxis(q_mat, 0, self.axis)     return jax.lax.convert_element_type(self.scale, dtype) * q_mat+++  class Zeros(Initializer):+    """Initializer that generates tensors initialized to 0."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.zeros(shape, dtype)++  class Ones(Initializer):+    """Initializer that generates tensors initialized to 1."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.ones(shape, dtype)+++  class Identity(Initializer):+    """Initializer that generates the identity matrix.+    Constructs a 2D identity matrix or batches of these.+    """++    def __init__(self, gain: float = 1.0):+      """Constructs an identity initializer.+      Args:+         gain: Multiplicative factor to apply to the identity matrix.+      """+    self.gain = gain++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      if len(shpe) < 2:

Typo here.

joaogui1

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

Adds Zeros, Ones and Identity initializers

 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:     q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))     q_mat = jnp.moveaxis(q_mat, 0, self.axis)     return jax.lax.convert_element_type(self.scale, dtype) * q_mat+++  class Zeros(Initializer):+    """Initializer that generates tensors initialized to 0."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.zeros(shape, dtype)++  class Ones(Initializer):+    """Initializer that generates tensors initialized to 1."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.ones(shape, dtype)

(In Sonnet we used a type for Initializer while in Haiku we just used Callable[.., ..] so any callable with the right signature is a valid init function.)

joaogui1

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

Adds Zeros, Ones and Identity initializers

 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:     q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))     q_mat = jnp.moveaxis(q_mat, 0, self.axis)     return jax.lax.convert_element_type(self.scale, dtype) * q_mat+++  class Zeros(Initializer):+    """Initializer that generates tensors initialized to 0."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.zeros(shape, dtype)++  class Ones(Initializer):+    """Initializer that generates tensors initialized to 1."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.ones(shape, dtype)+++  class Identity(Initializer):+    """Initializer that generates the identity matrix.+    Constructs a 2D identity matrix or batches of these.+    """++    def __init__(self, gain: float = 1.0):+      """Constructs an identity initializer.+      Args:+         gain: Multiplicative factor to apply to the identity matrix.+      """+    self.gain = gain

Indent plz 😄

joaogui1

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

Adds Zeros, Ones and Identity initializers

 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:     q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))     q_mat = jnp.moveaxis(q_mat, 0, self.axis)     return jax.lax.convert_element_type(self.scale, dtype) * q_mat+++  class Zeros(Initializer):+    """Initializer that generates tensors initialized to 0."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.zeros(shape, dtype)++  class Ones(Initializer):+    """Initializer that generates tensors initialized to 1."""++    def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:+      return jnp.ones(shape, dtype)

I think for Zeros and Ones it is probably more readable to just use jnp.zeros and jnp.ones. These actually match the Initializer signature so type checking works too 😄

joaogui1

comment created time in 2 months

pull request commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

I asked a few people to meet with me to review the design of this PR. Overall the proposed API was approved by the group and we will now iterate on a few changes to the implementation that will not be user visible (e.g. interaction with JVPTrace).

One major change as a result of the review is that we are leaning in to "gifting" nature of "donation" (previously we would error if we could not use a donation, now we will simply discard it). The original implementation focused exclusively on the case where a donated buffer is aliased with an output, however there are other future cases where donation may make sense (e.g. reusing an input buffer for an intermediate computation) and on some backends we may not be able to make use of donated buffers (e.g. on CPU/GPU input/output aliasing is not supported).

We decided that donation will be best effort and on backends that don't offer anything to do with donated buffers (e.g. CPU/GPU) we will log a warning and invalidate the the Python objects (to avoid accidental reuse). We will additionally no longer require perfect aliasing of inputs and outputs (meaning that on backends which do support donation you can donate more than you need).

There were concerns raised about providing the user control over which specific buffers in the input aliased the output. For now we will document that input/output aliasing is managed by traversing the flat input and flat output in order and matching the first output with an appropriate shape/dtype to the input. This means you can have fine grained control over aliasing through re-ordering outputs in Python.

Thanks to everyone who participated for an engaging discussion! I hope to have the patch ready to submit by the end of the month.

tomhennigan

comment created time in 2 months

Pull request review commentdeepmind/dm-haiku

[hk] Disallow parameter creation in apply when params is None

 def apply_fn(   return TransformedWithState(init_fn, apply_fn)  -def check_mapping(name: str, mapping: T) -> T:+def check_mapping(name: str, mapping: Optional[T]) -> [T]:   # TODO(tomhennigan) Remove support for empty non-Mappings.   if mapping and not isinstance(mapping, Mapping):     raise TypeError(f"{name} argument does not appear valid: {mapping!r}. "                     "For reference the parameters for apply are "                     "`apply(params, rng, ...)`` for `hk.transform` and "                     "`apply(params, state, rng, ...)` for "                     "`hk.transform_with_state`.")-  return mapping+  if mapping is not None:+    return mapping+  else:+    # Convert None to empty dict.+    return dict()

Maybe we should re-arrange this as?

if mapping is None:
  mapping = {}

.. check mapping is valid ..

return mapping
trevorcai

comment created time in 2 months

pull request commentdeepmind/dm-haiku

[hk] Disallow parameter creation in apply when params is None

The types on the apply functions don't indicate that Optional[T] are accepted, should we update that too?

trevorcai

comment created time in 2 months

PR opened google/jax

Avoid recompilation of rolled loops in threefry2x32.
+20 -17

0 comment

1 changed file

pr created time in 2 months

create barnchtomhennigan/jax

branch : changelist/311191008

created branch time in 2 months

pull request commentgoogle/jax

Expose functools.reduce initializer argument to tree_util.tree_reduce

I have another suggestion! ... just kidding 😄 looks good to me, thanks for sticking with it.

@gnecula can we merge this?

bastings

comment created time in 2 months

pull request commentgoogle/jax

Expose functools.reduce initializer argument to tree_util.tree_reduce

Just need to replace jax.tree_reduce with tree_reduce and they should pass 😄

bastings

comment created time in 2 months

pull request commentgoogle/jax

Expose functools.reduce initializer argument to tree_util.tree_reduce

Thanks! One thing that makes me a bit sad now is that the function is not self documenting 😢 . I was going to suggest adding a docstring explaining how *args is interpreted, but I think there is an alternative way to support this which only obscures the final argument of the function. WDYT of the suggestion below?

In my opinion this is better than the *args version even if it is a slight abuse of the varargs syntax. We are however safe and will throw an exception if the user passes more than one initializer so there is not a way to accidentally misuse this function and we can change how we implement this in the future if we change our minds:

import functools

def tree_reduce(function, tree, *initializer):
  if initializer:  # We need to differentiate between initializer explicitly passed or not.
    initializer, = initializer
    return functools.reduce(function, jax.tree_leaves(tree), initializer)
  else:
    return functools.reduce(function, jax.tree_leaves(tree))

tree_reduce(lambda a, b: a + b, [])
bastings

comment created time in 2 months

pull request commentgoogle/jax

Expose functools.reduce initializer argument to tree_util.tree_reduce

👍 for this change, but None is a supported value to reduce in functools and I think it should be here:

>>> reduce(lambda x, y: x + y, [])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: reduce() of empty sequence with no initial value
>>> reduce(lambda x, y: x + y, [], None)

The neatest way I know to do this is to check the length of args:

def tree_reduce(*args):
  if len(args) == 2:
    f, seq = args
    return reduce(f, tree_leaves(seq))
  else:
    f, seq, default = args
    return reduce(f, seq, default)

Or to test for a non-None sentinel:

raise_on_empty = object()

def tree_reduce(f, seq, default=raise_on_empty):
  seq = tree_leaves(seq)
  if default is raise_on_empty:
    return reduce(f, seq)
  else:
    return reduce(f, seq, default)
bastings

comment created time in 2 months

more