py-torch: set TORCH_CUDA_ARCH_LIST globally for dependents (#43962)

This commit is contained in:
Adam J. Stewart 2024-05-03 09:11:59 +02:00 committed by GitHub
parent 89bf1edb6e
commit 14561fafff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 12 additions and 45 deletions

View file

@ -35,11 +35,6 @@ class PyTorchCluster(PythonPackage):
def setup_build_environment(self, env):
if "+cuda" in self.spec:
cuda_arches = list(self.spec["py-torch"].variants["cuda_arch"].value)
for i, x in enumerate(cuda_arches):
cuda_arches[i] = "{0}.{1}".format(x[0:-1], x[-1])
env.set("TORCH_CUDA_ARCH_LIST", str.join(" ", cuda_arches))
env.set("FORCE_CUDA", "1")
env.set("CUDA_HOME", self.spec["cuda"].prefix)
else:

View file

@ -72,11 +72,6 @@ class PyTorchGeometric(PythonPackage):
def setup_build_environment(self, env):
if "+cuda" in self.spec:
cuda_arches = list(self.spec["py-torch"].variants["cuda_arch"].value)
for i, x in enumerate(cuda_arches):
cuda_arches[i] = "{0}.{1}".format(x[0:-1], x[-1])
env.set("TORCH_CUDA_ARCH_LIST", str.join(" ", cuda_arches))
env.set("FORCE_CUDA", "1")
env.set("CUDA_HOME", self.spec["cuda"].prefix)
else:

View file

@ -35,12 +35,6 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage):
def setup_build_environment(self, env):
if "+cuda" in self.spec:
env.set("CUDA_HOME", self.spec["cuda"].prefix)
if self.spec.variants["cuda_arch"].value[0] != "none":
torch_cuda_arch = ";".join(
"{0:.1f}".format(float(i) / 10.0)
for i in self.spec.variants["cuda_arch"].value
)
env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch)
else:
env.unset("CUDA_HOME")

View file

@ -29,11 +29,6 @@ class PyTorchScatter(PythonPackage):
def setup_build_environment(self, env):
if "+cuda" in self.spec:
cuda_arches = list(self.spec["py-torch"].variants["cuda_arch"].value)
for i, x in enumerate(cuda_arches):
cuda_arches[i] = "{0}.{1}".format(x[0:-1], x[-1])
env.set("TORCH_CUDA_ARCH_LIST", str.join(" ", cuda_arches))
env.set("FORCE_CUDA", "1")
env.set("CUDA_HOME", self.spec["cuda"].prefix)
else:

View file

@ -31,11 +31,6 @@ class PyTorchSparse(PythonPackage):
def setup_build_environment(self, env):
if "+cuda" in self.spec:
cuda_arches = list(self.spec["py-torch"].variants["cuda_arch"].value)
for i, x in enumerate(cuda_arches):
cuda_arches[i] = "{0}.{1}".format(x[0:-1], x[-1])
env.set("TORCH_CUDA_ARCH_LIST", str.join(" ", cuda_arches))
env.set("FORCE_CUDA", "1")
env.set("CUDA_HOME", self.spec["cuda"].prefix)
else:

View file

@ -27,11 +27,6 @@ class PyTorchSplineConv(PythonPackage):
def setup_build_environment(self, env):
if "+cuda" in self.spec:
cuda_arches = list(self.spec["py-torch"].variants["cuda_arch"].value)
for i, x in enumerate(cuda_arches):
cuda_arches[i] = "{0}.{1}".format(x[0:-1], x[-1])
env.set("TORCH_CUDA_ARCH_LIST", str.join(" ", cuda_arches))
env.set("FORCE_CUDA", "1")
env.set("CUDA_HOME", self.spec["cuda"].prefix)
else:

View file

@ -473,6 +473,13 @@ def patch(self):
"caffe2/CMakeLists.txt",
)
def torch_cuda_arch_list(self, env):
if "+cuda" in self.spec:
torch_cuda_arch = ";".join(
"{0:.1f}".format(float(i) / 10.0) for i in self.spec.variants["cuda_arch"].value
)
env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch)
def setup_build_environment(self, env):
"""Set environment variables used to control the build.
@ -515,10 +522,8 @@ def enable_or_disable(variant, keyword="USE", var=None):
if "+cuda" in self.spec:
env.set("CUDA_HOME", self.spec["cuda"].prefix) # Linux/macOS
env.set("CUDA_PATH", self.spec["cuda"].prefix) # Windows
torch_cuda_arch = ";".join(
"{0:.1f}".format(float(i) / 10.0) for i in self.spec.variants["cuda_arch"].value
)
env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch)
self.torch_cuda_arch_list(env)
if self.spec.satisfies("%clang"):
for flag in self.spec.compiler_flags["cxxflags"]:
if "gcc-toolchain" in flag:
@ -667,6 +672,9 @@ def enable_or_disable(variant, keyword="USE", var=None):
if self.spec.satisfies("%apple-clang@15:"):
env.append_flags("LDFLAGS", "-Wl,-ld_classic")
def setup_run_environment(self, env):
self.torch_cuda_arch_list(env)
@run_before("install")
def build_amd(self):
if "+rocm" in self.spec:

View file

@ -102,11 +102,6 @@ def setup_build_environment(self, env):
if "+cuda" in self.spec["py-torch"]:
env.set("USE_CUDA", 1)
torch_cuda_arch_list = ";".join(
"{0:.1f}".format(float(i) / 10.0)
for i in self.spec["py-torch"].variants["cuda_arch"].value
)
env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch_list)
else:
env.set("USE_CUDA", 0)

View file

@ -150,11 +150,6 @@ def setup_build_environment(self, env):
if "^cuda" in self.spec:
env.set("CUDA_HOME", self.spec["cuda"].prefix)
torch_cuda_arch_list = ";".join(
"{0:.1f}".format(float(i) / 10.0)
for i in self.spec["py-torch"].variants["cuda_arch"].value
)
env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch_list)
for gpu in ["cuda", "mps"]:
env.set(f"FORCE_{gpu.upper()}", int(f"+{gpu}" in self.spec["py-torch"]))