Use consistent method of checking for presence of info in connection (#27694)

Fixes #27652

Ensure that mirror's to_dict function returns a syaml_dict object for all code
paths.

Switch to using the .get function for accessing the potential information from
the S3 mirror objects.  If the key is not there, it will gracefully return
None instead of failing with a KeyError

Additionally, check that the connection object is a dictionary before trying
to "get" from it.

Add a test for the capturing of the new S3 information.
This commit is contained in:
Joseph Snyder 2021-12-22 10:15:49 -05:00 committed by GitHub
parent 522a7c8ee0
commit 34873f5fe7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 29 deletions

View file

@ -90,7 +90,9 @@ def from_json(stream, name=None):
def to_dict(self): def to_dict(self):
if self._push_url is None: if self._push_url is None:
return self._fetch_url return syaml_dict([
('fetch', self._fetch_url),
('push', self._fetch_url)])
else: else:
return syaml_dict([ return syaml_dict([
('fetch', self._fetch_url), ('fetch', self._fetch_url),
@ -105,12 +107,12 @@ def from_dict(d, name=None):
def display(self, max_len=0): def display(self, max_len=0):
if self._push_url is None: if self._push_url is None:
_display_mirror_entry(max_len, self._name, self._fetch_url) _display_mirror_entry(max_len, self._name, self.fetch_url)
else: else:
_display_mirror_entry( _display_mirror_entry(
max_len, self._name, self._fetch_url, "fetch") max_len, self._name, self.fetch_url, "fetch")
_display_mirror_entry( _display_mirror_entry(
max_len, self._name, self._push_url, "push") max_len, self._name, self.push_url, "push")
def __str__(self): def __str__(self):
name = self._name name = self._name
@ -145,8 +147,8 @@ def name(self):
def get_profile(self, url_type): def get_profile(self, url_type):
if isinstance(self._fetch_url, dict): if isinstance(self._fetch_url, dict):
if url_type == "push": if url_type == "push":
return self._push_url['profile'] return self._push_url.get('profile', None)
return self._fetch_url['profile'] return self._fetch_url.get('profile', None)
else: else:
return None return None
@ -159,8 +161,8 @@ def set_profile(self, url_type, profile):
def get_access_pair(self, url_type): def get_access_pair(self, url_type):
if isinstance(self._fetch_url, dict): if isinstance(self._fetch_url, dict):
if url_type == "push": if url_type == "push":
return self._push_url['access_pair'] return self._push_url.get('access_pair', None)
return self._fetch_url['access_pair'] return self._fetch_url.get('access_pair', None)
else: else:
return None return None
@ -173,8 +175,8 @@ def set_access_pair(self, url_type, connection_tuple):
def get_endpoint_url(self, url_type): def get_endpoint_url(self, url_type):
if isinstance(self._fetch_url, dict): if isinstance(self._fetch_url, dict):
if url_type == "push": if url_type == "push":
return self._push_url['endpoint_url'] return self._push_url.get('endpoint_url', None)
return self._fetch_url['endpoint_url'] return self._fetch_url.get('endpoint_url', None)
else: else:
return None return None
@ -187,8 +189,8 @@ def set_endpoint_url(self, url_type, url):
def get_access_token(self, url_type): def get_access_token(self, url_type):
if isinstance(self._fetch_url, dict): if isinstance(self._fetch_url, dict):
if url_type == "push": if url_type == "push":
return self._push_url['access_token'] return self._push_url.get('access_token', None)
return self._fetch_url['access_token'] return self._fetch_url.get('access_token', None)
else: else:
return None return None

View file

@ -246,6 +246,29 @@ def get_object(self, Bucket=None, Key=None):
raise self.ClientError raise self.ClientError
def test_gather_s3_information(monkeypatch, capfd):
mock_connection_data = {"access_token": "AAAAAAA",
"profile": "SPacKDeV",
"access_pair": ("SPA", "CK"),
"endpoint_url": "https://127.0.0.1:8888"}
session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mock_connection_data) # noqa: E501
# Session args are used to create the S3 Session object
assert "aws_session_token" in session_args
assert session_args.get("aws_session_token") == "AAAAAAA"
assert "aws_access_key_id" in session_args
assert session_args.get("aws_access_key_id") == "SPA"
assert "aws_secret_access_key" in session_args
assert session_args.get("aws_secret_access_key") == "CK"
assert "profile_name" in session_args
assert session_args.get("profile_name") == "SPacKDeV"
# In addition to the session object, use the client_args to create the s3
# Client object
assert "endpoint_url" in client_args
def test_remove_s3_url(monkeypatch, capfd): def test_remove_s3_url(monkeypatch, capfd):
fake_s3_url = 's3://my-bucket/subdirectory/mirror' fake_s3_url = 's3://my-bucket/subdirectory/mirror'

View file

@ -27,6 +27,30 @@ def _parse_s3_endpoint_url(endpoint_url):
return endpoint_url return endpoint_url
def get_mirror_s3_connection_info(connection):
s3_connection = {}
s3_connection_is_dict = connection and isinstance(connection, dict)
if s3_connection_is_dict:
if connection.get("access_token"):
s3_connection["aws_session_token"] = connection["access_token"]
if connection.get("access_pair"):
s3_connection["aws_access_key_id"] = connection["access_pair"][0]
s3_connection["aws_secret_access_key"] = connection["access_pair"][1]
if connection.get("profile"):
s3_connection["profile_name"] = connection["profile"]
s3_client_args = {"use_ssl": spack.config.get('config:verify_ssl')}
endpoint_url = os.environ.get('S3_ENDPOINT_URL')
if endpoint_url:
s3_client_args['endpoint_url'] = _parse_s3_endpoint_url(endpoint_url)
elif s3_connection_is_dict and connection.get("endpoint_url"):
s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(connection["endpoint_url"]) # noqa: E501
return (s3_connection, s3_client_args)
def create_s3_session(url, connection={}): def create_s3_session(url, connection={}):
url = url_util.parse(url) url = url_util.parse(url)
if url.scheme != 's3': if url.scheme != 's3':
@ -40,25 +64,9 @@ def create_s3_session(url, connection={}):
from boto3 import Session from boto3 import Session
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
s3_connection = {} s3_connection, s3_client_args = get_mirror_s3_connection_info(connection)
if connection:
if connection['access_token']:
s3_connection["aws_session_token"] = connection["access_token"]
if connection["access_pair"][0]:
s3_connection["aws_access_key_id"] = connection["access_pair"][0]
s3_connection["aws_secret_access_key"] = connection["access_pair"][1]
if connection["profile"]:
s3_connection["profile_name"] = connection["profile"]
session = Session(**s3_connection) session = Session(**s3_connection)
s3_client_args = {"use_ssl": spack.config.get('config:verify_ssl')}
endpoint_url = os.environ.get('S3_ENDPOINT_URL')
if endpoint_url:
s3_client_args['endpoint_url'] = _parse_s3_endpoint_url(endpoint_url)
elif connection and 'endpoint_url' in connection:
s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(connection["endpoint_url"]) # noqa: E501
# if no access credentials provided above, then access anonymously # if no access credentials provided above, then access anonymously
if not session.get_credentials(): if not session.get_credentials():
from botocore import UNSIGNED from botocore import UNSIGNED