diff --git a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py index 45f9b279cb..00a357603f 100644 --- a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py +++ b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py @@ -21,6 +21,7 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage): depends_on("python@3:", type=("build", "run")) depends_on("py-setuptools", type="build") + depends_on("py-packaging", type="build") depends_on("py-torch@0.4:", type=("build", "run")) depends_on("cuda@9:", when="+cuda") depends_on("py-pybind11", type=("build", "link", "run")) @@ -43,6 +44,7 @@ def setup_build_environment(self, env): else: env.unset("CUDA_HOME") + @when("^python@:3.10") def global_options(self, spec, prefix): args = [] if spec.satisfies("^py-torch@1.0:"): @@ -50,3 +52,11 @@ def global_options(self, spec, prefix): if "+cuda" in spec: args.append("--cuda_ext") return args + + @when("^python@3.11:") + def config_settings(self, spec, prefix): + return { + "builddir": "build", + "compile-args": f"-j{make_jobs}", + "--global-option": "--cpp_ext --cuda_ext", + }