PythonPackage: fix libs/headers attributes (#32970)

This commit is contained in:
Adam J. Stewart 2022-10-10 08:26:30 -05:00 committed by GitHub
parent bfbd411091
commit 7cb745b03a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 34 deletions

View file

@ -243,8 +243,8 @@ def headers(self):
"""Discover header files in platlib."""
# Headers may be in either location
include = inspect.getmodule(self).include
platlib = inspect.getmodule(self).platlib
include = self.prefix.join(self.spec["python"].package.include)
platlib = self.prefix.join(self.spec["python"].package.platlib)
headers = find_all_headers(include) + find_all_headers(platlib)
if headers:
@ -259,7 +259,7 @@ def libs(self):
# Remove py- prefix in package name
library = "lib" + self.spec.name[3:].replace("-", "?")
root = inspect.getmodule(self).platlib
root = self.prefix.join(self.spec["python"].package.platlib)
for shared in [True, False]:
libs = find_libraries(library, root, shared=shared, recursive=True)

View file

@ -321,37 +321,6 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
when="@:1.9.1 ^cuda@11.4.100:",
)
@property
def headers(self):
"""Discover header files in platlib."""
# Headers may be in either location
include = join_path(self.prefix, self.spec["python"].package.include)
platlib = join_path(self.prefix, self.spec["python"].package.platlib)
headers = find_all_headers(include) + find_all_headers(platlib)
if headers:
return headers
msg = "Unable to locate {} headers in {} or {}"
raise NoHeadersError(msg.format(self.spec.name, include, platlib))
@property
def libs(self):
"""Discover libraries in platlib."""
# Remove py- prefix in package name
library = "lib" + self.spec.name[3:].replace("-", "?")
root = join_path(self.prefix, self.spec["python"].package.platlib)
for shared in [True, False]:
libs = find_libraries(library, root, shared=shared, recursive=True)
if libs:
return libs
msg = "Unable to recursively locate {} libraries in {}"
raise NoLibrariesError(msg.format(self.spec.name, root))
@when("@1.5.0:")
def patch(self):
# https://github.com/pytorch/pytorch/issues/52208