amrex: add ROCm support (#20809)

This commit is contained in:
mic84 2021-01-15 01:14:55 -08:00 committed by GitHub
parent 3b9144a4a4
commit 7762b8acdd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,7 +6,7 @@
from spack import * from spack import *
class Amrex(CMakePackage, CudaPackage): class Amrex(CMakePackage, CudaPackage, ROCmPackage):
"""AMReX is a publicly available software framework designed """AMReX is a publicly available software framework designed
for building massively parallel block- structured adaptive for building massively parallel block- structured adaptive
mesh refinement (AMR) applications.""" mesh refinement (AMR) applications."""
@ -81,6 +81,7 @@ class Amrex(CMakePackage, CudaPackage):
depends_on('cmake@3.14:', type='build', when='@19.04:') depends_on('cmake@3.14:', type='build', when='@19.04:')
# cmake @3.17: is necessary to handle cuda @11: correctly # cmake @3.17: is necessary to handle cuda @11: correctly
depends_on('cmake@3.17:', type='build', when='^cuda @11:') depends_on('cmake@3.17:', type='build', when='^cuda @11:')
depends_on('rocrand', type='build', when='+rocm')
conflicts('%apple-clang') conflicts('%apple-clang')
conflicts('%clang') conflicts('%clang')
@ -113,6 +114,8 @@ class Amrex(CMakePackage, CudaPackage):
conflicts('cuda_arch=21', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5') conflicts('cuda_arch=21', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5')
conflicts('cuda_arch=30', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5') conflicts('cuda_arch=30', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5')
conflicts('cuda_arch=32', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5') conflicts('cuda_arch=32', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5')
conflicts('+rocm', when='@:20.11', msg='AMReX HIP support needs AMReX newer than version 20.11')
conflicts('+cuda', when='+rocm', msg='CUDA and HIP support are exclusive')
def url_for_version(self, version): def url_for_version(self, version):
if version >= Version('20.05'): if version >= Version('20.05'):
@ -200,4 +203,10 @@ def cmake_args(self):
cuda_arch = self.spec.variants['cuda_arch'].value cuda_arch = self.spec.variants['cuda_arch'].value
args.append('-DCUDA_ARCH=' + self.get_cuda_arch_string(cuda_arch)) args.append('-DCUDA_ARCH=' + self.get_cuda_arch_string(cuda_arch))
if '+rocm' in self.spec:
args.append('-DCMAKE_CXX_COMPILER={0}'.format(self.spec['hip'].hipcc))
args.append('-DAMReX_GPU_BACKEND=HIP')
targets = self.spec.variants['amdgpu_target'].value
args.append('-DAMReX_AMD_ARCH=' + ';'.join(str(x) for x in targets))
return args return args