import functools
import logging
import os
from typing import List, Optional, Tuple, Union

from web3 import Web3
from web3 import contract as web3_contract_module
from web3.contract import Contract
from web3.types import BlockIdentifier

_web3_patched = False
LATEST_BLOCK: BlockIdentifier = "latest"

logger = logging.getLogger(__name__)

env_variables = os.environ


def rate_limit_web3(func):
    """
    Rate limiting decorator for WEB3.
    Rate-limit can differ using multiple providers.
    """

    @functools.wraps(func)
    def inner(web3: Web3, *args, **kwargs):
        # TODO move rate-limits somewhere
        if 'api.anyblock' in web3.provider.endpoint_uri:
            # set rate limit for:
            #   - 0.12s (500 reqs/min) in case of prod env
            #   - 30s (2 reqs/min) otherwise (dev/test/staging env)
            rl = 0.12 if current_env in (Env.DEV, Env.PROD) else 30
            apply_rate_limit(rl, 'anyblock')

        return func(web3, *args, **kwargs)

    return inner


def patch_w3():
    """
    Monkeypatch RPC call using web3.
    We patch `call_contract_function`, because it's nearest point before
    RPC endpoint is called.
    """
    global _web3_patched
    if _web3_patched:
        return

    web3_contract_module.call_contract_function = rate_limit_web3(
        web3_contract_module.call_contract_function
    )
    _web3_patched = True


patch_w3()


def easy_call(
    contract: Contract,
    function_name: str,
    *f_args: Union[bytes, int, str, List[str], Tuple[str]],
    block: Optional[BlockIdentifier] = None,
) -> Union[int, str, dict, List[dict]]:
    """
    Call smart contract function, handle errors and structure results.
    Args:
        contract: web3 smart contract object
        function_name: name of function to call
        f_args: function arguments
        block: block for which function is called

    Returns (int, str, dict or List[dict]):
        Parsed result.
    """
    if not block:
        block = LATEST_BLOCK

    # TODO add errors handling
    f = getattr(contract.functions, function_name)
    try:
        raw = f(*f_args).call(block_identifier=block)
    except Exception as e:
        logging.error(
            "Failed to call function: %s, on contract: %s",
            function_name,
            contract.address,
        )
        raise e

    result = map_struct(raw, contract.abi, function_name)
    return result


def get_eth_client(network: str = 'mainnet', provider_name: str = 'anyblock') -> Web3:
    """
    Get web3 client.
    Args:
        network (string): name of the network (mainnet, ropsten, ...)
        provider_name (string): name of the provider (anyblock, infura)
    """
    if network == 'optimism':
        url = 'https://mainnet.optimism.io'

    elif provider_name == 'anyblock':
        url = (
            'https://api.anyblock.tools/ethereum/ethereum/'
            f'{network}/rpc/{env_variables.get("ANYBLOCK_APIKEY")}/'
        )
    elif provider_name == 'infura':
        key = env_variables.get('INFURA_APIKEY')
        url = f'https://{network}.infura.io/v3/' f'{key}'
    elif provider_name == 'local':
        url = env_variables.get('LOCAL_ETH_NODE')
    elif provider_name == 'pokt':
        url = env_variables.get('POKT_NODE')
    else:
        raise ValueError(f'Invalid provider name: {provider_name}')

    return Web3(Web3.HTTPProvider(url))


def map_struct(raw_result, abi, func_name):
    """
    Map raw result from smart contract to structured dict (or list
    of dicts).
    Args:
        raw_result (tuple or [tuple]): result from smart contract's call
        abi (dict): ABI definition of whole smart contract
        func_name (str): name of used func

    Returns (dict or [dict]):
        Structured result/s
    """
    func_abi = next(a for a in abi if a.get('name') == func_name)
    components = create_components(func_abi['outputs'])
    component = components[0] if isinstance(components, list) else components

    return (
        [map_sub_item(i, component) for i in raw_result]
        if isinstance(raw_result, list)
        else map_sub_item(raw_result, component)
    )


def map_sub_item(item, component):
    """
    Map component's names to item.
    """
    if not component:
        return item

    sub = {}
    for i, c in zip(item, component.items()):
        key, sub_component = c

        if isinstance(i, list):
            sub[key] = []
            for j in i:
                sub[key].append(map_sub_item(j, sub_component))
        else:
            sub[key] = map_sub_item(i, sub_component)

    return sub


def create_components(abi_outputs):
    """
    Create simple component as nested dict with attributes' names.
    Values of keys are dicts (nested dicts) or None values (if there
    is no more nested item).
    """
    return [create_component(o) for o in abi_outputs]


def create_component(item):
    """
    Create single component from raw_item.
    If there is no nested component, nothing is returned.
    """
    if item.get('components'):
        return {c['name']: create_component(c) for c in item['components']}


def to_checksum_address(func):
    """
    A decorator, which converts input argument representing an Ethereum address
    into its check-summed version. This decorator can be used for class methods
    with address as the first positional argument.
    """

    @functools.wraps(func)
    def inner(self, address, *args, **kwargs):
        return func(self, Web3.toChecksumAddress(address), *args, **kwargs)

    return inner


def ensure_checksum_address(address: Optional[str]) -> Optional[str]:
    return Web3.toChecksumAddress(address) if address is not None else None
