mxnet: Add optional cuda_arch spec support, enable CUDA by default (#21266)

This commit is contained in:
Baptiste Jonglez 2021-01-26 14:58:41 +01:00 committed by GitHub
parent b45a31aefe
commit 79afe20bb0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,7 +6,7 @@
from spack import * from spack import *
class Mxnet(MakefilePackage): class Mxnet(MakefilePackage, CudaPackage):
"""MXNet is a deep learning framework """MXNet is a deep learning framework
designed for both efficiency and flexibility.""" designed for both efficiency and flexibility."""
@ -18,7 +18,7 @@ class Mxnet(MakefilePackage):
version('1.6.0', sha256='01eb06069c90f33469c7354946261b0a94824bbaf819fd5d5a7318e8ee596def') version('1.6.0', sha256='01eb06069c90f33469c7354946261b0a94824bbaf819fd5d5a7318e8ee596def')
version('1.3.0', sha256='c00d6fbb2947144ce36c835308e603f002c1eb90a9f4c5a62f4d398154eed4d2') version('1.3.0', sha256='c00d6fbb2947144ce36c835308e603f002c1eb90a9f4c5a62f4d398154eed4d2')
variant('cuda', default=False, description='Enable CUDA support') variant('cuda', default=True, description='Enable CUDA support')
variant('opencv', default=True, description='Enable OpenCV support') variant('opencv', default=True, description='Enable OpenCV support')
variant('openmp', default=False, description='Enable OpenMP support') variant('openmp', default=False, description='Enable OpenMP support')
variant('profiler', default=False, description='Enable Profiler (for verification and debug only).') variant('profiler', default=False, description='Enable Profiler (for verification and debug only).')
@ -111,6 +111,11 @@ def build(self, spec, prefix):
args.extend(['USE_CUDA_PATH=%s' % spec['cuda'].prefix, args.extend(['USE_CUDA_PATH=%s' % spec['cuda'].prefix,
'CUDNN_PATH=%s' % spec['cudnn'].prefix, 'CUDNN_PATH=%s' % spec['cudnn'].prefix,
'CUB_INCLUDE=%s' % spec['cub'].prefix.include]) 'CUB_INCLUDE=%s' % spec['cub'].prefix.include])
# By default, all cuda architectures are built. Restrict only
# if a specific list of architectures is specified in cuda_arch.
if 'cuda_arch=none' not in spec:
cuda_flags = self.cuda_flags(self.spec.variants['cuda_arch'].value)
args.append('CUDA_ARCH={0}'.format(' '.join(cuda_flags)))
make(*args) make(*args)