profile
viewpoint

Ask questionsClear GPU memory

Dear jax team,

I'd like to use jax alongside other tools running on GPU in the same pipeline. Is there a possibility to "encapsulate" the usage of jax/XLA so that the GPU is freed afterwards? Even if I would have to copy over the DeviceArrays into numpy manually.

Maybe something like:

with jax.Block():
    result = some_jitted_fun(a, b, c)
    result = onp.copy(result)

I can imagine the (design of) handling of objects and their GPU memory is not straightforward, if not practically impossible. Could I at least tell jax to use the GPU only incrementally instead of filling the memory completely on import?

google/jax

Answer questions froystig

But I assume only affects pre-allocation, not freeing the memory afterwards?

Device memory for an array ought to be freed once all Python references to it drop, i.e. upon destruction of any corresponding DeviceArray. You could encourage this explicitly with del my_device_array, if Python scope isn't already lined up with your pipeline "blocks."

In your example, the line

result = onp.copy(result)

will drop the only reference to a DeviceArray (from the previous line), and should clear the device memory associated with the value of some_jitted_fun(a, b, c), for the same reason.

useful!

Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
Installing from source using Conda and CUDA could be improved - jax hot 1
jax `odeint` fails against scipy `odeint` hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Custom VJPs for external functions hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Reshape layer for stax - jax hot 1
Installing from source using Conda and CUDA could be improved hot 1
jax/stax BatchNorm: running average on the training set and l2 regularisation hot 1
Github User Rank List