#!/usr/bin/env python
# coding: utf-8

# Copyright (c) Mito.
# Distributed under the terms of the Modified BSD License.

from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

import numpy as np
from mitosheet.errors import make_cast_value_to_type_error, make_no_column_error
from mitosheet.sheet_functions.types import SERIES_CONVERSION_FUNCTIONS
from mitosheet.sheet_functions.types.utils import (BOOLEAN_SERIES,
                                                   NUMBER_SERIES,
                                                   STRING_SERIES,
                                                   is_int_dtype,
                                                   is_none_type)
from mitosheet.state import State
from mitosheet.step_performers.column_steps.set_column_formula import (
    refresh_dependant_columns, transpile_dependant_columns)
from mitosheet.step_performers.step_performer import StepPerformer


class SetCellValueStepPerformer(StepPerformer):
    """
    A set_cell_value step, allows you to set the value
    of a given cell in the sheet and then recalculates it's dependents.
    """

    @classmethod
    def step_version(cls) -> int:
        return 1

    @classmethod
    def step_type(cls) -> str:
        return 'set_cell_value'

    @classmethod
    def step_display_name(cls) -> str:
        return 'Set Cell Value'
    
    @classmethod
    def step_event_type(cls) -> str:
        return 'set_cell_value_edit'

    @classmethod
    def saturate(cls, prev_state: State, params) -> Dict[str, str]:
        # Mito doesn't allow empty cells, so if the new value is empty, change it to None.
        if params['new_value'] == '':
            params['new_value'] = None

        # Get the old value so we can check if the new value is different
        sheet_index = params['sheet_index']
        column_id = params['column_id']
        row_index = params['row_index']
        column_header = prev_state.column_ids.get_column_header_by_id(sheet_index, column_id)

        # Cast the old value to a string to avoid errors while writing the saved analysis
        params['old_value'] = str(prev_state.dfs[sheet_index].at[row_index, column_header])
        
        return params

    @classmethod
    def execute(
        cls,
        prev_state: State,
        sheet_index: int,
        column_id: str,
        row_index: int,
        old_value: str,
        new_value: Union[str, None],
        **params
    ) -> State:
        if column_id not in prev_state.column_metatype[sheet_index]:
            raise make_no_column_error([column_id], error_modal=False)

        # If nothings changed, there's no work to do
        if old_value == new_value:
            return None

        post_state = deepcopy(prev_state)

        column_header = post_state.column_ids.get_column_header_by_id(sheet_index, column_id)

        # Update the value of the cell, we handle it differently depending on the type of the column
        column_mito_type = post_state.column_type[sheet_index][column_id]
        type_corrected_new_value = cast_value_to_type(new_value, column_mito_type)

        # If the series is an int, but the new value is a float, convert the series to floats before adding the new value
        column_dtype = str(post_state.dfs[sheet_index][column_header].dtype)
        if new_value is not None and '.' in new_value and is_int_dtype(column_dtype):
            post_state.dfs[sheet_index][column_header] = post_state.dfs[sheet_index][column_header].astype('float')
        
        # Actually update the cell's value
        post_state.dfs[sheet_index].at[row_index, column_header] = type_corrected_new_value

        # Update the column formula, and then execute the new formula graph
        refresh_dependant_columns(post_state, post_state.dfs[sheet_index], sheet_index)

        return post_state, {
            'column_mito_type': column_mito_type # for logging
        }

    @classmethod
    def transpile(
        cls,
        prev_state: State,
        post_state: State,
        execution_data: Optional[Dict[str, Any]],
        sheet_index: int,
        column_id: str,
        row_index: int,
        old_value: str,
        new_value: Union[str, None],
    ) -> List[str]:
        code = []

        # If nothings changed, we don't write any code
        if old_value == new_value:
            return code

        # Cast the new_value to the correct type
        column_mito_type = post_state.column_type[sheet_index][column_id]
        type_corrected_new_value = cast_value_to_type(new_value, column_mito_type)

        column_header = post_state.column_ids.get_column_header_by_id(sheet_index, column_id)

        # If the series is an int, but the new value is a float, convert the series to floats before adding the new value
        column_dtype = str(prev_state.dfs[sheet_index][column_header].dtype)
        if new_value is not None and '.' in new_value and is_int_dtype(column_dtype):
            code.append(f'{post_state.df_names[sheet_index]}[\'{column_header}\'] = {post_state.df_names[sheet_index]}[\'{column_header}\'].astype(\'float\')')

        # Actually set the new value
        # We don't need to wrap the value in " if its None, a Boolean Series, or a Number Series.
        if type_corrected_new_value is None or column_mito_type == BOOLEAN_SERIES or column_mito_type == NUMBER_SERIES:
            code.append(f'{post_state.df_names[sheet_index]}.at[{row_index}, \'{column_header}\'] = {type_corrected_new_value}')
        else:
            code.append(f'{post_state.df_names[sheet_index]}.at[{row_index}, \'{column_header}\'] = \"{type_corrected_new_value}\"')

        # Add the transpiled code for all of the dependant columns inorder to refresh the dependant cells
        code = code + transpile_dependant_columns(post_state, sheet_index, column_id)
        return code


    @classmethod
    def describe(
        cls,
        sheet_index: int,
        column_id: str,
        row_index: int,
        old_value: str,
        new_value: Union[str, None],
        df_names=None,
        **params
    ) -> str:
        # Note: Since we don't have access to the dataframes, we can't run the new_value
        # through cast_value_to_type which might change the actual value. Therefore, the new_value
        # that is used in the comment might be incorrect.
        if df_names is not None:
            df_name = df_names[sheet_index]
            return f'Set column {column_id} at index {row_index} in {df_name} to {new_value}'
        return f'Set column {column_id} at index {row_index} to {new_value}'


def cast_value_to_type(value: Union[str, None], mito_series_type: str):
    """
    Helper function for converting a value into the correct type for the 
    series that it is going to be added to. 
    """
    # If the user it trying to make the value None, let them.
    if is_none_type(value):
        return None

    try:
        conversion_function = SERIES_CONVERSION_FUNCTIONS[mito_series_type]
        casted_value_series = conversion_function(value, on_uncastable_arg_element=np.NaN)

        type_corrected_new_value = casted_value_series.iat[0]

        # If the value is a string and it has a " in it, replace it with a ' so the transpiled code does not error
        if mito_series_type == STRING_SERIES and '"' in type_corrected_new_value:
            type_corrected_new_value = type_corrected_new_value.replace('"', "'")

        # If the typed value is not a float, then we do not make it one
        if mito_series_type == NUMBER_SERIES and '.' not in value:
            return round(type_corrected_new_value)

        return type_corrected_new_value
    except:
        raise make_cast_value_to_type_error(value, mito_series_type.replace('_', ' '), error_modal=False)
