diff --git a/var/spack/repos/builtin/packages/lbann/package.py b/var/spack/repos/builtin/packages/lbann/package.py index b0ca29b709..7dee4809ae 100644 --- a/var/spack/repos/builtin/packages/lbann/package.py +++ b/var/spack/repos/builtin/packages/lbann/package.py @@ -62,6 +62,7 @@ class Lbann(CMakePackage, CudaPackage): variant('vision', default=False, description='Builds with support for image processing data with OpenCV') variant('vtune', default=False, description='Builds with support for Intel VTune') + variant('onednn', default=False, description='Support for OneDNN') variant('nvshmem', default=False, description='Support for NVSHMEM') # Variant Conflicts @@ -174,6 +175,7 @@ class Lbann(CMakePackage, CudaPackage): depends_on('llvm-openmp', when='%apple-clang') + depends_on('onednn cpu_runtime=omp gpu_runtime=none', when='+onednn') depends_on('nvshmem', when='+nvshmem') generator = 'Ninja' @@ -227,6 +229,7 @@ def cmake_args(self): '-DLBANN_WITH_CUDNN:BOOL=%s' % ('+cuda' in spec), '-DLBANN_WITH_NVSHMEM:BOOL=%s' % ('+nvshmem' in spec), '-DLBANN_WITH_FFT:BOOL=%s' % ('+fft' in spec), + '-DLBANN_WITH_ONEDNN:BOOL=%s' % ('+onednn' in spec), '-DLBANN_WITH_TBINF=OFF', '-DLBANN_WITH_UNIT_TESTING:BOOL=%s' % (self.run_tests), '-DLBANN_WITH_VISION:BOOL=%s' % ('+vision' in spec),