Source code for infrahouse_core.aws.ec2_instance

"""
Module for EC2Instance class - a class tha represents an EC2 instance.
"""

import warnings
from enum import Enum
from logging import getLogger
from time import sleep
from typing import Optional

from boto3 import Session
from botocore.client import BaseClient
from botocore.exceptions import ClientError
from cached_property import cached_property_with_ttl
from ec2_metadata import ec2_metadata

from infrahouse_core.aws import get_client
from infrahouse_core.timeout import timeout
from infrahouse_core.validation import (
    validate_instance_id,
    validate_region,
    validate_role_arn,
)

LOG = getLogger(__name__)


[docs] class CommandStatus(Enum): """ Enum representing possible command statuses for EC2 instance operations. Attributes: - ``PENDING``: The command is pending execution. - ``IN_PROGRESS``: The command is currently in progress. - ``DELAYED``: The command execution has been delayed. - ``SUCCESS``: The command executed successfully. - ``CANCELLED``: The command execution was cancelled. - ``TIMED_OUT``: The command execution has timed out. - ``FAILED``: The command execution failed. - ``CANCELLING``: The command is in the process of being cancelled. """ PENDING = "Pending" IN_PROGRESS = "InProgress" DELAYED = "Delayed" SUCCESS = "Success" CANCELLED = "Cancelled" TIMED_OUT = "TimedOut" FAILED = "Failed" CANCELLING = "Cancelling"
[docs] class EC2Instance: """ EC2Instance represents an EC2 instance. :param instance_id: Instance id. If omitted, the local instance is read from metadata. :type instance_id: str """ def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments self, instance_id: str = None, region: str = None, ec2_client: Session = None, ssm_client: Session = None, role_arn: str = None, session: Session = None, ): """ :param instance_id: Instance id. If omitted, the local instance is read from metadata. :type instance_id: str :param region: AWS region to connect to. If omitted, the region is read from the instance metadata. :type region: str :param ec2_client: Boto3 EC2 client. If omitted, a client is created using the region and credentials. :type ec2_client: Session :param ssm_client: Boto3 SSM client. If omitted, a client is created using the region and credentials. :type ssm_client: Session :param role_arn: Use this IAM role to create boto3 clients. :type role_arn: str :param session: Pre-configured ``boto3.Session``. When provided, clients are created from this session instead of via :func:`get_client`. :type session: boto3.Session """ if ec2_client is not None: warnings.warn( "'ec2_client' is deprecated and will be removed in a future version. Pass role_arn instead.", DeprecationWarning, stacklevel=2, ) if ssm_client is not None: warnings.warn( "'ssm_client' is deprecated and will be removed in a future version. Pass role_arn instead.", DeprecationWarning, stacklevel=2, ) # Validate input parameters validate_instance_id(instance_id) validate_region(region) validate_role_arn(role_arn) self._instance_id = instance_id self._region = region self._ec2_client = ec2_client self._ssm_client = ssm_client self._role_arn = role_arn self._session = session @property def availability_zone(self) -> str: """ :return: Availability zone where this instance is hosted. This is obtained from EC2 metadata. """ return ec2_metadata.availability_zone @property def ec2_client(self) -> BaseClient: """ Boto3 EC2 client. :return: Boto3 EC2 client. """ if self._ec2_client is None: if self._session is not None: self._ec2_client = self._session.client("ec2", region_name=self._region) else: self._ec2_client = get_client("ec2", region=self._region, role_arn=self._role_arn) return self._ec2_client @property def instance_id(self) -> str: """ The instance's instance_id. It's read from metadata if the class instance was created w/o specifying it. :return: The instance's instance_id. """ if self._instance_id is None: # If the instance_id was not given, obtain it from metadata self._instance_id = ec2_metadata.instance_id return self._instance_id @property def hostname(self) -> Optional[str]: """ :return: Instance's private hostname, i.e. the first part of the private DNS name. For example, if the private DNS name is ip-10-0-0-1.eu-west-1.compute.internal, the hostname is ip-10-0-0-1. """ return self.private_dns_name.split(".")[0] if self.private_dns_name else None @property def private_dns_name(self): """ :return: Instance's private DNS name. This name is for use inside the VPC and is not accessible from the public Internet. """ return self._describe_instance["PrivateDnsName"] @property def private_ip(self): """ :return: Instance's private IP address. Can be None if the instance is in a transitional lifecycle state. """ return self._describe_instance.get("PrivateIpAddress") @property def public_ip(self): """ :return: Instance's public IP address. Can be None if the instance is not configured to have a public IP. """ return self._describe_instance.get("PublicIpAddress") @property def ssm_client(self) -> BaseClient: """ Boto3 SSM client. :return: Boto3 SSM client. """ if self._ssm_client is None: if self._session is not None: self._ssm_client = self._session.client("ssm", region_name=self._region) else: self._ssm_client = get_client("ssm", region=self._region, role_arn=self._role_arn) return self._ssm_client @property def state(self) -> str: """ :return: The state of the instance. Can be one of the following values: - ``pending``: The instance is preparing to launch. - ``running``: The instance is running and ready for use. - ``shutting-down``: The instance is preparing to be terminated. - ``terminated``: The instance has been shut down. - ``stopping``: The instance is stopping. - ``stopped``: The instance has been stopped. """ return self._describe_instance["State"]["Name"] @property def tags(self) -> dict: """ :return: A dictionary with the instance tags. Keys are tag names, and values - the tag values. """ # Tags are returned as a list of dictionaries, where each dictionary has 'Key' and 'Value' keys. # We want to expose them as a dictionary, where the key is the tag name and the value - the tag value. return {tag["Key"]: tag["Value"] for tag in self._describe_instance["Tags"]} @property def exists(self) -> bool: """ Check whether the instance currently exists. An instance is considered non-existent if its state is ``terminated`` or ``shutting-down``, or if the describe call fails with ``InvalidInstanceID.NotFound``. :return: ``True`` if the instance exists and is not terminated. """ try: return self.state not in ("terminated", "shutting-down") except ClientError as err: if err.response["Error"]["Code"] == "InvalidInstanceID.NotFound": return False raise
[docs] def delete(self) -> None: """ Terminate the EC2 instance. Idempotent — does nothing if the instance is already terminated or does not exist. """ try: self.ec2_client.terminate_instances(InstanceIds=[self.instance_id]) LOG.info("Terminated instance %s", self.instance_id) except ClientError as err: error_code = err.response["Error"]["Code"] if error_code == "InvalidInstanceID.NotFound": LOG.info("Instance %s does not exist.", self.instance_id) elif error_code == "OperationNotPermitted" and "terminated" in str(err).lower(): LOG.info("Instance %s is already terminated.", self.instance_id) else: raise
[docs] def add_tag(self, key: str, value: str): """ Add a tag to the EC2 instance. :param key: The key of the tag. :type key: str :param value: The value of the tag. :type value: str """ self.ec2_client.create_tags( Resources=[ self.instance_id, ], Tags=[ { "Key": key, "Value": value, }, ], )
[docs] def execute_command( self, command: str, send_timeout: int = 600, execution_timeout: int = 60 ) -> tuple[int, str, str]: """ Execute a command on the EC2 instance via SSM. :param command: The command to execute. :type command: str :param send_timeout: Time in seconds to attempt to send a command. Instances coming back from hibernation may take about 5 minutes. :type send_timeout: int :param execution_timeout: Time in seconds to wait for the command to complete. :type execution_timeout: int :return: A tuple containing the exit code, standard output, and standard error. Example:: instance = EC2Instance("i-1234567890abcdef0", region="us-east-1") exit_code, stdout, stderr = instance.execute_command("hostname") if exit_code != 0: raise RuntimeError(f"Command failed: {stderr}") """ command_id = self._send_command(command, send_timeout) return self._wait_for_command(command_id, execution_timeout)
@cached_property_with_ttl(ttl=10) def _describe_instance(self): """ Describe the instance - fetch instance data from AWS. :return: A dictionary with the instance data as returned by the ``describe_instances`` method of the EC2 client. """ return self.ec2_client.describe_instances( InstanceIds=[ self.instance_id, ], )[ "Reservations" ][0][ "Instances" ][0] def _send_command(self, command: str, send_timeout: int = 600) -> str: """ Send a command to the instance via SSM, retrying with exponential backoff if the instance is not ready (indicated by an 'InvalidInstanceId' error). The method will retry up to a maximum number of attempts before raising a TimeoutError. :param command: The command to send. :type command: str :param send_timeout: Time in seconds to attempt to send a command. Instances coming back from hibernation may take about 5 minutes. :type send_timeout: int :return: The command ID of the sent command. """ delay = 3 # initial delay in seconds with timeout(send_timeout): # it takes about 5 minutes to wake SSM agent while True: try: # If the instance is not ready yet, the SSM client will raise an # InvalidInstanceId error. We catch this error and retry until # the instance is ready. response = self.ssm_client.send_command( InstanceIds=[self.instance_id], DocumentName="AWS-RunShellScript", Parameters={"commands": [command]}, ) command_id = response["Command"]["CommandId"] LOG.info("Command sent. ID: %s", command_id) return command_id except ClientError as e: if e.response["Error"]["Code"] == "InvalidInstanceId": # Check if the instance is terminated — no point retrying state = self.state if state in ("terminated", "shutting-down"): raise RuntimeError( f"Instance {self.instance_id} is {state} — SSM will never connect" ) from e LOG.warning("Instance is not ready yet. Retrying in %d seconds.", delay) sleep(delay) delay = min(delay * 2, 30) # increase delay exponentially, capped at 30 seconds continue raise # Re-raise other unexpected exceptions def _wait_for_command(self, command_id: str, execution_timeout: int = 60) -> tuple[int, str, str]: """ Wait for the command to finish and return the exit code, standard output, and standard error. The method will retry up to a maximum number of attempts before raising a TimeoutError. :param command_id: The command ID of the sent command. :type command_id: str :param execution_timeout: Time in seconds to wait for the command to finish. :type execution_timeout: int :return: A tuple containing the exit code, standard output, and standard error. """ delay = 1 # initial delay in seconds # Wait for the command to finish with timeout(execution_timeout): while True: try: invocation = self.ssm_client.get_command_invocation( CommandId=command_id, InstanceId=self.instance_id, ) status = invocation["Status"] LOG.info("Current status: %s", status) if CommandStatus(status) in [ CommandStatus.SUCCESS, CommandStatus.FAILED, CommandStatus.TIMED_OUT, CommandStatus.CANCELLED, ]: # Check exit code and output exit_code = int(invocation["ResponseCode"]) stdout = invocation["StandardOutputContent"] stderr = invocation["StandardErrorContent"] LOG.debug("Exit code: %d", exit_code) if exit_code != 0: LOG.error("Command failed with exit code %d", exit_code) getattr(LOG, "error" if exit_code != 0 else "debug")("STDOUT:\n%s", stdout) getattr(LOG, "error" if exit_code != 0 else "debug")("STDERR:\n%s", stderr) return exit_code, stdout, stderr sleep(delay) delay = min(delay * 2, 30) # increase delay exponentially, capped at 30 seconds except ClientError as e: if e.response["Error"]["Code"] == "InvocationDoesNotExist": LOG.warning("Invocation not yet available. Retrying.") sleep(0.1) continue raise # Re-raise other unexpected exceptions