Add new dependencies for rocm variant for py-torch recipe (#32100)

* Cmake module path updated for ROCm 5.2

* nccl is already set below for PyTorch 1.6+

* Threadpool is set below for PyTorch 1.6+
This commit is contained in:
renjithravindrankannath 2022-08-12 23:17:20 -07:00 committed by GitHub
parent 4ec31003aa
commit b32cb5765c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -195,6 +195,7 @@ class PyTorch(PythonPackage, CudaPackage):
depends_on("rocfft")
depends_on("rocblas")
depends_on("miopen-hip")
depends_on("rocminfo")
# https://github.com/pytorch/pytorch/issues/60332
# depends_on('xnnpack@2022-02-16', when='@1.12:+xnnpack')
# depends_on('xnnpack@2021-06-21', when='@1.10:1.11+xnnpack')
@ -427,7 +428,6 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
env.set("ROCFFT_PATH", self.spec["rocfft"].prefix)
env.set("HIPFFT_PATH", self.spec["hipfft"].prefix)
env.set("HIPSPARSE_PATH", self.spec["hipsparse"].prefix)
env.set("THRUST_PATH", self.spec["rocthrust"].prefix.include)
env.set("HIP_PATH", self.spec["hip"].prefix)
env.set("HIPRAND_PATH", self.spec["rocrand"].prefix)
env.set("ROCRAND_PATH", self.spec["rocrand"].prefix)
@ -437,6 +437,8 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
env.set("HIPCUB_PATH", self.spec["hipcub"].prefix)
env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix)
env.set("ROCTRACER_PATH", self.spec["roctracer-dev"].prefix)
if self.spec.satisfies("^hip@5.2.0:"):
env.set("CMAKE_MODULE_PATH", self.spec["hip"].prefix.lib.cmake.hip)
enable_or_disable("cudnn")
if "+cudnn" in self.spec: