Put a module object in sys.modules before executing module code (#23269)
The loading protocol mandates that the the module we are going to import needs to be already in sys.modules before its code is executed, so to prevent unbounded recursions and multiple loading. Loading a module from file exits early if the module is already in sys.modules
This commit is contained in:
parent
8f1b701660
commit
219eb09e59
2 changed files with 39 additions and 2 deletions
|
@ -804,6 +804,9 @@ def __repr__(self):
|
||||||
def load_module_from_file(module_name, module_path):
|
def load_module_from_file(module_name, module_path):
|
||||||
"""Loads a python module from the path of the corresponding file.
|
"""Loads a python module from the path of the corresponding file.
|
||||||
|
|
||||||
|
If the module is already in ``sys.modules`` it will be returned as
|
||||||
|
is and not reloaded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_name (str): namespace where the python module will be loaded,
|
module_name (str): namespace where the python module will be loaded,
|
||||||
e.g. ``foo.bar``
|
e.g. ``foo.bar``
|
||||||
|
@ -816,12 +819,28 @@ def load_module_from_file(module_name, module_path):
|
||||||
ImportError: when the module can't be loaded
|
ImportError: when the module can't be loaded
|
||||||
FileNotFoundError: when module_path doesn't exist
|
FileNotFoundError: when module_path doesn't exist
|
||||||
"""
|
"""
|
||||||
|
if module_name in sys.modules:
|
||||||
|
return sys.modules[module_name]
|
||||||
|
|
||||||
|
# This recipe is adapted from https://stackoverflow.com/a/67692/771663
|
||||||
if sys.version_info[0] == 3 and sys.version_info[1] >= 5:
|
if sys.version_info[0] == 3 and sys.version_info[1] >= 5:
|
||||||
import importlib.util
|
import importlib.util
|
||||||
spec = importlib.util.spec_from_file_location( # novm
|
spec = importlib.util.spec_from_file_location( # novm
|
||||||
module_name, module_path)
|
module_name, module_path)
|
||||||
module = importlib.util.module_from_spec(spec) # novm
|
module = importlib.util.module_from_spec(spec) # novm
|
||||||
spec.loader.exec_module(module)
|
# The module object needs to exist in sys.modules before the
|
||||||
|
# loader executes the module code.
|
||||||
|
#
|
||||||
|
# See https://docs.python.org/3/reference/import.html#loading
|
||||||
|
sys.modules[spec.name] = module
|
||||||
|
try:
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
except BaseException:
|
||||||
|
try:
|
||||||
|
del sys.modules[spec.name]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
elif sys.version_info[0] == 3 and sys.version_info[1] < 5:
|
elif sys.version_info[0] == 3 and sys.version_info[1] < 5:
|
||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
loader = importlib.machinery.SourceFileLoader( # novm
|
loader = importlib.machinery.SourceFileLoader( # novm
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import os.path
|
import os.path
|
||||||
|
import sys
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import llnl.util.lang
|
import llnl.util.lang
|
||||||
|
@ -27,7 +28,12 @@ def module_path(tmpdir):
|
||||||
path = os.path.join('/usr', 'bin')
|
path = os.path.join('/usr', 'bin')
|
||||||
"""
|
"""
|
||||||
m.write(content)
|
m.write(content)
|
||||||
return str(m)
|
|
||||||
|
yield str(m)
|
||||||
|
|
||||||
|
# Don't leave garbage in the module system
|
||||||
|
if 'foo' in sys.modules:
|
||||||
|
del sys.modules['foo']
|
||||||
|
|
||||||
|
|
||||||
def test_pretty_date():
|
def test_pretty_date():
|
||||||
|
@ -127,10 +133,22 @@ def test_match_predicate():
|
||||||
|
|
||||||
|
|
||||||
def test_load_modules_from_file(module_path):
|
def test_load_modules_from_file(module_path):
|
||||||
|
# Check prerequisites
|
||||||
|
assert 'foo' not in sys.modules
|
||||||
|
|
||||||
|
# Check that the module is loaded correctly from file
|
||||||
foo = llnl.util.lang.load_module_from_file('foo', module_path)
|
foo = llnl.util.lang.load_module_from_file('foo', module_path)
|
||||||
|
assert 'foo' in sys.modules
|
||||||
assert foo.value == 1
|
assert foo.value == 1
|
||||||
assert foo.path == os.path.join('/usr', 'bin')
|
assert foo.path == os.path.join('/usr', 'bin')
|
||||||
|
|
||||||
|
# Check that the module is not reloaded a second time on subsequent calls
|
||||||
|
foo.value = 2
|
||||||
|
foo = llnl.util.lang.load_module_from_file('foo', module_path)
|
||||||
|
assert 'foo' in sys.modules
|
||||||
|
assert foo.value == 2
|
||||||
|
assert foo.path == os.path.join('/usr', 'bin')
|
||||||
|
|
||||||
|
|
||||||
def test_uniq():
|
def test_uniq():
|
||||||
assert [1, 2, 3] == llnl.util.lang.uniq([1, 2, 3])
|
assert [1, 2, 3] == llnl.util.lang.uniq([1, 2, 3])
|
||||||
|
|
Loading…
Reference in a new issue