Ask questionsInstalling from source using Conda and CUDA could be improved

Thanks to all contributors for their efforts in creating and open sourcing the library. I would like to add my 2 cents of installation process involving building from source for whatever that's worth.

I always like to install things in conda envs so that there is no clash between different software version or requirement libraries.


conda create -n jax python scipy cudnn cudatoolkit
conda list


Now the installation process:

python build/ --enable_cuda --cuda_path  ~/miniconda3/envs/jax/lib/ --cudnn_path ~/miniconda3/envs/jax/include

2 Problems arise:

1. nvcc cannot be found in path ~/miniconda3/envs/jax/lib/ bin 
actually the path is wrong, it should have been ~/miniconda3/envs/jax/bin.
Anyways, I copy nvcc from system wide installation /opt/cuda/bin/nvcc into ~/miniconda3/envs/jax/lib/bin.
So far so good.

2. re-running build it complains about cuda.h
Cuda Configuration Error: Cannot find cuda.h under ~/miniconda3/envs/jax/lib 
FAILED: Build did NOT complete successfully (4 packages loaded, 16 targets

ok, let's copy /opt/cuda/include/cuda.h into ~/miniconda3/envs/jax/lib 
re-running build after removing completely rm -rf ~/.cache/bazel
gives again the same error about not being able to find cuda.h.
At this point I am out of ideas.

Anyone else having other ideas on how to resolve this?


Answer questions ksquarekumar

This particular error comes from XLA (which is included as part of TF) so this is indeed a common issue across TF and JAX, most likely.

I think you should be able to work around it by setting the environment variable: export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/my/cuda/installation

Out of curiosity, where is your CUDA installation? One thing we could do is add more paths to the list of search paths XLA uses to find ptxas. Are there are any CUDA... environment variables set that point to it?

Thank you, this lets me install JAX with GPU support on my conda environment!

Would providing prebuilt Conda packages solve your problem adequately?

yes please, esp. if there are pip related regressions like here


Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
jax `odeint` fails against scipy `odeint` hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Custom VJPs for external functions hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Reshape layer for stax - jax hot 1
Clear GPU memory hot 1
jax/stax BatchNorm: running average on the training set and l2 regularisation hot 1
Github User Rank List