google/jax 8452

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Research language for array processing in the Haskell/ML family

A pedagogical implementation of Autograd

duvenaud/relax 132

Optimizing control variates for black-box gradient estimation

Recognizing and exploiting conjugacy without a domain-specific language

Prototypes of differentiable differential equation solvers in JAX.

my fish configuration

my .vim

Repo for Dechter, Johnson, Tavares final project for 6.945.

push eventgoogle/jax

commit sha 93be61189d0e5de376a2136138c52acdaeddb32e

omnistaging wip

push time in 37 minutes

push eventgoogle/jax

commit sha 0031075bb2d30c95ae4fb29d3457767d65ae07f6

Replaced jnp.sum by sum when the argument is a list (#3253)

push time in 2 hours

PR merged google/jax

As discussed on the chat, replacing jnp.sum by sum to solve compilation issues

pr closed time in 2 hours

push eventgoogle/jax

commit sha c2ca970867712147e4004de2c1d6ef0fdd5b4d6e

omnistaging wip

push time in 3 hours

push eventgoogle/jax

commit sha 7128a12cfe4669ba5189bf0518ff5114494827ed

omnistaging wip

push time in 3 hours

Pull request review commentgoogle/jax

moved check_jaxpr code around to match eval_jaxpr

def check_jaxpr(jaxpr: Jaxpr): exception_args = [msg, *e.args[1:]] raise exception_type(*exception_args) from e -def _check_jaxpr(jaxpr: Jaxpr):- env = _JaxprTypeEnvironment()-- env.write(unitvar)- map(env.write, jaxpr.constvars)- map(env.write, jaxpr.invars)+def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]): - for eqn in jaxpr.eqns:- check_jaxpr_eqn(env, eqn)+ def read(v: Var) -> AbstractValue:+ if isinstance(v, Literal):+ return get_aval(v.val)+ else:+ if v not in env:+ raise TypeError(f"Variable '{v}' not defined")+ return env[v] - for subjaxpr in subjaxprs(jaxpr):- _check_jaxpr(subjaxpr)+ def write(v: Var, a: AbstractValue) -> None:+ if v in env:+ raise TypeError(f"Variable '{v}' already bound")+ # TODO(frostig): we'd rather check equality or just typecompat here, but+ # partial_eval.tracers_to_jaxpr types eqn outvars as abstract_unit if the+ # outvars are unused+ if not typecompat(v.aval, a) and v.aval is not abstract_unit:+ raise TypeError(f"Variable '{v}' inconsistently typed as {a}, "+ f"bound as {v.aval}")+ env[v] = a - map(env.read, jaxpr.outvars)+ env : Dict[Var, AbstractValue] = {} -def _valid_eqn_assignment(dst_aval, src_aval):- # TODO(frostig): we'd rather this check simply be `typecompat` and not allow- # assignment to an AbstractUnit, but partial_eval.tracers_to_jaxpr types eqn- # outvars as AbstractUnit if the outvars are unused.- return dst_aval is abstract_unit or typecompat(dst_aval, src_aval)+ write(unitvar, abstract_unit)+ map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars])+ map(write, jaxpr.invars, in_avals) -def check_jaxpr_eqn(env, eqn):- invars = map(env.read, eqn.invars)- inferred_out_avals = type_transfer(eqn.primitive, invars, eqn.params)- outvars = map(env.write, eqn.outvars)+ for eqn in jaxpr.eqns:+ in_avals = map(read, eqn.invars)+ if eqn.primitive.call_primitive:+ out_avals = check_call(eqn.primitive, in_avals, eqn.params)+ elif eqn.primitive.map_primitive:+ out_avals = check_map(eqn.primitive, in_avals, eqn.params)+ else:+ out_avals = check_eqn(eqn.primitive, in_avals, eqn.params)+ try:+ map(write, eqn.outvars, out_avals)+ except TypeError as e:+ msg, = e.args+ raise TypeError(msg + f" in '{eqn}'") from None - for outvar, inferred_out_aval in zip(outvars, inferred_out_avals):- if not _valid_eqn_assignment(outvar.aval, inferred_out_aval):- raise TypeError(- f"Jaxpr equation LHS {outvar} is {outvar.aval}, "- f"RHS is inferred as {inferred_out_aval}, in '{eqn}'")+ map(read, jaxpr.outvars) -def type_transfer(prim, invars, params):- in_avals = [v.aval for v in invars]+def check_eqn(prim, in_avals, params):

@froystig pointed out that to maintain current behavior we need to check any jaxprs in `params`

here.

comment created time in 11 hours

Pull request review commentgoogle/jax

moved check_jaxpr code around to match eval_jaxpr

def check_jaxpr(jaxpr: Jaxpr): exception_args = [msg, *e.args[1:]] raise exception_type(*exception_args) from e -def _check_jaxpr(jaxpr: Jaxpr):- env = _JaxprTypeEnvironment()-- env.write(unitvar)- map(env.write, jaxpr.constvars)- map(env.write, jaxpr.invars)+def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]): - for eqn in jaxpr.eqns:- check_jaxpr_eqn(env, eqn)+ def read(v: Var) -> AbstractValue:

This should be `v: Union[Var, Literal]`

!

comment created time in 11 hours

Pull request review commentgoogle/jax

Improve speed of tracing dynamic_update_slice

def _check_shapelike(fun_name, arg_name, obj): def _dynamic_slice_indices(operand, start_indices):- if not isinstance(start_indices, (tuple, list)):- if start_indices.ndim != 1:- raise ValueError("Slice indices must be a 1D sequence, got {}"- .format(start_indices.shape))- start_indices = [squeeze(slice(start_indices, [i], [i+1]), dimensions=(0,))- for i in range(operand.ndim)]- else:- start_indices = [onp.asarray(i, dtype=dtypes.int_) if isinstance(i, int)- else i for i in start_indices] if len(start_indices) != operand.ndim: msg = ("Length of slice indices must match number of operand dimensions ({} " "vs {})") raise ValueError(msg.format(len(start_indices), operand.shape)) # map int over operand.shape to raise any dynamic-shape errors- return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i)- for i, d in zip(start_indices, operand.shape)]+ map(int, operand.shape)

I don't have a strong preference but another option is always to use a list comprehension.

This is tangential but I think we should always set `map, unsafe_map = safe_map, map`

(a suggestion I think Dougal had recently). The Python 3 `map`

is quite unsafe IMO, since I don't expect iteration to be side-effecting.

comment created time in 11 hours

PR opened google/jax

This change is mostly stylistic; it brings `check_jaxpr`

closer to `eval_jaxpr`

(and the other jaxpr interpreters) in organization. There's a slight tweak to an error message which lets us save some slightly redundant code.

pr created time in a day

push eventgoogle/jax

commit sha 07208be6be3d3c46aca7b08209e28944fc106dd8

moved check_jaxpr code around to match eval_jaxpr This change is mostly stylistic; it brings check_jaxpr closer to eval_jaxpr (and the other jaxpr interpreters) in organization. There's a slight tweak to an error message which lets us save some slightly redundant code.

push time in a day

push eventgoogle/jax

commit sha 38c01d7b2b3bc0af3fa2c1df65bf739f135dff2d

make mypy happy

push time in a day

issue commentgoogle/jax

Hoist large constants into parameters

@hawkinsp when you say "XLA/GPu in particular does not handle very large constants embedded in programs well", can you unpack that a little bit? What goes wrong?

comment created time in a day

issue commentgoogle/jax

Hoist large constants into parameters

Oh actually, just noticed you wrote `np.arange`

not `jnp.arange`

, so actually omnistaging won't help in this case either.

comment created time in a day

issue commentgoogle/jax

Hoist large constants into parameters

I think omnistaging will cover this case, though there could be other cases where the constant is not just an iota (e.g. closing over the mnist dataset) that we might want to hoist to a parameter. Does that latter case seem important?

comment created time in a day

issue closedgoogle/jax

Sign error in jax.custom_jvp and jax.custom_vjp docs

I'm having fun learning JAX but think I found a sign error in the math in these docs.

In the example, f(x, y) = sin(x) * y.

df/dx is cos(x) * y as expected. But df/dy should be sin(x), not -sin(x) as in the example code for both f_jvp and f_bwd.

I'm happy to create a pull request if that's the best practice here.

closed time in a day

jiawenissue commentgoogle/jax

Sign error in jax.custom_jvp and jax.custom_vjp docs

Fixed by @jiawen in #3219 !

comment created time in a day

push eventgoogle/jax

commit sha 7c90023ddbbb598a2f34a805ac8c4b19f69b82e1

Fix sign error in custom_jvp / custom_vjp. (#3213) (#3219) f(x, y) = sin(x) * y. df/dy should be sin(x) instead of -sin(x).

push time in a day

PR merged google/jax

f(x, y) = sin(x) * y.

df/dy should be sin(x) instead of -sin(x).

pr closed time in a day

delete branch google/jax

delete branch : remove-dead-custom-jvp-multilinear-code

delete time in a day

push eventgoogle/jax

commit sha 572928dfa309e625e221fd084bd25ee010afa2eb

fix custom_jvp_call_jaxpr transpose function (#3231) * make custom_jvp_call_jaxpr handle multilinear funs see #3226 * remove old comment

push time in a day

PR merged google/jax

(Ignore previous PR message, actually this was easy to fix!)

Follow-up to #3226.

pr closed time in a day

Pull request review commentgoogle/jax

fix custom_jvp_call_jaxpr transpose function

def batched_jvp_jaxpr_thunk(): xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl) # If a (multi)linear function is defined with a custom jvp, then-# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. We transpose it-# like a core.call.-def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,- avals):+# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's+# already been linearized, we can drop the jvp rule.+def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk): del jvp_jaxpr_thunk- name = 'custom_jvp_call_jaxpr_linear'- avals = [core.get_aval(l) for l in fun_jaxpr.literals] + avals- return ad.call_transpose(core.call_p, dict(name=name), fun_jaxpr.jaxpr,- tuple(fun_jaxpr.literals) + args, cts, avals)

Thanks! Hopefully GitHub will squash that for us.

comment created time in a day

startedbenjaminp/six

started time in 2 days

startedfancompute/legume

started time in 2 days

issue commentgoogle/jax

Wowww cool, that is really helpful @momchilmm !

comment created time in 2 days

Pull request review commentgoogle/jax

Initial import of jax2tf into JAX core

+# JAX to TensorFlow converter++WARNING: This is beta quality and not ready for production use. Please expect+API changes!++This package provides an experimental JAX interpreter that implements most JAX+primitives using TensorFlow operations. In practice this means that you can take+some code written in JAX and execute it using TensorFlow eager more, or stage it

"more" -> "mode" ?

comment created time in 2 days

push eventgoogle/jax

commit sha 5ccc941da6ea530f91b4be736a60c97d93964dd7

remove old comment

push time in 2 days

push eventgoogle/jax

commit sha 534f6870b9c8bda15706504d0b8b20222c251052

make custom_jvp_call_jaxpr handle multilinear funs see #3226

push time in 2 days

PR opened google/jax

This code was cruft accidentally left in. It's kind of tricky to get right so I think it's best to remove this for now until someone needs it. It would only come up if someone wants to use custom_jvp on a multilinear function.

Follow-up to #3226.

pr created time in 2 days

create barnchgoogle/jax

branch : remove-dead-custom-jvp-multilinear-code

created branch time in 2 days

issue commentgoogle/jax

Support higher order functions in shapecheck

Brainstormed for a couple minutes with Dougal about this, and it might be possible!

comment created time in 2 days

Pull request review commentgoogle/jax

Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)

def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule): def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs): assert is_undefined_primal(x) ^ is_undefined_primal(y)+ if type(cotangent) is Zero:+ return Zero if is_undefined_primal(x):- out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)- return out, None+ out = lhs_rule(cotangent, y, **kwargs)+ return Zero if out is Zero else (out, None) else:- out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)- return None, out+ out = rhs_rule(cotangent, x, **kwargs)+ return Zero if out is Zero else (None, out) def defjvp_zero(primitive): assert isinstance(primitive, Primitive) primitive_jvps[primitive] = partial(zero_jvp, primitive) def zero_jvp(primitive, primals, tangents, **params):- return primitive.bind(*primals, **params), zero+ r = primitive.bind(*primals, **params)+ return r, Zero.from_value(r) -deflinear(zeros_like_p, lambda t: [zero])+deflinear(zeros_like_p, lambda t: [Zero.from_value(t)]) deflinear(core.identity_p, lambda t: (t,)) deflinear(add_jaxvals_p, lambda t: (t, t)) def instantiate_zeros(example, tangent):- if tangent is zero:+ if type(tangent) is Zero: return zeros_like_jaxval(example)

We don't need to pass `example`

into this function anymore AIUI, though it might be annoying to update all the call sites.

Instead, perhaps we should add a check that `example`

is consistent with `tangent.aval`

when `type(tangent) is Zero`

? WDYT?

comment created time in 2 days

Pull request review commentgoogle/jax

Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)

def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule): def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs): assert is_undefined_primal(x) ^ is_undefined_primal(y)+ if type(cotangent) is Zero:+ return Zero if is_undefined_primal(x):- out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)- return out, None+ out = lhs_rule(cotangent, y, **kwargs)+ return Zero if out is Zero else (out, None) else:- out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)- return None, out+ out = rhs_rule(cotangent, x, **kwargs)+ return Zero if out is Zero else (None, out) def defjvp_zero(primitive): assert isinstance(primitive, Primitive) primitive_jvps[primitive] = partial(zero_jvp, primitive) def zero_jvp(primitive, primals, tangents, **params):- return primitive.bind(*primals, **params), zero+ r = primitive.bind(*primals, **params)+ return r, Zero.from_value(r) -deflinear(zeros_like_p, lambda t: [zero])+deflinear(zeros_like_p, lambda t: [Zero.from_value(t)]) deflinear(core.identity_p, lambda t: (t,)) deflinear(add_jaxvals_p, lambda t: (t, t)) def instantiate_zeros(example, tangent):- if tangent is zero:+ if type(tangent) is Zero: return zeros_like_jaxval(example) else: return tangent def instantiate_zeros_aval(aval, tangent):- if tangent is zero:+ if type(tangent) is Zero: return zeros_like_aval(aval)

Similarly here, we could have a check that the avals agree (or change this function to be unary and adapt all call sites).

comment created time in 2 days

pull request commentgoogle/jax

Small cleanup for partial_eval

By the way, it might be good to run internal tests on this, since otherwise the pmap coverage is lacking. See go/jax-playbook#heading=h.10khh2ytxa1n.

comment created time in 2 days

pull request commentgoogle/jax

Add support for buffer donation in `jit` and `pmap`.

Heads up, might need a rebase on #3210 when that goes in. Since that PR aims to be a cleanup, hopefully if there's any conflict it's only to make this PR easier.

comment created time in 2 days

Pull request review commentgoogle/jax

Small cleanup for partial_eval

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): if call_primitive in call_partial_eval_rules: return call_partial_eval_rules[call_primitive](self, call_primitive, f, tracers, params)- in_pvs, in_consts = unzip2([t.pval for t in tracers])- fun, aux = partial_eval(f, self, in_pvs)- out_flat = call_primitive.bind(fun, *in_consts, **params)- out_pvs, jaxpr, env = aux()- env_tracers = map(self.full_raise, env)- out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])++ jaxpr, out_pvs, consts, env_tracers = self.partial_eval(

Could you change `out_pvs`

to `out_pvals`

, and correspondingly change corresponding iteration variables from `pv`

to `pval`

?

There's a funny convention we've (mostly) stuck to where `pv`

is a name for the first component of a PartialVal pair, and `pval`

is a name for the full pair.

comment created time in 2 days

Pull request review commentgoogle/jax

Small cleanup for partial_eval

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): tracers = map(self.instantiate_const_abstracted, tracers) else: name = wrap_name(name, 'pe')- params = dict(params, name=name)- in_pvs, in_consts = unzip2([t.pval for t in tracers])- reduced_pvs = [None if pv is None else- core.mapped_aval(params['axis_size'], pv) if m else pv- for pv, m in zip(in_pvs, params['mapped_invars'])]- fun, aux = partial_eval(f, self, reduced_pvs)- out_flat = map_primitive.bind(fun, *in_consts, **params)- out_pvs_reduced, jaxpr, env = aux()- out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])- out_pvs = [None if pv is None else core.unmapped_aval(params['axis_size'], pv)- for pv in out_pvs_reduced]++ @curry+ def modify_aval(mod, args):+ pval, is_mapped = args+ if pval.is_known() or not is_mapped:+ return pval+ return PartialVal((mod(params['axis_size'], pval[0]), pval[1]))++ reduced_in_pvs = map(modify_aval(core.mapped_aval),+ zip([t.pval for t in tracers], params['mapped_invars']))+ jaxpr, reduced_out_pvs, consts, env_tracers = self.partial_eval(

Here also: `pvals`

instead of `pvs`

?

comment created time in 2 days

Pull request review commentgoogle/jax

Small cleanup for partial_eval

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): tracers = map(self.instantiate_const_abstracted, tracers) else: name = wrap_name(name, 'pe')- params = dict(params, name=name)- in_pvs, in_consts = unzip2([t.pval for t in tracers])- reduced_pvs = [None if pv is None else- core.mapped_aval(params['axis_size'], pv) if m else pv- for pv, m in zip(in_pvs, params['mapped_invars'])]- fun, aux = partial_eval(f, self, reduced_pvs)- out_flat = map_primitive.bind(fun, *in_consts, **params)- out_pvs_reduced, jaxpr, env = aux()- out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])- out_pvs = [None if pv is None else core.unmapped_aval(params['axis_size'], pv)- for pv in out_pvs_reduced]++ @curry+ def modify_aval(mod, args):+ pval, is_mapped = args+ if pval.is_known() or not is_mapped:+ return pval+ return PartialVal((mod(params['axis_size'], pval[0]), pval[1]))

Might be best not to shadow the builtin name `mod`

here. How about just `modify`

?

comment created time in 2 days

Pull request review commentgoogle/jax

Small cleanup for partial_eval

def _remat_partial_eval(trace, _, f, tracers, params): instantiated_tracers = map(trace.instantiate_const_abstracted, tracers) # Using the instantiated tracers, run call_bind like JaxprTrace.process_call.- in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)- fun, aux = partial_eval(f, trace, in_pvs)- with core.initial_style_staging():- out_flat = remat_call_p.bind(fun, *in_consts, **params)- out_pvs, jaxpr, env = aux()- env = map(trace.full_raise, env)- out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])- out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]+ in_pvs = [t.pval for t in instantiated_tracers]+ jaxpr, out_pvals1, consts, env_tracers = trace.partial_eval(+ f, in_pvs, partial(remat_call_p.bind, **params)) # Since we traced with everything marked as unknown, but we need to know which # outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns. - in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env]- + [raise_to_shaped(pv) for pv in in_pvs])- out_avals = [raise_to_shaped(pv if pv is not None- else abstract_unit if var is unitvar+ in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers]+ + [raise_to_shaped(pv.get_aval()) for pv in in_pvs])+ out_avals = [raise_to_shaped(abstract_unit if var is unitvar else get_aval(var.val) if type(var) is Literal- else get_aval(const))- for var, pv, const in zip(jaxpr.outvars, out_pvs, out_pval_consts1)]+ else pv.get_aval())

`pv`

-> `pval`

?

comment created time in 2 days

Pull request review commentgoogle/jax

Small cleanup for partial_eval

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): tracers = map(self.instantiate_const_abstracted, tracers) else: name = wrap_name(name, 'pe')- params = dict(params, name=name)- in_pvs, in_consts = unzip2([t.pval for t in tracers])- reduced_pvs = [None if pv is None else- core.mapped_aval(params['axis_size'], pv) if m else pv- for pv, m in zip(in_pvs, params['mapped_invars'])]- fun, aux = partial_eval(f, self, reduced_pvs)- out_flat = map_primitive.bind(fun, *in_consts, **params)- out_pvs_reduced, jaxpr, env = aux()- out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])- out_pvs = [None if pv is None else core.unmapped_aval(params['axis_size'], pv)- for pv in out_pvs_reduced]++ @curry+ def modify_aval(mod, args):+ pval, is_mapped = args+ if pval.is_known() or not is_mapped:+ return pval+ return PartialVal((mod(params['axis_size'], pval[0]), pval[1]))++ reduced_in_pvs = map(modify_aval(core.mapped_aval),

Similar here: can you call these `reduced_in_pvals`

?

comment created time in 2 days

pull request commentgoogle/jax

Simplify handling of non-linear equations in backward_pass and fix remat

LGTM again! Let's merge :D

comment created time in 2 days

Pull request review commentgoogle/jax

Simplify handling of non-linear equations in backward_pass and fix remat

def traceable(num_primals, in_tree_def, *primals_and_tangents): yield out_flat, tree_def -def call_transpose(primitive, params, call_jaxpr, args, ct):+def call_transpose(primitive, params, call_jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) params = dict(params, name=wrap_name(params['name'], 'transpose')) out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat) primitive_transposes[core.call_p] = partial(call_transpose, call_p)-primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p) -def map_transpose(primitive, params, call_jaxpr, args, ct):++def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals):+ # backward_pass can only transpose linear computations, but the call_jaxpr embedded in+ # remat contains primal (non-linear) equations too. Hence, we have to eliminate those+ # (in this case via partial_eval) before we call into backward_pass again.+ typed_call_jaxpr = core.TypedJaxpr(+ call_jaxpr, [],+ [raise_to_shaped(p.aval if is_undefined_primal(p) else get_aval(p)) for p in primals_in],+ cotangent_in_avals)+ primal_jaxpr, tangent_jaxpr, out_unknowns = \+ pe.partial_eval_jaxpr(typed_call_jaxpr,+ unknowns=map(is_undefined_primal, primals_in),+ instantiate=True,+ trace_type=None)++ # Make a primal_jaxpr into a function. The flattening business is necessary so+ # that we can thread UndefinedPrimals through the bind call. Note that jaxpr_as_fun+ # doesn't understand how to handle those either, but tangent arguments are usued,+ # so it should just ignore them.+ primal_fun = lu.wrap_init(core.jaxpr_as_fun(primal_jaxpr))+ flat_primals, in_tree_def = tree_flatten(primals_in)+ primal_fun, _ = flatten_fun_nokwargs(primal_fun, in_tree_def)++ params = dict(params, name=wrap_name(params['name'], 'transpose'))+ residuals = pe.remat_call_p.bind(primal_fun, *flat_primals, **params)[len(cotangents_in):]

Great fix and explanation! I agree those pesky details could be revised away but probably aren't worth doing right now.

comment created time in 2 days

push eventgoogle/jax

commit sha 9f8a4ad341acc8bab52c0746102193cb7a4da2ee

remove stray print statement from #1529

push time in 3 days

issue commentgoogle/jax

Sign error in jax.custom_jvp and jax.custom_vjp docs

Yes, please make a PR if you can! This might be another instance of the bug in #3031.

comment created time in 3 days

push eventjacobjinkelly/jax

commit sha 186b85bc5bea09c1089558beaf52fb88bc8edc6a

make jet support multi-output primitives

push time in 3 days

startedadamhaber/stan2tfp

started time in 5 days

startedpoldrack/autoCV

started time in 5 days

push eventgoogle/jax

commit sha 9b12763b9b8d4f841be11bd530938dbc4f787b2b

revive the tracer leak checker The tracer leak checker never worked with the jit (or pmap) compilation cache because 1. it relies on Python reference counting (via a weakref mechanism) to check that there are no more references to a MasterTrace (e.g. from any Tracers associated with it) once a trace is finished, but 2. the compilation caches (i.e. linear_util.cache) can include in their cache key the transforms stacked on the corresponding WrappedFun being jitted, and transforms (specifically trace_to_subjaxpr in partial_eval.py) can include MasterTraces. Hence the cache keys included references to the MasterTraces, defeating the leak checking mechanism. This commit just makes an equal copy of any MasterTraces when building the cache key. MasterTraces are already hashable, with equality defined based on just their level and trace type. MasterTraces are only compared by identity in core.full_raise, and then only to determine if a sublift is needed (because a trace is encountering one of its own tracers but from inside at least one additional level of call scoping). That's not an issue for jit because of the caching, but it could be another issue for calls; TODO figure that out. Co-authored-by: James Bradbury <jekbradbury@google.com>

push time in 6 days

issue commentgoogle/jax

Unexpectedly high grad-of-scan memory usage

By the way, we're working on some other improvements that should make this work well even without `remat`

by never instantiating the large `ones((3000, 3000))`

array. We'd still need `remat`

in general, but in this case the memory savings can be had by avoiding the large constant.

comment created time in 7 days

startedgoogle-research/big_transfer

started time in 7 days

startediodide-project/pyodide

started time in 7 days

push eventgoogle/jax

commit sha c459280e5678a96529dada852bf97b8cf01a4808

Added `associative_scan`. (#2170) * Added `associative_scan`. * Fixed problem where base case of associative scan could fail * remove jax.numpy dependence in associative_scan Co-authored-by: Matthew Johnson <mattjj@google.com>

push time in 8 days

PR merged google/jax

pr closed time in 8 days

push eventdpiponi/jax

commit sha b18a4d8583c0e11e228a0792793d6f6e99292766

Disabled tests known to fail on Mac, and optionally slow tests. Issue: #2166 Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known to be slow.

commit sha b2ef5bc09552e8ed39759df4ff49ea97e32db708

Canonicalize the shape in the wrapper functions in random.py. (#2165) * Canonicalize the shape in the wrapper functions in random.py. This lets the user be more sloppy in using numpy arrays and statically known DeviceArrays for shapes, and still hit the jit cache. When they are not, the error is improved. * Fix some errors. * No need for the Poly workaround. * Bypass canonicalization for None shapes in random.py.

commit sha 13316f35705fee2f43376655313e698b52d1965f

Fix type error in partial_eval.py. (#2171)

commit sha ae3003e9d42a8df41d6d8bbd62c0ba2b4c2c13ce

Simplify bound_subjaxprs. Before, bound_subjaxprs was a tuple (0 or 1 values) of a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs such that they do not take constvars and their constant values are part of the arguments. We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr) This is first part of a simplification. In a subsequent PR I will move the bound_subjaxpr into params, as for most higher-order primitives.

commit sha 2d3cde3fdd093cb0654384c46248986c8541347d

Simplify the translation rules to not take constant values

commit sha 0045ed671c384b170666b2240529ebaea1f97ba4

Fix caching bug. This was a very tricky bug. The compilation caching keys depend on many pieces of the transformation state, including among other things, the pointer-value of the Jaxpr. Since in the previous commit I have added a `convert_constvars_jaxpr` call, every time that one ran, it produces the same semantic Jaxpr but a different pointer value, which was breaking caching. The fix is to cache the `convert_constvars_jaxpr`.

commit sha bbe31335cd3abc253fe4fc793e297d948f9a458f

Fixed pytype complaint

commit sha 8407a65e1b940673ab8f83e41e810f72c9a5cee7

Merge pull request #2167 from gnecula/simple_jaxpr Simplify bound_subjaxprs.

commit sha 4582e4ed8d9a09748f79f32e54c5e5cd85444f61

Updates based on code reviews

commit sha b79c7948eeb7152e27ae05cdd563a2c4555af861

Removed dependency on distutils.strtobool

commit sha 86984b37dd34f40dc71f399318151257251555ec

Merge pull request #2169 from gnecula/bug_fix Disabled tests known to fail on Mac, and optionally slow tests.

commit sha e140520466b9a5081cc304af4f9fa72230579c65

make pmap inside of eager scan work, fixes #2018 (#2183) * make pmap inside of eager scan work, fixes #2018 Co-authored-by: Sharad Vikram <sharadmv@google.com> * Ensure AxisEnv is instantiated with tuples (#2186) Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>

commit sha 2e8798dd16acfe6373bf66780ae7402f7766854f

Use 64-bit integers for indexing if any tensor dimension exceeds 2^31 elements. (#2182)

commit sha be5b24fa5d8dc2aa6128772571836949328128d0

relax the ndim>=1 condition of tensordot (#2191) * relax the ndim condition of tensordot * add test for scalar input with axes=0

commit sha a0e1804e4376a359be6dafdd2aff3a80ed6e117b

Implementation of np.linalg.{cond, tensorinv} (#2125) * add np.linalg.cond in a third_party module * remove unnecessary type maps * rename cond.py to linalg.py for consistency * shift LICENSE to correct directory; formatting changes; completed testing * Add implementation and testing for tensorinv * fix tests for tensorinv to stop build stalling * blank __init__.py; add extra testing for cond; remove index assignments * jax.lax.cond is breaking on jax.numpy.linalg.norm * fix control flow issues; update tests to use curried functions * clean up imports; remove commented code * remove control flow from tests; remove unneeded functions

commit sha 5c9438864e64c8b02b0e13fce9759d8a8ed3d488

fix cond batching bug reading axis size (#2193)

commit sha 051d7b895658f22e9ca64fc77961d61467e20e05

Fix broken link in README (#2196)

commit sha baf45f2c7a039ef91b503afc3c8bda8f089a5ab5

Fix expm(zeros((n, n))) == NaN. (#2131) (#2192) * Fix expm(zeros((n, n)) == NaN. (#2131) * update tests based on @sriharikrishna review

commit sha aa0ca2706219b258ccef3fdb95f70fa002c5b499

Implementation of np.linalg.tensorsolve. (#2119) * Tensorsolve implementation * Tests working for tensorsolve #1999 * Moved tensorsolve to third party directory

commit sha 76d77bfc147e0e281d0953260e3aded64e425230

Fix inconsistent indentation in `JaxprTrace.default_process_primitive`.

push time in 8 days

push eventgoogle/jax

commit sha 5c1de2836c11f5068f9df75fcce4fe887b458f1c

revise vmap in_axes/out_axes leaf type error msg (#3179) from #3161 discussion

push time in 8 days

PR merged google/jax

from #3161 discussion

pr closed time in 8 days

PR opened google/jax

from #3161 discussion

pr created time in 9 days

issue commentgoogle/jax

Binary search function in lax_control_flow_test.py doesn't work where anticipated

@shoyer mind taking a look at this one?

comment created time in 9 days

issue commentgoogle/jax

Documentation on how to vmap on dictionary entries

Oh, good catch on that error message; that's out of date! I'llf ix it.

comment created time in 9 days

push eventgoogle/jax

commit sha eb81d7e7ffcbfaff1e2b281ae6ee902d46653d6a

add dict in_axes example to vmap docstring (#3176) * add dict in_axes example to vmap docstring fixes #3161 * fix typo

push time in 9 days

PR merged google/jax

fixes #3161

pr closed time in 9 days

issue closedgoogle/jax

Documentation on how to vmap on dictionary entries

I have a related question to #2367: Is it possible to use `vmap`

on an array that is in a dictionary? And if yes, how?

To stick with the example given in #2367, how to `vmap`

'b', which is an array within a dictionary:

```
import jax.numpy as np
from jax import vmap
dictionary = {'a': 5., 'b': np.arange(5)}
c = 1.
d = 2.
def f(dct, c, d):
return dct['a'] + dct['b'] + c + d
result = vmap(f, magic)(dictionary, c, d)
```

I don't understand how I would have to create the magic axes tuple for this example or if it is even possible? I think it would be great if there would be a related example in the documentation of vmap.

closed time in 9 days

lukasbrauncomissue closedgoogle/jax

remat unnecessarily forces lazy values

```
import jax
import jax.numpy as jnp
import haiku as hk
N_LAYERS = 3
def f(rng, x, y):
x = hk.dropout(rng, 0.1, x)
return jax.nn.relu(jnp.matmul(x, y))
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(428), 3)
x = jax.random.uniform(k1, (2, 1024, 1024))
y = jax.random.uniform(k2, (1024, 1024))
def g(rng, x, y):
for k in jax.random.split(rng, N_LAYERS):
x = jax.remat(f)(k, x, y)
return jnp.sum(x)
# Profiling this line.
jax.jit(jax.grad(g, argnums=2))(k3, x, y).block_until_ready()
```

This results in the `lax.iota`

from the underlying `jax.random._random_bits`

being forced at trace time, rather being evaluated in runtime. With a large enough N_LAYERS or `x.shape`

, this can result in OOMing the device *during tracing*.

closed time in 9 days

trevorcaiissue commentgoogle/jax

remat unnecessarily forces lazy values

Fixed in #3169 !

comment created time in 9 days

issue commentgoogle/jax

Reductions can quietly muck with loop in/out types

Maybe it's to accumulate in a higher precision and thus avoid numerical issues?

comment created time in 9 days

push eventgoogle/jax

commit sha 0a5e8106f18aa032da5dd1a37c4d1f3d4f8722ce

fix typo

push time in 9 days

PR opened google/jax

fixes #3161

pr created time in 9 days

issue commentgoogle/jax

Documentation on how to vmap on dictionary entries

Yes, set `magic = ({'a': None, 'b': 0}, None, None)`

, like this:

```
import jax.numpy as np
from jax import vmap
dictionary = {'a': 5., 'b': np.arange(5)}
c = 1.
d = 2.
def f(dct, c, d):
return dct['a'] + dct['b'] + c + d
result = vmap(f, in_axes=({'a': None, 'b': 0}, None, None))(dictionary, c, d)
```

This sentence in the vmap docstring is intended to describe the behavior:

If the positional arguments to fun are container types, the corresponding element of in_axes can itself be a matching container, so that distinct array axes can be mapped for different container elements. in_axes must be a container tree prefix of the positional argument tuple passed to fun.

Here, the argument in question is a dictionary, so we make the corresponding entry of `in_axes`

a dictionary too (with the same keys).

comment created time in 9 days

PR opened google/jax

Attempting to fix #3163, but running into issues.

@hawkinsp can you share some wisdom, after reading #3163?

pr created time in 9 days

push eventgoogle/jax

commit sha ae9d1753462d07d75e9e002b13538c385949d9af

fix while_loop cond function batching (#3174) * fix while_loop cond function batching fixes #3164 * add test for #3164

push time in 9 days

PR merged google/jax

fixes #3164

pr closed time in 9 days

issue closedgoogle/jax

expm shape tracing issue through two vmaps and jvp

Hi,

I've been trying to narrow down the issue that arises when a vmap is applied to jvp that goes through a calculation which has its own vmap. Although I am new to Jax, and potentially, lacking understanding of shape tracing but it seems to me that the problem occurs due to expm matrix operation.

Consider the following code that either implements a Jacobian calculation for the matrix product operation:

e^{I*(w[0,0]+w[0,1]})e^{I*(w[1,0]+w[1,1]})

or

(I * w[0,0]) * (I * w[0,1]) * (I * w[1,0]) * (I * w[1,1])

where I is 2x2 identity matrix and w is a set of weights stored as a 2x2 array. The executed calculation can be toggled by commenting/uncommenting 2 lines in matrix_operation(). The code works well for the 2nd operation but breaks down for the 1st one throwing an assertion error related to the shape detected in lax_control_flow.py.

I am wondering whether it is a bug in expm or I am missing something? I would appreciate any help!

```
import jax.numpy as jnp
from jax import vmap, jvp
from functools import partial
from jax.numpy.linalg import multi_dot
def matrix_operation(w):
assert w.shape == (2,)
A = jnp.identity(2)
return la.expm(A*(w[0]+w[1])) # ----- breaks down with this line
#return jnp.matmul(A*w[0], A*w[1]) #----- works with this line!
def cost_function(ws):
same_w = vmap(matrix_operation)(ws)
product = multi_dot(same_w)
return product
def pushfwd_(func, weights, tangent):
return jvp(func, (weights,), (tangent,))
r = 2
c = 2
weights = jnp.ones((r, c))
pushfwd = partial(pushfwd_, cost_function, weights)
# a set of vectors with a single non-zero entry equal to 1
# and same shape as weights
tangents = jnp.reshape(jnp.identity(r*c), (c*r, r, c))
print(vmap(pushfwd)(tangents)) # this breaks with la.expm above
#pushfwd(tangents[0]) # this works with la.expm above
```

closed time in 9 days

BohdanK-1QBitissue commentgoogle/jax

Double jit-compilation in loop when passing plain types

Thanks so much for the crystal-clear explanation!

The `weak_type`

parameter was added in #1709. The PR message has a description, but basically it exists to better follow NumPy-like dtype promotion behavior.

Still thinking about this...

comment created time in 9 days

push eventgoogle/jax

commit sha 752b283abee064c38cdbac0904bb98d31463cb0a

add test for #3164

push time in 9 days

Pull request review commentgoogle/jax

fix while_loop cond function batching

def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr, body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( body_jaxpr, size, batched, instantiate=carry_bat) cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(- cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False)+ cond_jaxpr, size, cconst_bat + carry_bat,+ instantiate=bool(cond_jaxpr.out_avals[0].shape))

For posterity: this was the issue, not the change a few lines down.

comment created time in 9 days

PR opened google/jax

fixes #3164

pr created time in 9 days

issue commentgoogle/jax

expm shape tracing issue through two vmaps and jvp

I think this is a bug in our while_loop batching rule (and a while_loop is used in expm, via the fori_loop wrapper)! It's a bit hard to articulate, but I think I see it...

comment created time in 9 days

issue commentgoogle/jax

expm shape tracing issue through two vmaps and jvp

#3056 is specific to reverse-mode (i.e. VJPs) so I don't think it's related.

comment created time in 9 days

push eventgoogle/jax

commit sha f9c978e9d608e373965512ebc498c0a1338af3ed

improve docstring of jax.numpy.broadcast_to (#3173) thanks @joaogui1 !

push time in 9 days

PR merged google/jax

fixes #3168

thanks @joaogui1 !

pr closed time in 9 days

issue closedgoogle/jax

broadcast_to docs are unhelpful

Right now the docs state "Like Numpy’s broadcast_to but doesn’t necessarily return views." A few suggestions

- Link to numpy's docs
- Copy their docs in some way
- Write our own version

closed time in 9 days

joaogui1push eventgoogle/jax

commit sha a4094f72a4435e57e2c52e27085e55659797b54a

revise "Tracer with raw numpy" error message (#3160) * revise "Tracer with raw numpy" error message fixes #3133 * fix f-string typo * fix typo Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: James Bradbury <jekbradbury@google.com>

push time in 9 days

PR merged google/jax

fixes #3133

pr closed time in 9 days

issue closedgoogle/jax

Tracing a function that indexes into Numpy array gives a poor error message

The following code fails on the last line

```
f = lambda i: jnp.zeros((3, 3))[i, :]
g = lambda i: np.zeros((3, 3))[i, :]
a = np.array([1, 2])
f(a) # Okay
jax.jit(f)(a) # Okay
g(a) # Okay
jax.jit(g)(a) # Fail
```

with the standard error message

```
Tracer can't be used with raw numpy functions. You might have
import numpy as np
instead of
import jax.numpy as np
```

The cause of the error is attempting to trace the `__getitem__`

method of a raw numpy tensor. Normally "Tracer can't be used ..." errors are easy to spot because the offending call starts with `np.`

, but this error is a bit more subtle and takes more time to track down. Also, binary operations that mix numpy and JAX arrays work fine, so it this is an exceptional case.

Is there any way to improve this error message / detect this case? At the extreme end, could jax do without implementing the `__array__`

method for implicit conversions (and replace with an explicit conversion method), to reduce the mental overhead associated with these conversions?

closed time in 9 days

john-m-jumperPull request review commentgoogle/jax

revise "Tracer with raw numpy" error message

class Tracer(object): __slots__ = ['_trace', '__weakref__'] def __array__(self, *args, **kw):- raise Exception("Tracer can't be used with raw numpy functions. "- "You might have\n"- " import numpy as np\n"- "instead of\n"- " import jax.numpy as jnp")+ msg = ("The numpy.ndarray conversion method __array__() was called on "+ f"the JAX Tracer object {self}.\n\n"+ "This error can occurr when a JAX Tracer object is passed to a raw "

I have never spelled that word correctly on my first try.

comment created time in 9 days

push eventgoogle/jax

commit sha c68f6746de4a672b1b76cc95234e4e9c95417c39

fix typo Co-authored-by: James Bradbury <jekbradbury@google.com>

push time in 9 days