profile
viewpoint

google/xls 379

XLS: Accelerated HW Synthesis

regehr/opt-fuzz 26

llvm opt fuzzer and bounded exhaustive test generator

regehr/llvm-stress 1

fork of llvm-stress

majnemer/clang 0

Mirror of official clang git repository located at http://llvm.org/git/clang. Updated every five minutes.

majnemer/compiler-rt 0

Mirror of official compiler-rt git repository located at http://llvm.org/git/compiler-rt. Updated every five minutes.

majnemer/compiler-tests 0

This repo contains Microsoft compiler-tests to validate Windows platform particulars.

majnemer/docs 0

TensorFlow documentation

majnemer/draft 0

C++ standards drafts

majnemer/jax 0

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

PR opened google/jax

Remove type restrictions

We support half16 on TPU

+0 -1

0 comment

1 changed file

pr created time in a month

push eventmajnemer/jax

David Majnemer

commit sha 6f5afa3494238b2c0c1a8c5289bede0e1bc3a4f7

Remove type restrictions We support half16 on TPU

view details

push time in a month

PR opened google/jax

Remove type restrictions

We support s8, u8, s16, u16, half16 on TPU

+6 -29

0 comment

1 changed file

pr created time in a month

push eventmajnemer/jax

David Majnemer

commit sha 6bab85b8d3e347e65d6b5a8bee00951a9b3600ef

Remove type restrictions We support s8, u8, s16, u16, half16 on TPU

view details

push time in a month

PR opened google/jax

Remove type restrictions

dynamic_slice and dynamic_update_slice work with S16 types on TPU.

+0 -2

0 comment

1 changed file

pr created time in a month

push eventmajnemer/jax

David Majnemer

commit sha c8801fbd9bbb2d6563f7186d26e857c2ed3866ff

Remove type restrictions dynamic_slice and dynamic_update_slice work with S16 types on TPU.

view details

push time in a month

issue commentgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

I think this should work now.

levskaya

comment created time in a month

PR opened google/jax

TPUs support half precision arithmetic
+2 -2

0 comment

1 changed file

pr created time in 2 months

create barnchmajnemer/jax

branch : tpu-half16

created branch time in 2 months

push eventmajnemer/jax

Matthew Johnson

commit sha f8bab4ae7c6293148be219b3ce0fd963a9c955cf

update version and changelog for pypi

view details

Skye Wanderman-Milne

commit sha 66ba734882daf8206e41c96495a9ea3c53082bb1

Add note to docs describing how pytree arguments work. (#3284) Addresses #3095. I'm not sure if we wanna link to this from API docstrings. This also subsumes the original pytrees notebook.

view details

Ayush Shridhar

commit sha b998044ffed2f36f683c9dce4fbd2914d5b7de44

Add np.polyadd (#3261)

view details

Jake Vanderplas

commit sha 0db57cb541d58f223cabb02bed4aca51a3d876c3

Fix validation code in lax.conv (#3279)

view details

Julius Kunze

commit sha d1dbf7c7d8216158024697abc9db7d3576b22ebd

Implement mask for some primitives + jit. (#2922) * Implement mask for slice, conv, pad, transpose, where * Remove tentative mask(jit) * Add explanatory comment to dot_general masking rule * Rm reshape from select masking rule * Rm unnecessary check from lax slice abstract_eval rule * Revert to standard indentation in masking_test.py * Begin simplifying masking tests * Finish drafting masking check function * More progress simplifying tests * Add conv masking in batch dim * Finish fixing up tests * Revert to old API, making out_shape compulsory again * More efficient conv masking rule * Tidy up masking_test imports * Check that out tree is preserved by masking * fix flake errors Co-authored-by: Jamie Townsend <jamestownsend@google.com> Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com> Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

Jake Vanderplas

commit sha c77c0838fea4a4673754c5348a56d00bcb7f8bad

deflake jax.numpy and add to flake8 check (#3312)

view details

Skye Wanderman-Milne

commit sha 5ad9feda5f74a88e053a33c5f5330044186b128f

Fix handling of infeed token inside sharded_jit (#3313)

view details

James Bradbury

commit sha 4f5547dd85596bff775cc1ed42c013321833a8b6

Don't AD through max-subtraction in softmax (#2260) * Don't AD through max-subtraction in softmax * Also stop-grad the max in logsumexp

view details

Roy Frostig

commit sha dc4c9f045007959b6bcd9c1c97b3f958f09fc706

change cond primitive to an indexed conditional with multiple branch functions in the core: * bind and check cond primitive in indexed form * rewrite abstract evaluation rule * rewrite translation rule * rewrite partial evaluation rule * rewrite batching rule * rewrite JVP rule * rewrite transpose rule * update jaxpr typechecker * update pretty printer * update outfeed-usage check * update reference jaxpr in cond jaxpr test * update reference regexes in HLO test in experimental modules: * update host_callback rewriter * update loops expression builder * generalize tf_impl rule

view details

Roy Frostig

commit sha 6015a2a6893af98195a0b121056527e1124ab76f

introduce lax.switch

view details

Roy Frostig

commit sha bd3cab9768370a58d172a5bdb2d39de776189a49

update jaxpr doc to reflect lax.switch and indexed cond

view details

Roy Frostig

commit sha c49bb754543f89fc44bcec2ab4b7824f3b869be0

update changelog with lax.switch

view details

George Necula

commit sha afa9276f0869305afe12cbaec88fe3fb535de807

Implement jax_to_tf.scan (#3307)

view details

George Necula

commit sha 71f1c5cafeab50f0358aa194f0a03a0666e4b4db

Refactoring of jax_to_tf tests: (#3262) (#3308) * Moved control-flow tests into their own file * Added a helper module tf_test_util, with a helper function ConvertAndCompare * Used self.assertAllClose instead of numpy.testing.assert_all_close because the former iterates over lists and tuples (and is standard in other JAX tests) * Used @parameterized.named_parameters for parameterized tests, for nicer test names.

view details

Jamie Townsend

commit sha c04dea1c84b657d014412a04a5312e7e525b7501

Begin implementing mask(jit)

view details

Jamie Townsend

commit sha dfe3462746d701b08f3d1ee814534f228d2fa199

Add device_put_p abstract_eval rule

view details

Jamie Townsend

commit sha 0f0032727b68e3fc07164c3505ad5c80c9c08503

Implement MaskTrace.post_process_call

view details

Jamie Townsend

commit sha 38d483737d0e35f77ed79fc279574b4b2dc46937

Fix x64 test

view details

Matthew Johnson

commit sha 9c0a58a8e774171e5e465bd81d2fad481c5264fc

add float dtype checks to random.py (#3320) fixes #3317

view details

Jake VanderPlas

commit sha 45444363449ee651386d283235661c4f18339a47

Improve error when zero-sized arrays passed to convolve

view details

push time in 2 months

issue commentgoogle/jax

[jax2tf] Incorrect out of bound index handling for lax.dynamic_slice

Can scatter ignore out of bound indices?

tberghammer

comment created time in 3 months

push eventmajnemer/jax

David Majnemer

commit sha 6653cdbb3b5d5329847c97469fd158519be2417b

Fix lax_reference's round for edge case inputs - round(8388609) would compute trunc(8388609 + 0.5) == 8388610. Fix this by not modifying sufficiently inputs. - round(0.499999970198) would compute trunc(0.499999970198 + 0.5) == 1.0 Fix this by explicitly special casing the first float before 0.5.

view details

push time in 3 months

more