-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(aws): Add a test_connection method #4563
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #4563 /- ##
==========================================
Coverage 88.89% 89.11% 0.21%
==========================================
Files 907 913 6
Lines 27612 27866 254
==========================================
Hits 24547 24834 287
Misses 3065 3032 -33 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good! I left some comments and I also wondered how is the API going to use this?
As far as I know, we only store the provider_id
(relevant to what I see in the test_connection
expected arguments). Would that be the external_id
field? Would that be enough to test the connection?
It mostly depends on the AWS Credentials method used, you can see more about it here https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html. Currently Prowler supports out of the box methods 3, 4, 5 and 6. For example using the environment variables or the profile/config will make the SDK to automatically look for them and then try to authenticate with AWS.
The
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the explanation too!
We are still discussing about the format of the @staticmethod
def test_connection(
session: Session = None,
profile: str = None,
aws_region: str = AWS_STS_GLOBAL_ENDPOINT_REGION,
role_arn: str = None,
role_session_name: str = ROLE_SESSION_NAME,
session_duration: int = 3600,
external_id: str = None,
mfa_enabled: bool = False,
raise_on_exception: bool = True,
) -> TestConnection:
"""
Validates AWS credentials using the provided session and AWS region.
If no session is provided, the method will create a new session using the Boto3 default session.
Args:
session (Session): The AWS session object.
profile (str): The AWS profile to use for the session.
aws_region (str): The AWS region to validate the credentials in.
role_arn (str): The ARN of the IAM role to assume.
role_session_name (str): The name of the role session.
session_duration (int): The duration of the assumed role session in seconds.
external_id (str): The external ID to use when assuming the role.
mfa_enabled (bool): Whether MFA (Multi-Factor Authentication) is enabled.
raise_on_exception (bool): Whether to raise an exception if an error occurs.
Returns:
TestConnection: A named tuple containing the result of the validation.
- connected (bool): Indicates whether the validation was successful.
- result (AWSCallerIdentity): An object representing the caller's identity if the validation was successful.
- error (Exception): An exception object if an error occurs during the validation.
Raises:
Exception: If an error occurs during the validation process.
Examples:
>>> AwsProvider.test_connection(
role_arn="arn:aws:iam::111122223333:role/ProwlerRole",
external_id="67f7a641-ecb0-4f6d-921d-3587febd379c"
)
AWSCallerIdentity(user_id='AROAAAAAAAAAAAAAAAAAA:ProwlerAssessmentSession', account='111122223333', arn=ARN(arn='arn:aws:sts::111122223333:assumed-role/ProwlerRole/ProwlerAssessmentSession', partition='aws', service='sts', region=None, account_id='111122223333', resource='ProwlerRole/ProwlerAssessmentSession', resource_type='assumed-role'), region='us-east-1')
>>> AwsProvider.test_connection(profile="test")
AWSCallerIdentity(user_id='AROAAAAAAAAAAAAAAAAAA:test-user', account='111122223333', arn=ARN(arn='arn:aws:sts::111122223333:user/test-user', partition='aws', service='sts', region=None, account_id='111122223333', resource='test-user', resource_type='user'), region='us-east-1')
"""
try:
# Create the default session if no session is given
session = (
session if session else AwsProvider.setup_session(mfa_enabled, profile)
)
# Test Connection using the IAM Role
if role_arn:
session_duration = validate_session_duration(session_duration)
role_session_name = validate_role_session_name(role_session_name)
role_arn = parse_iam_credentials_arn(role_arn)
assumed_role_information = AWSAssumeRoleInfo(
role_arn=role_arn,
session_duration=session_duration,
external_id=external_id,
mfa_enabled=mfa_enabled,
role_session_name=role_session_name,
)
assumed_role_credentials = AwsProvider.assume_role(
session,
assumed_role_information,
)
session = Session(
aws_access_key_id=assumed_role_credentials.aws_access_key_id,
aws_secret_access_key=assumed_role_credentials.aws_secret_access_key,
aws_session_token=assumed_role_credentials.aws_session_token,
region_name=aws_region,
profile_name=profile,
)
sts_client = AwsProvider.create_sts_session(session, aws_region)
caller_identity = sts_client.get_caller_identity()
# Include the region where the caller_identity has validated the credentials
return TestConnection(
connected=True,
result=AWSCallerIdentity(
user_id=caller_identity.get("UserId"),
account=caller_identity.get("Account"),
arn=ARN(caller_identity.get("Arn")),
region=aws_region,
),
)
except (ClientError, ProfileNotFound) as credentials_error:
logger.error(
f"{credentials_error.__class__.__name__}[{credentials_error.__traceback__.tb_lineno}]: {credentials_error}"
)
if raise_on_exception:
raise credentials_error
else:
return TestConnection(error=credentials_error)
except ArgumentTypeError as validation_error:
logger.error(
f"{validation_error.__class__.__name__}[{validation_error.__traceback__.tb_lineno}]: {validation_error}"
)
if raise_on_exception:
raise validation_error
else:
return TestConnection(error=validation_error)
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
raise error and the @dataclass
class TestConnection:
_connected: bool = False
_error: Exception = None
_result: Any = None
@property
def connected(self) -> bool:
return self._connected
@property
def error(self) -> Exception:
return self._error
@property
def result(self) -> Any:
return self._result This PR needs to be updated accordingly. Also the init should call it with:
|
Do not merge until the above change is addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation looks better, great job. I left some small requested changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments but 😍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥇 🐐
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job!
Context
We need to be able to test the connection to each provider using one of the available authentication methods, starting from the environment variables.
Description
test_connection
method for the AWS provider.test_connection
abstract method in theProvider
class but not enforced for now not to break anything.Notes for reviewers
sys.exit()
in favor of catching the exceptions during the__init__
.License
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.