Enable Tensorflow for ROCm. Add ROCm dependencies. (#32248)

* Build Tensorflow using the fork for rocm. Initial commit

* re-order the versions

* fix style errors

* address review comments

* add conflicts for rocm version

* address review comments

* remove rocm variant as its added by ROCmPackage
This commit is contained in:
Sreenivasa Murthy Kolam 2022-08-19 18:50:38 -07:00 committed by GitHub
parent 5590cad1ef
commit 11a4f5e25d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,7 +10,7 @@
from spack.package import * from spack.package import *
class PyTensorflow(Package, CudaPackage): class PyTensorflow(Package, CudaPackage, ROCmPackage):
"""An Open Source Machine Learning Framework for Everyone. """An Open Source Machine Learning Framework for Everyone.
TensorFlow is an end-to-end open source platform for machine learning. It has a TensorFlow is an end-to-end open source platform for machine learning. It has a
@ -35,6 +35,11 @@ class PyTensorflow(Package, CudaPackage):
version("2.8.2", sha256="b3f860c02c22a30e9787e2548ca252ab289a76b7778af6e9fa763d4aafd904c7") version("2.8.2", sha256="b3f860c02c22a30e9787e2548ca252ab289a76b7778af6e9fa763d4aafd904c7")
version("2.8.1", sha256="4b487a63d6f0c1ca46a2ac37ba4687eabdc3a260c222616fa414f6df73228cec") version("2.8.1", sha256="4b487a63d6f0c1ca46a2ac37ba4687eabdc3a260c222616fa414f6df73228cec")
version("2.8.0", sha256="66b953ae7fba61fd78969a2e24e350b26ec116cf2e6a7eb93d02c63939c6f9f7") version("2.8.0", sha256="66b953ae7fba61fd78969a2e24e350b26ec116cf2e6a7eb93d02c63939c6f9f7")
version(
"2.7.4-rocm-enhanced",
sha256="45b79c125edfdc008274f1b150d8b5a53b3ff4713fd1ad1ff4738f515aad8191",
url="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/archive/refs/tags/v2.7.4-rocm-enhanced.tar.gz",
)
version("2.7.3", sha256="b576c2e124cd6d4d04cbfe985430a0d955614e882172b2258217f0ec9b61f39b") version("2.7.3", sha256="b576c2e124cd6d4d04cbfe985430a0d955614e882172b2258217f0ec9b61f39b")
version("2.7.2", sha256="b3c8577f3b7cc82368ff7f9315821d506abd2f716ea6692977d255b7d8bc54c0") version("2.7.2", sha256="b3c8577f3b7cc82368ff7f9315821d506abd2f716ea6692977d255b7d8bc54c0")
version("2.7.1", sha256="abebe2cf5ca379e18071693ca5f45b88ade941b16258a21cc1f12d77d5387a21") version("2.7.1", sha256="abebe2cf5ca379e18071693ca5f45b88ade941b16258a21cc1f12d77d5387a21")
@ -128,7 +133,6 @@ class PyTensorflow(Package, CudaPackage):
variant("ngraph", default=False, description="Build with Intel nGraph support") variant("ngraph", default=False, description="Build with Intel nGraph support")
variant("opencl", default=False, description="Build with OpenCL SYCL support") variant("opencl", default=False, description="Build with OpenCL SYCL support")
variant("computecpp", default=False, description="Build with ComputeCPP support") variant("computecpp", default=False, description="Build with ComputeCPP support")
variant("rocm", default=False, description="Build with ROCm support")
variant("tensorrt", default=False, description="Build with TensorRT support") variant("tensorrt", default=False, description="Build with TensorRT support")
variant("cuda", default=sys.platform != "darwin", description="Build with CUDA support") variant("cuda", default=sys.platform != "darwin", description="Build with CUDA support")
variant( variant(
@ -279,6 +283,21 @@ class PyTensorflow(Package, CudaPackage):
# type=('build', 'run'), when='@2.8:') # type=('build', 'run'), when='@2.8:')
# depends_on('py-tensorflow-io-gcs-filesystem@0.21:', # depends_on('py-tensorflow-io-gcs-filesystem@0.21:',
# type=('build', 'run'), when='@2.7') # type=('build', 'run'), when='@2.7')
with when("+rocm"):
depends_on("hip")
depends_on("rocrand")
depends_on("rocblas")
depends_on("rocfft")
depends_on("hipfft")
depends_on("rccl")
depends_on("hipsparse")
depends_on("hipcub")
depends_on("rocsolver")
depends_on("rocprim")
depends_on("miopen-hip")
depends_on("llvm-amdgpu")
depends_on("hsa-rocr-dev")
depends_on("rocminfo")
if sys.byteorder == "little": if sys.byteorder == "little":
# Only builds correctly on little-endian machines # Only builds correctly on little-endian machines
@ -357,7 +376,6 @@ class PyTensorflow(Package, CudaPackage):
conflicts("+opencl", when="@:0.11") conflicts("+opencl", when="@:0.11")
conflicts("+computecpp", when="@:0.11") conflicts("+computecpp", when="@:0.11")
conflicts("+computecpp", when="~opencl") conflicts("+computecpp", when="~opencl")
conflicts("+rocm", when="@:1.11")
conflicts("+cuda", when="platform=darwin", msg="There is no GPU support for macOS") conflicts("+cuda", when="platform=darwin", msg="There is no GPU support for macOS")
conflicts( conflicts(
"cuda_arch=none", "cuda_arch=none",
@ -416,6 +434,8 @@ class PyTensorflow(Package, CudaPackage):
conflicts("platform=darwin target=aarch64:", when="@:2.4") conflicts("platform=darwin target=aarch64:", when="@:2.4")
# https://github.com/tensorflow/tensorflow/pull/39225 # https://github.com/tensorflow/tensorflow/pull/39225
conflicts("target=aarch64:", when="@:2.2") conflicts("target=aarch64:", when="@:2.2")
conflicts("~rocm", when="@2.7.4-rocm-enhanced")
conflicts("+rocm", when="@:2.7.4-a,2.7.4.0:")
# TODO: why is this needed? # TODO: why is this needed?
patch("url-zlib.patch", when="@0.10.0") patch("url-zlib.patch", when="@0.10.0")
@ -720,6 +740,11 @@ def setup_build_environment(self, env):
env.set("INCLUDEDIR", spec["protobuf"].prefix.include) env.set("INCLUDEDIR", spec["protobuf"].prefix.include)
def patch(self): def patch(self):
filter_file(
'"-U_FORTIFY_SOURCE",',
'"-U_FORTIFY_SOURCE", "-I%s",' % self.spec["protobuf"].prefix.include,
"third_party/gpus/crosstool/BUILD.rocm.tpl",
)
if self.spec.satisfies("@2.3.0:"): if self.spec.satisfies("@2.3.0:"):
filter_file( filter_file(
"deps = protodeps + well_known_proto_libs(),", "deps = protodeps + well_known_proto_libs(),",
@ -976,6 +1001,9 @@ def build(self, spec, prefix):
if "+cuda" in spec: if "+cuda" in spec:
args.append("--config=cuda") args.append("--config=cuda")
if "+rocm" in spec:
args.append("--config=rocm")
if "~aws" in spec: if "~aws" in spec:
args.append("--config=noaws") args.append("--config=noaws")