# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import Any, Dict, List, Optional

from sagemaker.deprecations import deprecated
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME


class SageMakerSettings(object):
    """Static class for storing the SageMaker settings."""

    _parsed_sagemaker_version = ""

    @staticmethod
    def set_sagemaker_version(version: str) -> None:
        """Set SageMaker version."""
        SageMakerSettings._parsed_sagemaker_version = version

    @staticmethod
    def get_sagemaker_version() -> str:
        """Return SageMaker version."""
        return SageMakerSettings._parsed_sagemaker_version


class JumpStartModelsAccessor(object):
    """Static class for storing the JumpStart models cache."""

    _cache: Optional[cache.JumpStartModelsCache] = None
    _curr_region = JUMPSTART_DEFAULT_REGION_NAME

    _cache_kwargs: Dict[str, Any] = {}

    @staticmethod
    def _validate_and_mutate_region_cache_kwargs(
        cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
    ) -> Dict[str, Any]:
        """Returns cache_kwargs with region argument removed if present.

        Raises:
            ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

        Args:
            cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate.
            region (str): The region to validate along with the kwargs.
        """
        cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
        if region is not None and "region" in cache_kwargs_dict:
            if region != cache_kwargs_dict["region"]:
                raise ValueError(
                    f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}"
                )
            del cache_kwargs_dict["region"]
        return cache_kwargs_dict

    @staticmethod
    def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
        """Sets ``JumpStartModelsAccessor._cache`` and ``JumpStartModelsAccessor._curr_region``.

        Args:
            region (str): region for which to retrieve header/spec.
            cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
        """
        if JumpStartModelsAccessor._cache is None or region != JumpStartModelsAccessor._curr_region:
            JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
                region=region, **cache_kwargs
            )
            JumpStartModelsAccessor._curr_region = region

    @staticmethod
    def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]:
        """Return entire JumpStart models manifest.

        Raises:
            ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

        Args:
            region (str): Optional. The region to use for the cache.
        """
        cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
            JumpStartModelsAccessor._cache_kwargs, region
        )
        JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
        return JumpStartModelsAccessor._cache.get_manifest()  # type: ignore

    @staticmethod
    def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
        """Returns model header from JumpStart models cache.

        Args:
            region (str): region for which to retrieve header.
            model_id (str): model ID to retrieve.
            version (str): semantic version to retrieve for the model ID.
        """
        cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
            JumpStartModelsAccessor._cache_kwargs, region
        )
        JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
        return JumpStartModelsAccessor._cache.get_header(  # type: ignore
            model_id=model_id, semantic_version_str=version
        )

    @staticmethod
    def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
        """Returns model specs from JumpStart models cache.

        Args:
            region (str): region for which to retrieve header.
            model_id (str): model ID to retrieve.
            version (str): semantic version to retrieve for the model ID.
        """
        cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
            JumpStartModelsAccessor._cache_kwargs, region
        )
        JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
        return JumpStartModelsAccessor._cache.get_specs(  # type: ignore
            model_id=model_id, semantic_version_str=version
        )

    @staticmethod
    def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
        """Sets cache kwargs, clears the cache.

        Raises:
            ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

        Args:
            cache_kwargs (str): cache kwargs to validate.
            region (str): Optional. The region to validate along with the kwargs.
        """
        cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
            cache_kwargs, region
        )
        JumpStartModelsAccessor._cache_kwargs = cache_kwargs
        if region is None:
            JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
                **JumpStartModelsAccessor._cache_kwargs
            )
        else:
            JumpStartModelsAccessor._curr_region = region
            JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
                region=region, **JumpStartModelsAccessor._cache_kwargs
            )

    @staticmethod
    def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None:
        """Resets cache, optionally allowing cache kwargs to be passed to the new cache.

        Raises:
            ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

        Args:
            cache_kwargs (str): cache kwargs to validate.
            region (str): The region to validate along with the kwargs.
        """
        cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
        JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)

    @staticmethod
    @deprecated()
    def get_manifest(
        cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
    ) -> List[JumpStartModelHeader]:
        """Return entire JumpStart models manifest.

        Raises:
            ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

        Args:
            cache_kwargs (Dict[str, Any]): Optional. Cache kwargs to use.
                (Default: None).
            region (str): Optional. The region to use for the cache.
                (Default: None).
        """
        cache_kwargs_dict: Dict[str, Any] = {} if cache_kwargs is None else cache_kwargs
        JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
        return JumpStartModelsAccessor._cache.get_manifest()  # type: ignore
