profile
viewpoint
If you are wondering where the data of this site comes from, please visit https://api.github.com/users/mattjj/events. GitMemory does not store any data, but only uses NGINX to cache data for a period of time. The idea behind GitMemory is simply to give users a better reading experience.
Matthew Johnson mattjj Google San Francisco people.csail.mit.edu/~mattjj research scientist @ Google Brain

google/jax 11770

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

google-research/dex-lang 852

Research language for array processing in the Haskell/ML family

mattjj/autodidact 583

A pedagogical implementation of Autograd

google/xls 494

XLS: Accelerated HW Synthesis

jacobjinkelly/easy-neural-ode 144

Code for the paper "Learning Differential Equations that are Easy to Solve"

duvenaud/relax 143

Optimizing control variates for black-box gradient estimation

google-research/autoconj 35

Recognizing and exploiting conjugacy without a domain-specific language

duvenaud/jaxde 26

Prototypes of differentiable differential equation solvers in JAX.

mattjj/config-fish 4

my fish configuration

mattjj/config-vim 4

my .vim

issue closedgoogle/jax

Failure to build jaxlib v0.1.62 on Windows

I have just tried to build the newly released jaxlib-v0.1.62 using the following command on a Windows 10 with VS2019.

v0.1.61 builds fine with the same command using Bazel v3.7. I have also tried with Bazel v4.0 and receive the same error.

python .\build\build.py --enable_cuda --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0" --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0" --cuda_compute_capabilities="7.5" --cuda_version="11.0" --cudnn_version="8.0.5"

and received the output as listed below.

    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Bazel binary path: C:\bazel\bazel.EXE
Python binary path: C:/Users/Adam/anaconda3/python.exe
Python version: 3.8
MKL-DNN enabled: yes
Target CPU features: release
CUDA enabled: yes
CUDA toolkit path: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
CUDNN library path: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
CUDA compute capabilities: 7.5
CUDA version: 11.0
CUDNN version: 8.0.5
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
C:\bazel\bazel.EXE run --verbose_failures=true --config=short_logs --config=mkl_open_source_only --config=cuda --define=xla_python_enable_gpu=true :build_wheel -- --output_path=C:\sdks\jax-jaxlib-v0.1.62\dist
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=1 --terminal_columns=80
INFO: Reading rc options for 'run' from c:\sdks\jax-jaxlib-v0.1.62\.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Options provided by the client:
  Inherited 'build' options: --python_path=C:/Users/Adam/anaconda3/python.exe
INFO: Reading rc options for 'run' from c:\sdks\jax-jaxlib-v0.1.62\.bazelrc:
  Inherited 'build' options: --repo_env PYTHON_BIN_PATH=C:/Users/Adam/anaconda3/python.exe --action_env=PYENV_ROOT --python_path=C:/Users/Adam/anaconda3/python.exe --repo_env TF_NEED_CUDA=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 --repo_env TF_NEED_ROCM=0 --action_env TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 --distinct_host_configuration=false -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --enable_platform_specific_config --define=with_tpu_support=true --action_env CUDA_TOOLKIT_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0 --action_env CUDNN_INSTALL_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0 --action_env TF_CUDA_VERSION=11.0 --action_env TF_CUDNN_VERSION=8.0.5
INFO: Found applicable config definition build:short_logs in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:cuda in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda
INFO: Found applicable config definition build:windows in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --copt=/D_USE_MATH_DEFINES --host_copt=/D_USE_MATH_DEFINES --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN --copt=-DNOGDI --host_copt=-DNOGDI --copt=/Zc:preprocessor --cxxopt=/std:c++14 --host_cxxopt=/std:c++14 --linkopt=/DEBUG --host_linkopt=/DEBUG --linkopt=/OPT:REF --host_linkopt=/OPT:REF --linkopt=/OPT:ICF --host_linkopt=/OPT:ICF --experimental_strict_action_env=true
ERROR: C:/users/adam/_bazel_adam/nzquhzn2/external/org_tensorflow/tensorflow/core/common_runtime/BUILD:1647:16: Illegal ambiguous match on configurable attribute "deps" in @org_tensorflow//tensorflow/core/common_runtime:core_cpu_internal:
@org_tensorflow//tensorflow:windows
@org_tensorflow//tensorflow:with_tpu_support
Multiple matches are not allowed unless one is unambiguously more specialized.
ERROR: Analysis of target '//build:build_wheel' failed; build aborted: C:/users/adam/_bazel_adam/nzquhzn2/external/org_tensorflow/tensorflow/core/common_runtime/BUILD:1647:16: Illegal ambiguous match on configurable attribute "deps" in @org_tensorflow//tensorflow/core/common_runtime:core_cpu_internal:
@org_tensorflow//tensorflow:windows
@org_tensorflow//tensorflow:with_tpu_support
Multiple matches are not allowed unless one is unambiguously more specialized.
INFO: Elapsed time: 1.698s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (1 packages loaded, 184 targets co\
FAILED: Build did NOT complete successfully (1 packages loaded, 184 targets co\
nfigured)
Traceback (most recent call last):
  File ".\build\build.py", line 516, in <module>
    main()
  File ".\build\build.py", line 511, in main
    shell(command)
  File ".\build\build.py", line 51, in shell
    output = subprocess.check_output(cmd)
  File "C:\Users\Adam\anaconda3\lib\subprocess.py", line 411, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "C:\Users\Adam\anaconda3\lib\subprocess.py", line 512, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['C:\\bazel\\bazel.EXE', 'run', '--verbose_failures=true', '--config=short_logs', '--config=mkl_open_source_only', '--config=cuda', '--define=xla_python_enable_gpu=true', ':build_wheel', '--', '--output_path=C:\\sdks\\jax-jaxlib-v0.1.62\\dist']' returned non-zero exit status 1.

closed time in 7 minutes

oracle3001

issue commentgoogle/jax

Failure to build jaxlib v0.1.62 on Windows

This should be fixed via https://github.com/google/jax/pull/5983 (the issue was that jaxlib doesn't like when you try to build for GPU and TPU at the same time). Pull and try again? (I'm gonna close this because I think it's fixed, but please lemme know if that's not the case)

oracle3001

comment created time in 7 minutes

Pull request review commentgoogle/jax

Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.

 _os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') del _os +# Set Cloud TPU env vars if necessary before transitively loading C++ backend+from .cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init+_cloud_tpu_init()

Good idea, done. I made it raise a warning with the exception, because in theory this should never fail (like so much of jax!). Lemme know if you think it should be silent instead.

skye

comment created time in 18 minutes

PR merged google/jax

Reviewers
Don't build with Cloud TPU support for GPU or Mac wheels. cla: yes pull ready

This should make our builds simpler and less failure-prone.

+9 -4

0 comment

2 changed files

skye

pr closed time in 30 minutes

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 7b7a2a1ba83fec86f6fa6e4526a12c5e33f80ebb

Don't build with Cloud TPU support for GPU or Mac wheels. This should make our builds simpler and less failure-prone.

view details

jax authors

commit sha a61c43c32ae5e63bdb672a4c76630d2b717cf2c5

Merge pull request #5983 from skye:enable_tpu PiperOrigin-RevId: 361692916

view details

push time in 30 minutes

Pull request review commentgoogle/jax

Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.

 _os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') del _os +# Set Cloud TPU env vars if necessary before transitively loading C++ backend+from .cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init+_cloud_tpu_init()

Would it be worth wrapping this in a try/except? I'm worried this could make jax unimportable if something unanticipated goes wrong.

skye

comment created time in an hour

PR opened google/jax

Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.

This removes the need to manually set these env vars when running on a Cloud TPU pod slice.

+110 -0

0 comment

3 changed files

pr created time in an hour

PR merged google/jax

DOC: add transformations doc to HTML & reorganize contents cla: yes pull ready

This seems to fit better in the actual documentation, where the code can be doctested, which actually caught some existing mistakes.

This is a first-step toward making the github README a bit lighter and making the HTML docs a bit more organized and useful. As part of that, I re-organized our documentation to separate API docs from the advanced tutorials. I think some further reorganization based on this will be useful.

+314 -23

0 comment

5 changed files

jakevdp

pr closed time in an hour

push eventgoogle/jax

Jake VanderPlas

commit sha 749ad95514c8068afa71632bd0d917207c34f1fb

DOC: add transformations doc to HTML & reorganize contents

view details

jax authors

commit sha 591a484247a8af0b14aac8eac32f15c769259ea7

Merge pull request #5908 from jakevdp:transforms-doc PiperOrigin-RevId: 361685827

view details

push time in an hour

Pull request review commentgoogle/jax

DOC: add transformations doc to HTML & reorganize contents

 Indices and tables * :ref:`genindex` * :ref:`modindex` * :ref:`search`+++.. _Autograd: https://github.com/hips/autograd)

Done!

jakevdp

comment created time in an hour

PR opened google/jax

Reviewers
Don't build with Cloud TPU support for GPU or Mac wheels.

This should make our builds simpler and less failure-prone.

+9 -4

0 comment

2 changed files

pr created time in an hour

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 902038a71803bc5323d5c55f42394be642614639

Revert breaking change: Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM. This removes the need to manually set these env vars when running on a Cloud TPU pod slice. PiperOrigin-RevId: 361681134

view details

push time in 2 hours

PR merged google/jax

Reviewers
Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM. cla: yes pull ready

This removes the need to manually set these env vars when running on a Cloud TPU pod slice.

I manually tested this on a v2-8 and v2-32.

+110 -0

2 comments

3 changed files

skye

pr closed time in 2 hours

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 5a2859e1b6e306b9483beba34946ea82b679e1a0

Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM. This removes the need to manually set these env vars when running on a Cloud TPU pod slice.

view details

jax authors

commit sha 35214a1c0d1b4da048f2c4c8150dd7440857b0ae

Merge pull request #5962 from skye:cloud_tpu_env_vars PiperOrigin-RevId: 361671219

view details

push time in 2 hours

PR merged google/jax

Initial commit of seventh JAX-101 notebook cla: yes pull ready
+714 -0

0 comment

3 changed files

jakevdp

pr closed time in 2 hours

push eventgoogle/jax

Jake VanderPlas

commit sha 35933238d1e40bae2564e2e4f639b3ef2e6ae7b1

Initial commit of seventh JAX-101 notebook

view details

jax authors

commit sha 93d477935a0c3f829fd3c7c337d35cb72ee6844f

Merge pull request #5980 from jakevdp:jax-101 PiperOrigin-RevId: 361671022

view details

push time in 2 hours

issue openedgoogle/jax

jax.experimental.enable_x64 fails in JIT context

minimal repro:

from jax import jit, experimental

@jit
def f(a):
  with experimental.enable_x64():
    return 1 + a
f(1)

created time in 3 hours

Pull request review commentgoogle/jax

Initial commit of seventh JAX-101 notebook

+---+jupytext:+  formats: ipynb,md:myst+  text_representation:+    extension: .md+    format_name: myst+    format_version: 0.13+    jupytext_version: 1.10.0+kernelspec:+  display_name: Python 3+  name: python3+---+++++ {"id": "Ga0xSM8xhBIm"}++# Stateful computations++[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb)++*Authors: Vladimir Mikulik*++This section explores how JAX constrains the implementation of stateful programs.+++++ {"id": "Avjnyrjojo8z"}++## Motivation++In machine learning, program state most often comes in the form of:+* model parameters,+* optimizer state, and+* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).++Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.++Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming).++```{code-cell} ipython3+:id: S-n8XRC2kTnI++import jax+import jax.numpy as jnp+```+++++ {"id": "s_-6semKkSzp"}++## A simple example: Counter++Let's start by looking at a simple stateful program: a counter.++```{code-cell} ipython3+:id: B3aoCHpjg8gm+:outputId: 5cbcfbf5-5c42-498f-a175-050438518337++class Counter:+  """A simple counter."""++  def __init__(self):+    self.n = 0++  def count(self) -> int:+    """Increments the counter and returns the new value."""+    self.n += 1+    return self.n++  def reset(self):+    """Resets the counter to zero."""+    self.n = 0+++counter = Counter()++for _ in range(3):+  print(counter.count())+```+++++ {"id": "SQ-RNLfdiw04"}++The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.++Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to JIT-compile the update of model parameters, where `jax.jit` makes an enormous difference).++```{code-cell} ipython3+:id: 5jSjmJMon03W+:outputId: d952f16b-9b30-4753-ed94-cc914a929a36++counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  print(fast_count())+```+++++ {"id": "weiI0V7_pKGv"}++Oh no! Our counter isn't working. This is because the line+```+self.n += 1+```+in `count` is only called once, when JAX compiles the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returns the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?++## The solution: explicit state++Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was "baked into" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument?++```{code-cell} ipython3+:id: 53pSdK4KoOEZ+:outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79++from typing import Tuple++CounterState = int++class CounterV2:++  def count(self, n: CounterState) -> Tuple[int, CounterState]:+    # You could just return n+1, but here we separate its role as +    # the output and as the counter state for didactic purposes.+    return n+1, n+1++  def reset(self) -> CounterState:+    return 0++counter = CounterV2()+state = counter.reset()++for _ in range(3):+  value, state = counter.count(state)+  print(value)+```+++++ {"id": "PrBjmgZtq89b"}++In this new version of `Counter`, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:++```{code-cell} ipython3+:id: LO4Xzcq_q8PH+:outputId: 25c06a56-f2bf-4c54-a3c3-6e093d484362++state = counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  value, state = fast_count(state)+  print(value)+```+++++ {"id": "MzMSWD2_sgnh"}++## A general strategy++We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form++```python+class StatefulClass++  state: State++  def stateful_method(*args, **kwargs) -> Output:+```++and turned it into a class of the form++```python+class StatelessClass++  def stateless_method(state: State, *args, **kwargs) -> (Output, State):+```++This is a common [functional programming](https://en.wikipedia.org/wiki/Functional_programming) pattern, and, essentially, is the way that state is handled in all JAX programs.++Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state. ++In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?++Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.+++++ {"id": "I2SqRx14_z98"}++## Simple worked example: Linear Regression++Let's apply this strategy to a simple machine learning model: linear regression via gradient descent.++Here, we only deal with one kind of state: the model parameters. But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.++The function to look at carefully is `update`.++```{code-cell} ipython3+:id: wQdU7DoAseW6++from typing import NamedTuple++class Params(NamedTuple):+  weight: jnp.ndarray+  bias: jnp.ndarray+++def init(rng) -> Params:+  """Returns the initial model params."""+  weights_key, bias_key = jax.random.split(rng)+  weight = jax.random.normal(weights_key, ())+  bias = jax.random.normal(bias_key, ())+  return Params(weight, bias)+++def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:+  """Computes the least squares error of the model's predictions on x against y."""+  pred = params.weight * x + params.bias+  return jnp.mean((pred - y) ** 2)+++LEARNING_RATE = 0.005++@jax.jit+def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:+  """Performs one SGD update step on params using the given data."""+  grad = jax.grad(loss)(params, x, y)++  # If we were using Adam or another stateful optimizer,+  # we would also do something like+  # ```+  # updates, new_optimizer_state = optimizer(grad, optimizer_state)+  # ```+  # and then use `updates` instead of `grad` to actually update the params.+  # (And we'd include `new_optimizer_state` in the output, naturally.)++  new_params = jax.tree_multimap(+      lambda param, g: param - g * LEARNING_RATE, params, grad)++  return new_params+```+++++ {"id": "dKySWouu2-Hu"}++Notice that we manually pipe the params in and out of the update function.++```{code-cell} ipython3+:id: jQCYYy0yxO6K+:outputId: 1f3b69d2-e90b-4065-cbcc-6422978d25c2++import matplotlib.pyplot as plt++rng = jax.random.PRNGKey(42)++# Generate true data from y = w*x + b + noise+true_w, true_b = 2, -1+x_rng, noise_rng = jax.random.split(rng)+xs = jax.random.normal(x_rng, (128, 1))+noise = jax.random.normal(noise_rng, (128, 1)) * 0.5+ys = xs * true_w + true_b + noise++# Fit regression+params = init(rng)+for _ in range(1000):+  params = update(params, xs, ys)++plt.scatter(xs, ys)+plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')+plt.legend()+plt.show()+```+++++ {"id": "1wq3L6Xg1UHP"}++## Taking it further++The strategy described above is how any (jitted) JAX program must handle state. ++Handling parameters manually seems fine if you're dealing with two parameters, but what if it's a neural net with dozens of layers? You might already be getting worried about two things:++1) Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?++2) Are we supposed to pipe all these things around manually?++The details can be tricky to handle, but there are examples of libraries that take care of this for you. One example+is [Haiku](https://github.com/deepmind/dm-haiku), which allows you to write OOP-style neural nets without explicitly tracking their parameters, and then to transform them into functional programs.

Great idea - done!

jakevdp

comment created time in 3 hours

Pull request review commentgoogle/jax

Initial commit of seventh JAX-101 notebook

+---+jupytext:+  formats: ipynb,md:myst+  text_representation:+    extension: .md+    format_name: myst+    format_version: 0.13+    jupytext_version: 1.10.0+kernelspec:+  display_name: Python 3+  name: python3+---+++++ {"id": "Ga0xSM8xhBIm"}++# Stateful computations++[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb)++*Authors: Vladimir Mikulik*++This section explores how JAX constrains the implementation of stateful programs.+++++ {"id": "Avjnyrjojo8z"}++## Motivation++In machine learning, program state most often comes in the form of:+* model parameters,+* optimizer state, and+* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).++Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.++Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming).++```{code-cell} ipython3+:id: S-n8XRC2kTnI++import jax+import jax.numpy as jnp+```+++++ {"id": "s_-6semKkSzp"}++## A simple example: Counter++Let's start by looking at a simple stateful program: a counter.++```{code-cell} ipython3+:id: B3aoCHpjg8gm+:outputId: 5cbcfbf5-5c42-498f-a175-050438518337++class Counter:+  """A simple counter."""++  def __init__(self):+    self.n = 0++  def count(self) -> int:+    """Increments the counter and returns the new value."""+    self.n += 1+    return self.n++  def reset(self):+    """Resets the counter to zero."""+    self.n = 0+++counter = Counter()++for _ in range(3):+  print(counter.count())+```+++++ {"id": "SQ-RNLfdiw04"}++The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.++Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to JIT-compile the update of model parameters, where `jax.jit` makes an enormous difference).++```{code-cell} ipython3+:id: 5jSjmJMon03W+:outputId: d952f16b-9b30-4753-ed94-cc914a929a36++counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  print(fast_count())+```+++++ {"id": "weiI0V7_pKGv"}++Oh no! Our counter isn't working. This is because the line+```+self.n += 1+```+in `count` is only called once, when JAX compiles the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returns the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?++## The solution: explicit state++Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was "baked into" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument?++```{code-cell} ipython3+:id: 53pSdK4KoOEZ+:outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79++from typing import Tuple++CounterState = int++class CounterV2:++  def count(self, n: CounterState) -> Tuple[int, CounterState]:+    # You could just return n+1, but here we separate its role as +    # the output and as the counter state for didactic purposes.+    return n+1, n+1++  def reset(self) -> CounterState:+    return 0++counter = CounterV2()+state = counter.reset()++for _ in range(3):+  value, state = counter.count(state)+  print(value)+```+++++ {"id": "PrBjmgZtq89b"}++In this new version of `Counter`, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:++```{code-cell} ipython3+:id: LO4Xzcq_q8PH+:outputId: 25c06a56-f2bf-4c54-a3c3-6e093d484362++state = counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  value, state = fast_count(state)+  print(value)+```+++++ {"id": "MzMSWD2_sgnh"}++## A general strategy++We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form++```python+class StatefulClass++  state: State++  def stateful_method(*args, **kwargs) -> Output:+```++and turned it into a class of the form++```python+class StatelessClass++  def stateless_method(state: State, *args, **kwargs) -> (Output, State):+```++This is a common [functional programming](https://en.wikipedia.org/wiki/Functional_programming) pattern, and, essentially, is the way that state is handled in all JAX programs.++Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state. ++In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?++Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.+++++ {"id": "I2SqRx14_z98"}++## Simple worked example: Linear Regression++Let's apply this strategy to a simple machine learning model: linear regression via gradient descent.++Here, we only deal with one kind of state: the model parameters. But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.++The function to look at carefully is `update`.++```{code-cell} ipython3+:id: wQdU7DoAseW6++from typing import NamedTuple++class Params(NamedTuple):+  weight: jnp.ndarray+  bias: jnp.ndarray+++def init(rng) -> Params:+  """Returns the initial model params."""+  weights_key, bias_key = jax.random.split(rng)+  weight = jax.random.normal(weights_key, ())+  bias = jax.random.normal(bias_key, ())+  return Params(weight, bias)+++def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:+  """Computes the least squares error of the model's predictions on x against y."""+  pred = params.weight * x + params.bias+  return jnp.mean((pred - y) ** 2)+++LEARNING_RATE = 0.005++@jax.jit+def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:+  """Performs one SGD update step on params using the given data."""+  grad = jax.grad(loss)(params, x, y)++  # If we were using Adam or another stateful optimizer,+  # we would also do something like+  # ```+  # updates, new_optimizer_state = optimizer(grad, optimizer_state)+  # ```+  # and then use `updates` instead of `grad` to actually update the params.+  # (And we'd include `new_optimizer_state` in the output, naturally.)++  new_params = jax.tree_multimap(+      lambda param, g: param - g * LEARNING_RATE, params, grad)++  return new_params+```+++++ {"id": "dKySWouu2-Hu"}++Notice that we manually pipe the params in and out of the update function.++```{code-cell} ipython3+:id: jQCYYy0yxO6K+:outputId: 1f3b69d2-e90b-4065-cbcc-6422978d25c2++import matplotlib.pyplot as plt++rng = jax.random.PRNGKey(42)++# Generate true data from y = w*x + b + noise+true_w, true_b = 2, -1+x_rng, noise_rng = jax.random.split(rng)+xs = jax.random.normal(x_rng, (128, 1))+noise = jax.random.normal(noise_rng, (128, 1)) * 0.5+ys = xs * true_w + true_b + noise++# Fit regression+params = init(rng)+for _ in range(1000):+  params = update(params, xs, ys)++plt.scatter(xs, ys)+plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')+plt.legend()+plt.show()+```+++++ {"id": "1wq3L6Xg1UHP"}++## Taking it further++The strategy described above is how any (jitted) JAX program must handle state. ++Handling parameters manually seems fine if you're dealing with two parameters, but what if it's a neural net with dozens of layers? You might already be getting worried about two things:++1) Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?++2) Are we supposed to pipe all these things around manually?++The details can be tricky to handle, but there are examples of libraries that take care of this for you. One example+is [Haiku](https://github.com/deepmind/dm-haiku), which allows you to write OOP-style neural nets without explicitly tracking their parameters, and then to transform them into functional programs.

Instead of calling out Haiku here, how about we link to https://github.com/google/jax#neural-network-libraries instead.

jakevdp

comment created time in 3 hours

Pull request review commentgoogle/jax

Initial commit of seventh JAX-101 notebook

+{+ "cells": [+  {+   "cell_type": "markdown",+   "metadata": {+    "id": "Ga0xSM8xhBIm"+   },+   "source": [+    "# Stateful computations\n",+    "\n",+    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb)\n",+    "\n",+    "*Authors: Vladimir Mikulik*\n",+    "\n",+    "This section explores how JAX constrains the implementation of stateful programs."+   ]+  },+  {+   "cell_type": "markdown",+   "metadata": {+    "id": "Avjnyrjojo8z"+   },+   "source": [+    "## Motivation\n",+    "\n",+    "In machine learning, program state most often comes in the form of\n",+    "* model parameters,\n",+    "* optimizer state, and\n",+    "* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).\n",+    "\n",+    "Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.\n",+    "\n",+    "Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming)."+   ]+  },+  {+   "cell_type": "code",+   "execution_count": 1,+   "metadata": {+    "id": "S-n8XRC2kTnI"+   },+   "outputs": [],+   "source": [+    "import jax\n",+    "import jax.numpy as jnp"+   ]+  },+  {+   "cell_type": "markdown",+   "metadata": {+    "id": "s_-6semKkSzp"+   },+   "source": [+    "## A simple example: Counter\n",+    "\n",+    "Let's start by looking at a simple stateful program: a counter."+   ]+  },+  {+   "cell_type": "code",+   "execution_count": 2,+   "metadata": {+    "id": "B3aoCHpjg8gm",+    "outputId": "5cbcfbf5-5c42-498f-a175-050438518337"+   },+   "outputs": [+    {+     "name": "stdout",+     "output_type": "stream",+     "text": [+      "1\n",+      "2\n",+      "3\n"+     ]+    }+   ],+   "source": [+    "class Counter:\n",+    "  \"\"\"A simple counter.\"\"\"\n",+    "\n",+    "  def __init__(self):\n",+    "    self.n = 0\n",+    "\n",+    "  def count(self) -> int:\n",+    "    \"\"\"Increments the counter and returns the new value.\"\"\"\n",+    "    self.n += 1\n",+    "    return self.n\n",+    "\n",+    "  def reset(self):\n",+    "    \"\"\"Resets the counter to zero.\"\"\"\n",+    "    self.n = 0\n",+    "\n",+    "\n",+    "counter = Counter()\n",+    "\n",+    "for _ in range(3):\n",+    "  print(counter.count())"+   ]+  },+  {+   "cell_type": "markdown",+   "metadata": {+    "id": "SQ-RNLfdiw04"+   },+   "source": [+    "The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.\n",+    "\n",+    "Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to jit-compile the update of model parameters, where `jax.jit` makes an enormous difference)."

Suggestion:

    "Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to JIT-compile the update of model parameters, where `jax.jit` makes an enormous difference)."
jakevdp

comment created time in 4 hours

Pull request review commentgoogle/jax

Initial commit of seventh JAX-101 notebook

+---+jupytext:+  formats: ipynb,md:myst+  text_representation:+    extension: .md+    format_name: myst+    format_version: 0.13+    jupytext_version: 1.10.0+kernelspec:+  display_name: Python 3+  name: python3+---+++++ {"id": "Ga0xSM8xhBIm"}++# Stateful computations++[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb)++*Authors: Vladimir Mikulik*++This section explores how JAX constrains the implementation of stateful programs.+++++ {"id": "Avjnyrjojo8z"}++## Motivation++In machine learning, program state most often comes in the form of+* model parameters,+* optimizer state, and+* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).

Suggestion: nit

In machine learning, program state most often comes in the form of:
* model parameters,
* optimizer state, and
* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).
jakevdp

comment created time in 5 hours

Pull request review commentgoogle/jax

Initial commit of seventh JAX-101 notebook

+---+jupytext:+  formats: ipynb,md:myst+  text_representation:+    extension: .md+    format_name: myst+    format_version: 0.13+    jupytext_version: 1.10.0+kernelspec:+  display_name: Python 3+  name: python3+---+++++ {"id": "Ga0xSM8xhBIm"}++# Stateful computations++[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb)++*Authors: Vladimir Mikulik*++This section explores how JAX constrains the implementation of stateful programs.+++++ {"id": "Avjnyrjojo8z"}++## Motivation++In machine learning, program state most often comes in the form of+* model parameters,+* optimizer state, and+* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).++Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.++Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming).+++```{code-cell} ipython3+:id: S-n8XRC2kTnI++import jax+import jax.numpy as jnp+```+++++ {"id": "s_-6semKkSzp"}++++## A simple example: Counter++Let's start by looking at a simple stateful program: a counter.++```{code-cell} ipython3+---+id: B3aoCHpjg8gm+outputId: 5cbcfbf5-5c42-498f-a175-050438518337+---+class Counter:+  """A simple counter."""++  def __init__(self):+    self.n = 0++  def count(self) -> int:+    """Increments the counter and returns the new value."""+    self.n += 1+    return self.n++  def reset(self):+    """Resets the counter to zero."""+    self.n = 0+++counter = Counter()++for _ in range(3):+  print(counter.count())+```+++++ {"id": "SQ-RNLfdiw04"}++The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.++Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to jit-compile the update of model parameters, where `jax.jit` makes an enormous difference).++```{code-cell} ipython3+---+id: 5jSjmJMon03W+outputId: d952f16b-9b30-4753-ed94-cc914a929a36+---+counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  print(fast_count())+```+++++ {"id": "weiI0V7_pKGv"}++Oh no! Our counter isn't working. This is because the line+```+self.n += 1+```+in `count` was only called once, when JAX was compiling the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returned the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?++## The solution: explicit state++Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was "baked into" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument?++```{code-cell} ipython3+---+id: 53pSdK4KoOEZ+outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79+---+from typing import Tuple++CounterState = int++class CounterV2:++  def count(self, n: CounterState) -> Tuple[int, CounterState]:+    # You could just return n+1, but here we separate its role as +    # the output and as the counter state for didactic purposes.+    return n+1, n+1++  def reset(self) -> CounterState:+    return 0++counter = CounterV2()+state = counter.reset()++for _ in range(3):+  value, state = counter.count(state)+  print(value)+```+++++ {"id": "PrBjmgZtq89b"}++In this new version of Counter, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:

Suggestion: nit (formatting)

In this new version of `Counter`, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:

(similar to "In our case, the CounterV2 class is nothin..." further below.)

jakevdp

comment created time in 4 hours

Pull request review commentgoogle/jax

Initial commit of seventh JAX-101 notebook

+---+jupytext:+  formats: ipynb,md:myst+  text_representation:+    extension: .md+    format_name: myst+    format_version: 0.13+    jupytext_version: 1.10.0+kernelspec:+  display_name: Python 3+  name: python3+---+++++ {"id": "Ga0xSM8xhBIm"}++# Stateful computations++[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb)++*Authors: Vladimir Mikulik*++This section explores how JAX constrains the implementation of stateful programs.+++++ {"id": "Avjnyrjojo8z"}++## Motivation++In machine learning, program state most often comes in the form of+* model parameters,+* optimizer state, and+* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).++Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.++Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming).+++```{code-cell} ipython3+:id: S-n8XRC2kTnI++import jax+import jax.numpy as jnp+```+++++ {"id": "s_-6semKkSzp"}++++## A simple example: Counter++Let's start by looking at a simple stateful program: a counter.++```{code-cell} ipython3+---+id: B3aoCHpjg8gm+outputId: 5cbcfbf5-5c42-498f-a175-050438518337+---+class Counter:+  """A simple counter."""++  def __init__(self):+    self.n = 0++  def count(self) -> int:+    """Increments the counter and returns the new value."""+    self.n += 1+    return self.n++  def reset(self):+    """Resets the counter to zero."""+    self.n = 0+++counter = Counter()++for _ in range(3):+  print(counter.count())+```+++++ {"id": "SQ-RNLfdiw04"}++The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.++Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to jit-compile the update of model parameters, where `jax.jit` makes an enormous difference).++```{code-cell} ipython3+---+id: 5jSjmJMon03W+outputId: d952f16b-9b30-4753-ed94-cc914a929a36+---+counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  print(fast_count())+```+++++ {"id": "weiI0V7_pKGv"}++Oh no! Our counter isn't working. This is because the line+```+self.n += 1+```+in `count` was only called once, when JAX was compiling the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returned the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?++## The solution: explicit state++Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was "baked into" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument?++```{code-cell} ipython3+---+id: 53pSdK4KoOEZ+outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79+---+from typing import Tuple++CounterState = int++class CounterV2:++  def count(self, n: CounterState) -> Tuple[int, CounterState]:+    # You could just return n+1, but here we separate its role as +    # the output and as the counter state for didactic purposes.+    return n+1, n+1++  def reset(self) -> CounterState:+    return 0++counter = CounterV2()+state = counter.reset()++for _ in range(3):+  value, state = counter.count(state)+  print(value)+```+++++ {"id": "PrBjmgZtq89b"}++In this new version of Counter, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:++```{code-cell} ipython3+---+id: LO4Xzcq_q8PH+outputId: 25c06a56-f2bf-4c54-a3c3-6e093d484362+---+state = counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  value, state = fast_count(state)+  print(value)+```+++++ {"id": "MzMSWD2_sgnh"}++## A general strategy++We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form++```python+class StatefulClass++  state: State++  def stateful_method(*args, **kwargs) -> Output:+```++and turned it into a class of the form++```python+class StatelessClass++  def stateless_method(state: State, *args, **kwargs) -> (Output, State):+```++This is a common [functional programming](https://en.wikipedia.org/wiki/Functional_programming) pattern, and, essentially, is the way that state is handled in all JAX programs.++Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state. ++In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?++Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.+++++ {"id": "I2SqRx14_z98"}++## Simple worked example: Linear Regression++Let's apply this strategy to a simple machine learning model: linear regression via gradient descent.++Here, we only deal with one kind of state: the model parameters. But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.++The function to look at carefully is `update`.++```{code-cell} ipython3+:id: wQdU7DoAseW6++from typing import NamedTuple++class Params(NamedTuple):+  weight: jnp.ndarray+  bias: jnp.ndarray+++def init(rng) -> Params:+  """Returns the initial model params."""+  weights_key, bias_key = jax.random.split(rng)+  weight = jax.random.normal(weights_key, ())+  bias = jax.random.normal(bias_key, ())+  return Params(weight, bias)+++def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:+  """Computes the least squares error of the model's predictions on x against y."""+  pred = params.weight * x + params.bias+  return jnp.mean((pred - y) ** 2)+++LEARNING_RATE = 0.005++@jax.jit+def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:+  """Performs one SGD update step on params using the given data."""+  grad = jax.grad(loss)(params, x, y)++  # If we were using Adam or another stateful optimizer,+  # we would also do something like+  # ```+  # updates, new_optimizer_state = optimizer(grad, optimizer_state)+  # ```+  # and then use `updates` instead of `grad` to actually update the params.+  # (And we'd include `new_optimizer_state` in the output, naturally.)++  new_params = jax.tree_multimap(+      lambda param, g: param - g * LEARNING_RATE, params, grad)++  return new_params+```+++++ {"id": "dKySWouu2-Hu"}++Notice that we manually pipe the params in and out of the update function.++```{code-cell} ipython3+---+id: jQCYYy0yxO6K+outputId: 1f3b69d2-e90b-4065-cbcc-6422978d25c2+---+import matplotlib.pyplot as plt++rng = jax.random.PRNGKey(42)++# Generate true data from y = w*x + b + noise+true_w, true_b = 2, -1+x_rng, noise_rng = jax.random.split(rng)+xs = jax.random.normal(x_rng, (128, 1))+noise = jax.random.normal(noise_rng, (128, 1)) * 0.5+ys = xs * true_w + true_b + noise++# Fit regression+params = init(rng)+for _ in range(1000):+  params = update(params, xs, ys)++plt.scatter(xs, ys)+plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')+plt.legend()+plt.show()+```+++++ {"id": "1wq3L6Xg1UHP"}++## Taking it further++The strategy described above is how any (jitted) JAX program must handle state. ++Handling parameters manually seems fine if you're dealing with two parameters, but what if it's a neural net with dozens of layers? You might already be getting worried about two things:++1) Are we supposed to initialize them all manually, essentialy repeating what we already write in the forward pass definition?

Suggestion: nit ("essentially")

1) Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?
jakevdp

comment created time in 5 hours

Pull request review commentgoogle/jax

Initial commit of seventh JAX-101 notebook

+---+jupytext:+  formats: ipynb,md:myst+  text_representation:+    extension: .md+    format_name: myst+    format_version: 0.13+    jupytext_version: 1.10.0+kernelspec:+  display_name: Python 3+  name: python3+---+++++ {"id": "Ga0xSM8xhBIm"}++# Stateful computations++[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb)++*Authors: Vladimir Mikulik*++This section explores how JAX constrains the implementation of stateful programs.+++++ {"id": "Avjnyrjojo8z"}++## Motivation++In machine learning, program state most often comes in the form of+* model parameters,+* optimizer state, and+* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).++Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.++Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming).+++```{code-cell} ipython3+:id: S-n8XRC2kTnI++import jax+import jax.numpy as jnp+```+++++ {"id": "s_-6semKkSzp"}++++## A simple example: Counter++Let's start by looking at a simple stateful program: a counter.++```{code-cell} ipython3+---+id: B3aoCHpjg8gm+outputId: 5cbcfbf5-5c42-498f-a175-050438518337+---+class Counter:+  """A simple counter."""++  def __init__(self):+    self.n = 0++  def count(self) -> int:+    """Increments the counter and returns the new value."""+    self.n += 1+    return self.n++  def reset(self):+    """Resets the counter to zero."""+    self.n = 0+++counter = Counter()++for _ in range(3):+  print(counter.count())+```+++++ {"id": "SQ-RNLfdiw04"}++The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.++Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to jit-compile the update of model parameters, where `jax.jit` makes an enormous difference).++```{code-cell} ipython3+---+id: 5jSjmJMon03W+outputId: d952f16b-9b30-4753-ed94-cc914a929a36+---+counter.reset()+fast_count = jax.jit(counter.count)++for _ in range(3):+  print(fast_count())+```+++++ {"id": "weiI0V7_pKGv"}++Oh no! Our counter isn't working. This is because the line+```+self.n += 1+```+in `count` was only called once, when JAX was compiling the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returned the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?

Suggestion: nit (present tense, maybe better formatting)

in `count` was only called once, when JAX was compiling the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returns the first `1`, subsequent calls to `fast_count` will always return `1`. This won't do. So, how do we fix it?
jakevdp

comment created time in 5 hours

startedChrisWaites/jax-flows

started time in 5 hours

issue commentgoogle/jax

Provide Windows binaries on PyPI

Will this ticket cover conda installs on windows as well?

cool-RR

comment created time in 5 hours

issue openedgoogle/jax

Failure to build jaxlib v0.1.62 on Windows

I have just tried to build the newly released jaxlib-v0.1.62 using the following command on a Windows 10 with VS2019.

v0.1.61 builds fine with the same command using Bazel v3.7. I have also tried with Bazel v4.0 and receive the same error.

python .\build\build.py --enable_cuda --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0" --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0" --cuda_compute_capabilities="7.5" --cuda_version="11.0" --cudnn_version="8.0.5"

and received the output as listed below.

    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Bazel binary path: C:\bazel\bazel.EXE
Python binary path: C:/Users/Adam/anaconda3/python.exe
Python version: 3.8
MKL-DNN enabled: yes
Target CPU features: release
CUDA enabled: yes
CUDA toolkit path: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
CUDNN library path: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
CUDA compute capabilities: 7.5
CUDA version: 11.0
CUDNN version: 8.0.5
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
C:\bazel\bazel.EXE run --verbose_failures=true --config=short_logs --config=mkl_open_source_only --config=cuda --define=xla_python_enable_gpu=true :build_wheel -- --output_path=C:\sdks\jax-jaxlib-v0.1.62\dist
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=1 --terminal_columns=80
INFO: Reading rc options for 'run' from c:\sdks\jax-jaxlib-v0.1.62\.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Options provided by the client:
  Inherited 'build' options: --python_path=C:/Users/Adam/anaconda3/python.exe
INFO: Reading rc options for 'run' from c:\sdks\jax-jaxlib-v0.1.62\.bazelrc:
  Inherited 'build' options: --repo_env PYTHON_BIN_PATH=C:/Users/Adam/anaconda3/python.exe --action_env=PYENV_ROOT --python_path=C:/Users/Adam/anaconda3/python.exe --repo_env TF_NEED_CUDA=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 --repo_env TF_NEED_ROCM=0 --action_env TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 --distinct_host_configuration=false -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --enable_platform_specific_config --define=with_tpu_support=true --action_env CUDA_TOOLKIT_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0 --action_env CUDNN_INSTALL_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0 --action_env TF_CUDA_VERSION=11.0 --action_env TF_CUDNN_VERSION=8.0.5
INFO: Found applicable config definition build:short_logs in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:cuda in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda
INFO: Found applicable config definition build:windows in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --copt=/D_USE_MATH_DEFINES --host_copt=/D_USE_MATH_DEFINES --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN --copt=-DNOGDI --host_copt=-DNOGDI --copt=/Zc:preprocessor --cxxopt=/std:c++14 --host_cxxopt=/std:c++14 --linkopt=/DEBUG --host_linkopt=/DEBUG --linkopt=/OPT:REF --host_linkopt=/OPT:REF --linkopt=/OPT:ICF --host_linkopt=/OPT:ICF --experimental_strict_action_env=true
ERROR: C:/users/adam/_bazel_adam/nzquhzn2/external/org_tensorflow/tensorflow/core/common_runtime/BUILD:1647:16: Illegal ambiguous match on configurable attribute "deps" in @org_tensorflow//tensorflow/core/common_runtime:core_cpu_internal:
@org_tensorflow//tensorflow:windows
@org_tensorflow//tensorflow:with_tpu_support
Multiple matches are not allowed unless one is unambiguously more specialized.
ERROR: Analysis of target '//build:build_wheel' failed; build aborted: C:/users/adam/_bazel_adam/nzquhzn2/external/org_tensorflow/tensorflow/core/common_runtime/BUILD:1647:16: Illegal ambiguous match on configurable attribute "deps" in @org_tensorflow//tensorflow/core/common_runtime:core_cpu_internal:
@org_tensorflow//tensorflow:windows
@org_tensorflow//tensorflow:with_tpu_support
Multiple matches are not allowed unless one is unambiguously more specialized.
INFO: Elapsed time: 1.698s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (1 packages loaded, 184 targets co\
FAILED: Build did NOT complete successfully (1 packages loaded, 184 targets co\
nfigured)
Traceback (most recent call last):
  File ".\build\build.py", line 516, in <module>
    main()
  File ".\build\build.py", line 511, in main
    shell(command)
  File ".\build\build.py", line 51, in shell
    output = subprocess.check_output(cmd)
  File "C:\Users\Adam\anaconda3\lib\subprocess.py", line 411, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "C:\Users\Adam\anaconda3\lib\subprocess.py", line 512, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['C:\\bazel\\bazel.EXE', 'run', '--verbose_failures=true', '--config=short_logs', '--config=mkl_open_source_only', '--config=cuda', '--define=xla_python_enable_gpu=true', ':build_wheel', '--', '--output_path=C:\\sdks\\jax-jaxlib-v0.1.62\\dist']' returned non-zero exit status 1.

created time in 5 hours

PR opened google/jax

Initial commit of seventh JAX-101 notebook
+724 -0

0 comment

3 changed files

pr created time in 5 hours

PR merged google/jax

Initial commit of sixth JAX-101 notebook cla: yes pull ready

Replaces #5957. I removed the problematic URL

+1350 -0

0 comment

3 changed files

jakevdp

pr closed time in 5 hours

push eventgoogle/jax

Jake VanderPlas

commit sha 59790b3b260149aa7f50626ccaa6fbb2f0811a70

Initial commit of sixth JAX-101 notebook

view details

jax authors

commit sha a68e8b3c76bff6d65cc142a3155357d43e7c568a

Merge pull request #5963 from jakevdp:jax-101-part6 PiperOrigin-RevId: 361626184

view details

push time in 5 hours