# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains code to create and manage SageMaker ``Context``."""
from __future__ import absolute_import

from sagemaker.apiutils import _base_types
from sagemaker.lineage import (
    _api_types,
    _utils,
    association,
)


class Context(_base_types.Record):
    """An Amazon SageMaker context, which is part of a SageMaker lineage.

    Attributes:
        context_arn (str): The ARN of the context.
        context_name (str): The name of the context.
        context_type (str): The type of the context.
        description (str): A description of the context.
        source (obj): The source of the context with a URI and type.
        properties (dict): Dictionary of properties.
        tags (List[dict[str, str]]): A list of tags to associate with the context.
        creation_time (datetime): When the context was created.
        created_by (obj): Contextual info on which account created the context.
        last_modified_time (datetime): When the context was last modified.
        last_modified_by (obj): Contextual info on which account created the context.
    """

    context_arn = None
    context_name = None
    context_type = None
    properties = None
    tags = None
    creation_time = None
    created_by = None
    last_modified_time = None
    last_modified_by = None

    _boto_load_method = "describe_context"
    _boto_create_method = "create_context"
    _boto_update_method = "update_context"
    _boto_delete_method = "delete_context"

    _custom_boto_types = {
        "source": (_api_types.ContextSource, False),
    }

    _boto_update_members = [
        "context_name",
        "description",
        "properties",
        "properties_to_remove",
    ]
    _boto_delete_members = ["context_name"]

    def save(self):
        """Save the state of this Context to SageMaker.

        Returns:
            obj: boto API response.
        """
        return self._invoke_api(self._boto_update_method, self._boto_update_members)

    def delete(self, disassociate=False):
        """Delete the context object.

        Args:
            disassociate (bool): When set to true, disassociate incoming and outgoing association.

        Returns:
            obj: boto API response.
        """
        if disassociate:
            _utils._disassociate(
                source_arn=self.context_arn, sagemaker_session=self.sagemaker_session
            )
            _utils._disassociate(
                destination_arn=self.context_arn, sagemaker_session=self.sagemaker_session
            )
        return self._invoke_api(self._boto_delete_method, self._boto_delete_members)

    def set_tag(self, tag=None):
        """Add a tag to the object.

        Args:
            tag (obj): Key value pair to set tag.

        Returns:
            list({str:str}): a list of key value pairs
        """
        return self._set_tags(resource_arn=self.context_arn, tags=[tag])

    def set_tags(self, tags=None):
        """Add tags to the object.

        Args:
            tags ([{key:value}]): list of key value pairs.

        Returns:
            list({str:str}): a list of key value pairs
        """
        return self._set_tags(resource_arn=self.context_arn, tags=tags)

    @classmethod
    def load(cls, context_name, sagemaker_session=None):
        """Load an existing context and return an ``Context`` object representing it.

        Examples:
            .. code-block:: python

                from sagemaker.lineage import context

                my_context = context.Context.create(
                    context_name='MyContext',
                    context_type='Endpoint',
                    source_uri='arn:aws:...')

                my_context.properties["added"] = "property"
                my_context.save()

                for ctx in context.Context.list():
                    print(ctx)

                my_context.delete()

            Args:
                context_name (str): Name of the context
                sagemaker_session (sagemaker.session.Session): Session object which
                    manages interactions with Amazon SageMaker APIs and any other
                    AWS services needed. If not specified, one is created using the
                    default AWS configuration chain.

            Returns:
                Context: A SageMaker ``Context`` object
        """
        context = cls._construct(
            cls._boto_load_method,
            context_name=context_name,
            sagemaker_session=sagemaker_session,
        )
        return context

    @classmethod
    def create(
        cls,
        context_name=None,
        source_uri=None,
        source_type=None,
        context_type=None,
        description=None,
        properties=None,
        tags=None,
        sagemaker_session=None,
    ):
        """Create a context and return a ``Context`` object representing it.

        Args:
            context_name (str): The name of the context.
            source_uri (str): The source URI of the context.
            source_type (str): The type of the source.
            context_type (str): The type of the context.
            description (str): Description of the context.
            properties (dict): Metadata associated with the context.
            tags (dict): Tags to add to the context.
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.

        Returns:
            Context: A SageMaker ``Context`` object.
        """
        return super(Context, cls)._construct(
            cls._boto_create_method,
            context_name=context_name,
            source=_api_types.ContextSource(source_uri=source_uri, source_type=source_type),
            context_type=context_type,
            description=description,
            properties=properties,
            tags=tags,
            sagemaker_session=sagemaker_session,
        )

    @classmethod
    def list(
        cls,
        source_uri=None,
        context_type=None,
        created_after=None,
        created_before=None,
        sort_by=None,
        sort_order=None,
        max_results=None,
        next_token=None,
        sagemaker_session=None,
    ):
        """Return a list of context summaries.

        Args:
            source_uri (str, optional): A source URI.
            context_type (str, optional): An context type.
            created_before (datetime.datetime, optional): Return contexts created before this
                instant.
            created_after (datetime.datetime, optional): Return contexts created after this instant.
            sort_by (str, optional): Which property to sort results by.
                One of 'SourceArn', 'CreatedBefore', 'CreatedAfter'
            sort_order (str, optional): One of 'Ascending', or 'Descending'.
            max_results (int, optional): maximum number of contexts to retrieve
            next_token (str, optional): token for next page of results
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.

        Returns:
            collections.Iterator[ContextSummary]: An iterator
                over ``ContextSummary`` objects.
        """
        return super(Context, cls)._list(
            "list_contexts",
            _api_types.ContextSummary.from_boto,
            "ContextSummaries",
            source_uri=source_uri,
            context_type=context_type,
            created_before=created_before,
            created_after=created_after,
            sort_by=sort_by,
            sort_order=sort_order,
            max_results=max_results,
            next_token=next_token,
            sagemaker_session=sagemaker_session,
        )


class EndpointContext(Context):
    """An Amazon SageMaker endpoint context, which is part of a SageMaker lineage."""

    def models(self):
        """Get all models deployed by all endpoint versions of the endpoint.

        Returns:
            list of Associations: Associations that destination represents an endpoint's model.
        """
        endpoint_actions = association.Association.list(
            sagemaker_session=self.sagemaker_session,
            source_arn=self.context_arn,
            destination_type="ModelDeployment",
        )

        model_list = [
            model
            for endpoint_action in endpoint_actions
            for model in association.Association.list(
                source_arn=endpoint_action.destination_arn,
                destination_type="Model",
                sagemaker_session=self.sagemaker_session,
            )
        ]
        return model_list
