from abc import abstractmethod
from typing import Callable, List, Tuple, Union
from .tools import pad_to_multiple, crop_to_original, divide_into_regions, combine_regions
from .data import T
import numpy as np

class CombinedModel:
    """
    A model with all combined transformations. Can be used to sandwich multiple ReversibleTransformations around a base model.
    """
    
    def __init__(self, steps: List[Callable]):
        self.steps = steps
        
        model = self.steps[-1]
        for step in reversed(self.steps[:-1]):
            if step is not None:
                model = step(model)
        self.model = model
        
    def __call__(self, data:T) -> T:
        return self.model(data)
        

class ReversibleTransformation:
    """
    A data transformation that can be reversed.
    """
    
    def __init__(self, distributed:bool=False):
        self.distributed = distributed
        self._inter = id

    def _attach_intermediary(self, intermediary:Callable[[T], T]):
        self._inter = intermediary
    
    @abstractmethod
    def _forward(self, data):
        raise NotImplementedError("forward not implemented")
    
    @abstractmethod
    def _backward(self, data):
        raise NotImplementedError("backward not implemented")
    
    def __call__(self, data_or_inter:Union[T, Callable]) -> T:
        if isinstance(data_or_inter, Callable):
            self._attach_intermediary(data_or_inter)
            return self
        elif self.distributed:
            return self._backward([self._inter(x) for x in self._forward(data_or_inter)])
        else:
            return self._backward(self._inter(self._forward(data_or_inter)))
        
class BasicTTA(ReversibleTransformation):
    """
    Basic Test Time Augmentation. Flips image horizontally, vertically and both.
    """
    
    def __init__(self, combine_mode:str="mean"):
        super().__init__(distributed=True)
        self.combine_mode = combine_mode
    
    def _forward(self, data):
        return [
            data,
            np.flip(data, 1),
            np.flip(data, 2),
            np.flip(data, (1, 2))
        ]
    
    def _backward(self, data):
        data = [
            data[0],
            np.flip(data[1], 1),
            np.flip(data[2], 2),
            np.flip(data[3], (1, 2))
        ]
        
        if self.combine_mode == 'mean':
            return np.mean(data, axis=0)
        else:
            raise ValueError(f"Unknown combine mode {self.combine_mode}")
        
    def __repr__(self) -> str:
        return f"BasicTTA(combine_mode={self.combine_mode})"


class PadCrop(ReversibleTransformation):
    """
    Pad the data to a multiple of the input size, reverse crops the data to the original shape.
    """
    def __init__(self, input_size, pad_mode="reflect", pad_position="end"):
        super().__init__()
        self.input_size = input_size
        self.pad_mode = pad_mode
        self.pad_position = pad_position
    
    def _forward(self, data):
        self.original_shape = data.shape
        return pad_to_multiple(data, self.input_size, pad_mode=self.pad_mode, pad_position=self.pad_position)
    
    def _backward(self, data):
        return crop_to_original(data, self.original_shape, pad_position=self.pad_position)
    
    def __repr__(self):
        return f"PadCrop(input_size={self.input_size}, pad_mode={self.pad_mode}, pad_position={self.pad_position})"
    
class DivideCombine(ReversibleTransformation):
    """
    Divide the data into regions of size region_size, reverse combines the regions into the original shape.
    """
    def __init__(self, region_size:Tuple):
        """
        Args:
            region_size: The size of the regions to divide the data into. 
        """
        super().__init__()
        self.region_size = region_size
    
    def _forward(self, data):
        """
        Forward pass, divide the data into regions.
        """
        self.original_shape = data.shape # save the original shape for the backward pass
        return divide_into_regions(data, self.region_size)
    
    def _backward(self, data):
        return combine_regions(data, self.region_size, self.original_shape)
    
    def __repr__(self):
        return f"DivideCombine(region_size={self.region_size})"