mxnet: Add optional cuda_arch spec support, enable CUDA by default (#21266)
This commit is contained in:
parent
b45a31aefe
commit
79afe20bb0
1 changed files with 7 additions and 2 deletions
|
@ -6,7 +6,7 @@
|
|||
from spack import *
|
||||
|
||||
|
||||
class Mxnet(MakefilePackage):
|
||||
class Mxnet(MakefilePackage, CudaPackage):
|
||||
"""MXNet is a deep learning framework
|
||||
designed for both efficiency and flexibility."""
|
||||
|
||||
|
@ -18,7 +18,7 @@ class Mxnet(MakefilePackage):
|
|||
version('1.6.0', sha256='01eb06069c90f33469c7354946261b0a94824bbaf819fd5d5a7318e8ee596def')
|
||||
version('1.3.0', sha256='c00d6fbb2947144ce36c835308e603f002c1eb90a9f4c5a62f4d398154eed4d2')
|
||||
|
||||
variant('cuda', default=False, description='Enable CUDA support')
|
||||
variant('cuda', default=True, description='Enable CUDA support')
|
||||
variant('opencv', default=True, description='Enable OpenCV support')
|
||||
variant('openmp', default=False, description='Enable OpenMP support')
|
||||
variant('profiler', default=False, description='Enable Profiler (for verification and debug only).')
|
||||
|
@ -111,6 +111,11 @@ def build(self, spec, prefix):
|
|||
args.extend(['USE_CUDA_PATH=%s' % spec['cuda'].prefix,
|
||||
'CUDNN_PATH=%s' % spec['cudnn'].prefix,
|
||||
'CUB_INCLUDE=%s' % spec['cub'].prefix.include])
|
||||
# By default, all cuda architectures are built. Restrict only
|
||||
# if a specific list of architectures is specified in cuda_arch.
|
||||
if 'cuda_arch=none' not in spec:
|
||||
cuda_flags = self.cuda_flags(self.spec.variants['cuda_arch'].value)
|
||||
args.append('CUDA_ARCH={0}'.format(' '.join(cuda_flags)))
|
||||
|
||||
make(*args)
|
||||
|
||||
|
|
Loading…
Reference in a new issue