Ask questionsClear GPU memory
Dear jax team,
I'd like to use jax alongside other tools running on GPU in the same pipeline. Is there a possibility to "encapsulate" the usage of jax/XLA so that the GPU is freed afterwards? Even if I would have to copy over the
DeviceArrays into numpy manually.
Maybe something like:
with jax.Block(): result = some_jitted_fun(a, b, c) result = onp.copy(result)
I can imagine the (design of) handling of objects and their GPU memory is not straightforward, if not practically impossible. Could I at least tell jax to use the GPU only incrementally instead of filling the memory completely on import?
Answer questions froystig
But I assume only affects pre-allocation, not freeing the memory afterwards?
Device memory for an array ought to be freed once all Python references to it drop, i.e. upon destruction of any corresponding
DeviceArray. You could encourage this explicitly with
del my_device_array, if Python scope isn't already lined up with your pipeline "blocks."
In your example, the line
result = onp.copy(result)
will drop the only reference to a
DeviceArray (from the previous line), and should clear the device memory associated with the value of
some_jitted_fun(a, b, c), for the same reason.