Put a module object in sys.modules before executing module code (#23269)

The loading protocol mandates that the the module we are going
to import needs to be already in sys.modules before its code is
executed, so to prevent unbounded recursions and multiple loading.

Loading a module from file exits early if the module is already
in sys.modules
This commit is contained in:
Massimiliano Culpo 2021-05-06 11:53:40 +02:00 committed by GitHub
parent 8f1b701660
commit 219eb09e59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 2 deletions

View file

@ -804,6 +804,9 @@ def __repr__(self):
def load_module_from_file(module_name, module_path):
"""Loads a python module from the path of the corresponding file.
If the module is already in ``sys.modules`` it will be returned as
is and not reloaded.
Args:
module_name (str): namespace where the python module will be loaded,
e.g. ``foo.bar``
@ -816,12 +819,28 @@ def load_module_from_file(module_name, module_path):
ImportError: when the module can't be loaded
FileNotFoundError: when module_path doesn't exist
"""
if module_name in sys.modules:
return sys.modules[module_name]
# This recipe is adapted from https://stackoverflow.com/a/67692/771663
if sys.version_info[0] == 3 and sys.version_info[1] >= 5:
import importlib.util
spec = importlib.util.spec_from_file_location( # novm
module_name, module_path)
module = importlib.util.module_from_spec(spec) # novm
# The module object needs to exist in sys.modules before the
# loader executes the module code.
#
# See https://docs.python.org/3/reference/import.html#loading
sys.modules[spec.name] = module
try:
spec.loader.exec_module(module)
except BaseException:
try:
del sys.modules[spec.name]
except KeyError:
pass
raise
elif sys.version_info[0] == 3 and sys.version_info[1] < 5:
import importlib.machinery
loader = importlib.machinery.SourceFileLoader( # novm

View file

@ -6,6 +6,7 @@
import pytest
import os.path
import sys
from datetime import datetime, timedelta
import llnl.util.lang
@ -27,7 +28,12 @@ def module_path(tmpdir):
path = os.path.join('/usr', 'bin')
"""
m.write(content)
return str(m)
yield str(m)
# Don't leave garbage in the module system
if 'foo' in sys.modules:
del sys.modules['foo']
def test_pretty_date():
@ -127,10 +133,22 @@ def test_match_predicate():
def test_load_modules_from_file(module_path):
# Check prerequisites
assert 'foo' not in sys.modules
# Check that the module is loaded correctly from file
foo = llnl.util.lang.load_module_from_file('foo', module_path)
assert 'foo' in sys.modules
assert foo.value == 1
assert foo.path == os.path.join('/usr', 'bin')
# Check that the module is not reloaded a second time on subsequent calls
foo.value = 2
foo = llnl.util.lang.load_module_from_file('foo', module_path)
assert 'foo' in sys.modules
assert foo.value == 2
assert foo.path == os.path.join('/usr', 'bin')
def test_uniq():
assert [1, 2, 3] == llnl.util.lang.uniq([1, 2, 3])