From 7762b8acddf62806607adb1c246f7ca5536450b5 Mon Sep 17 00:00:00 2001 From: mic84 Date: Fri, 15 Jan 2021 01:14:55 -0800 Subject: [PATCH] amrex: add ROCm support (#20809) --- var/spack/repos/builtin/packages/amrex/package.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/var/spack/repos/builtin/packages/amrex/package.py b/var/spack/repos/builtin/packages/amrex/package.py index d3414e575e..8c077e7921 100644 --- a/var/spack/repos/builtin/packages/amrex/package.py +++ b/var/spack/repos/builtin/packages/amrex/package.py @@ -6,7 +6,7 @@ from spack import * -class Amrex(CMakePackage, CudaPackage): +class Amrex(CMakePackage, CudaPackage, ROCmPackage): """AMReX is a publicly available software framework designed for building massively parallel block- structured adaptive mesh refinement (AMR) applications.""" @@ -81,6 +81,7 @@ class Amrex(CMakePackage, CudaPackage): depends_on('cmake@3.14:', type='build', when='@19.04:') # cmake @3.17: is necessary to handle cuda @11: correctly depends_on('cmake@3.17:', type='build', when='^cuda @11:') + depends_on('rocrand', type='build', when='+rocm') conflicts('%apple-clang') conflicts('%clang') @@ -113,6 +114,8 @@ class Amrex(CMakePackage, CudaPackage): conflicts('cuda_arch=21', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5') conflicts('cuda_arch=30', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5') conflicts('cuda_arch=32', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5') + conflicts('+rocm', when='@:20.11', msg='AMReX HIP support needs AMReX newer than version 20.11') + conflicts('+cuda', when='+rocm', msg='CUDA and HIP support are exclusive') def url_for_version(self, version): if version >= Version('20.05'): @@ -200,4 +203,10 @@ def cmake_args(self): cuda_arch = self.spec.variants['cuda_arch'].value args.append('-DCUDA_ARCH=' + self.get_cuda_arch_string(cuda_arch)) + if '+rocm' in self.spec: + args.append('-DCMAKE_CXX_COMPILER={0}'.format(self.spec['hip'].hipcc)) + args.append('-DAMReX_GPU_BACKEND=HIP') + targets = self.spec.variants['amdgpu_target'].value + args.append('-DAMReX_AMD_ARCH=' + ';'.join(str(x) for x in targets)) + return args