diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index d8def42008..b4c7daba27 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -47,6 +47,13 @@ class PyJaxlib(PythonPackage, CudaPackage): depends_on("py-absl-py", when="@:0.3", type=("build", "run")) depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run")) + conflicts( + "cuda_arch=none", + when="+cuda", + msg="Must specify CUDA compute capabilities of your GPU, see " + "https://developer.nvidia.com/cuda-gpus", + ) + def patch(self): self.tmp_path = tempfile.mkdtemp(prefix="spack") self.buildtmp = tempfile.mkdtemp(prefix="spack")