diff --git a/var/spack/repos/builtin/packages/mxnet/package.py b/var/spack/repos/builtin/packages/mxnet/package.py index 6f42d78f50..5cc9977ef3 100644 --- a/var/spack/repos/builtin/packages/mxnet/package.py +++ b/var/spack/repos/builtin/packages/mxnet/package.py @@ -6,7 +6,7 @@ from spack import * -class Mxnet(MakefilePackage): +class Mxnet(MakefilePackage, CudaPackage): """MXNet is a deep learning framework designed for both efficiency and flexibility.""" @@ -18,7 +18,7 @@ class Mxnet(MakefilePackage): version('1.6.0', sha256='01eb06069c90f33469c7354946261b0a94824bbaf819fd5d5a7318e8ee596def') version('1.3.0', sha256='c00d6fbb2947144ce36c835308e603f002c1eb90a9f4c5a62f4d398154eed4d2') - variant('cuda', default=False, description='Enable CUDA support') + variant('cuda', default=True, description='Enable CUDA support') variant('opencv', default=True, description='Enable OpenCV support') variant('openmp', default=False, description='Enable OpenMP support') variant('profiler', default=False, description='Enable Profiler (for verification and debug only).') @@ -111,6 +111,11 @@ def build(self, spec, prefix): args.extend(['USE_CUDA_PATH=%s' % spec['cuda'].prefix, 'CUDNN_PATH=%s' % spec['cudnn'].prefix, 'CUB_INCLUDE=%s' % spec['cub'].prefix.include]) + # By default, all cuda architectures are built. Restrict only + # if a specific list of architectures is specified in cuda_arch. + if 'cuda_arch=none' not in spec: + cuda_flags = self.cuda_flags(self.spec.variants['cuda_arch'].value) + args.append('CUDA_ARCH={0}'.format(' '.join(cuda_flags))) make(*args)