from os.path import isfile, isdir
from os import makedirs
from typing import List

import pandas as pd
# import geopandas as gpd

import geopy.distance
import shapely
from shapely import wkt
from shapely.ops import unary_union
from shapely.geometry import Point

import matplotlib.pyplot as plt

import pypsa

from epippy.geographics import get_shapes, get_natural_earth_shapes, get_nuts_shapes, replace_iso2_codes
from epippy.technologies import get_costs
from epippy.topologies.core import voronoi_special

from epippy import data_path


def get_ehighway_clusters() -> pd.DataFrame:
    """Return a DataFrame indicating for each ehighway cluster: its country, composing NUTS regions
     (either NUTS0 or country) and the position of the bus associated to this cluster (if the position
     is not specified one can obtain it by taking the centroid of the shapes)."""
    eh_clusters_fn = f"{data_path}topologies/e-highways/source/clusters_2016.csv"
    return pd.read_csv(eh_clusters_fn, delimiter=";", index_col="name")


def get_ehighway_shapes() -> pd.Series:
    """
    Return e-Highways cluster shapes.

    Returns
    -------
    shapes : gpd.GeoDataFrame
        DataFrame containing desired shapes.
    """

    clusters_fn = f"{data_path}topologies/e-highways/source/clusters_2016.csv"
    clusters = pd.read_csv(clusters_fn, delimiter=";", index_col=0)

    all_codes = []
    for idx in clusters.index:
        all_codes.extend(clusters.loc[idx, 'codes'].split(','))
    nuts_codes = [code for code in all_codes if len(code) == 5]
    iso_codes = [code for code in all_codes if len(code) != 5]
    nuts3_shapes = get_nuts_shapes("3", nuts_codes)
    iso_shapes = get_natural_earth_shapes(iso_codes)

    shapes = pd.Series(index=clusters.index)

    for node in clusters.index:
        codes = clusters.loc[node, 'codes'].split(',')
        # If cluster codes are all NUTS3, union of all.
        if len(codes[0]) > 2:
            shapes.loc[node] = unary_union(nuts3_shapes.loc[codes].values)
        # If cluster is specified by country ISO2 code, data is taken from naturalearth
        else:
            shapes.loc[node] = iso_shapes.loc[codes].values[0]

    return shapes


def preprocess(plotting: bool = False):
    """
    Pre-process e-highway buses and lines information.

    Parameters
    ----------
    plotting: bool
        Whether to plot the results
    """

    generated_dir = f"{data_path}topologies/e-highways/generated/"
    if not isdir(generated_dir):
        makedirs(generated_dir)

    eh_clusters = get_ehighway_clusters()

    line_data_fn = f"{data_path}topologies/e-highways/source/Results_GTC_estimation_updated.xlsx"
    lines = pd.read_excel(line_data_fn, usecols="A:D", skiprows=[0], names=["name", "nb_lines", "MVA", "GTC"])
    lines["bus0"] = lines["name"].apply(lambda k: k.split('-')[0])
    lines["bus1"] = lines["name"].apply(lambda k: k.split('-')[1].split("_")[0])
    lines["carrier"] = lines["name"].apply(lambda k: k.split('(')[1].split(')')[0])
    lines["s_nom"] = lines["GTC"]/1000.0
    lines = lines.set_index("name")
    lines.index.names = ["id"]
    lines = lines.drop(["nb_lines", "MVA", "GTC"], axis=1)

    # Drop lines that are associated to buses that are not defined
    for idx in lines.index:
        if lines.loc[idx].bus0 not in eh_clusters.index.values or \
                lines.loc[idx].bus1 not in eh_clusters.index.values:
            lines = lines.drop([idx])

    buses = pd.DataFrame(columns=["x", "y", "country", "onshore_region", "offshore_region"], index=eh_clusters.index)
    buses.index.names = ["id"]
    buses.country = eh_clusters.country

    # Assemble the clusters define in e-highways in order to compute for each bus its region, x and y
    cluster_shapes = get_ehighway_shapes()

    for idx in cluster_shapes.index:

        cluster_shape = cluster_shapes[idx]

        # Compute centroid of shape
        # Some special points are not the centroid of their region
        centroid = eh_clusters.loc[idx].centroid
        if centroid == 'None':
            centroid = cluster_shape.centroid
        else:
            x = float(centroid.strip("(").strip(")").split(",")[0])
            y = float(centroid.strip("(").strip(")").split(",")[1])
            centroid = shapely.geometry.Point(x, y)
        buses.loc[idx].onshore_region = cluster_shape
        buses.loc[idx].x = centroid.x
        buses.loc[idx].y = centroid.y
        buses.loc[idx].onshore = True

    # Offshore nodes
    add_buses = pd.DataFrame([["OFF1", -6.5, 49.5, None, None, Point(-6.5, 49.5)],  # England south-west
                              ["OFF2", 3.5, 55.5, None, None, Point(3.5, 55.5)],  # England East
                              ["OFF3", 30.0, 43.5, None, None, Point(30.0, 43.5)],  # Black Sea
                              ["OFF4", 18.5, 56.5, None, None, Point(18.5, 56.5)],  # Sweden South-east
                              ["OFF5", 19.5, 62.0, None, None, Point(19.5, 62.0)],  # Sweden North-east
                              ["OFF6", -3.0, 46.5, None, None, Point(-3.0, 46.5)],  # France west
                              ["OFF7", -5.0, 54.0, None, None, Point(-5.0, 54.0)],  # Isle of Man
                              ["OFF8", -7.5, 56.5, None, None, Point(-7.5, 56.0)],  # Uk North
                              ["OFF9", 15.0, 43.0, None, None, Point(15.0, 43.0)],  # Italy east
                              ["OFFA", 25.0, 39.0, None, None, Point(25.0, 39.0)],  # Greece East
                              ["OFFB", 1.5, 40.0, None, None, Point(1.5, 40.0)],  # Spain east
                              ["OFFC", 9.0, 65.0, None, None, Point(9.0, 65.0)],  # Norway South-West
                              ["OFFD", 14.5, 69.0, None, None, Point(14.0, 68.5)],  # Norway North-West
                              # ["OFFE", 26.0, 72.0, Point(26.0, 72.0)],  # Norway North-West Norther
                              ["OFFF", 11.5, 57.0, None, None, Point(11.5, 57.0)],  # East Denmark
                              ["OFFG", -1.0, 50.0, None, None, Point(-1.0, 50.0)],  # France North
                              ["OFFI", -9.5, 41.0, None, None, Point(-9.5, 41.0)]],  # Portugal West
                             columns=["id", "x", "y", "country", "onshore_region", "offshore_region"])
    add_buses = add_buses.set_index("id")
    buses = buses.append(add_buses)

    # Offshore lines
    add_lines = pd.DataFrame([["OFF1-96IE", "OFF1", "96IE", "DC", 0],
                              ["OFF1-91UK", "OFF1", "91UK", "DC", 0],
                              ["OFF1-21FR", "OFF1", "21FR", "DC", 0],
                              ["OFF2-79NO", "OFF2", "79NO", "DC", 0],
                              ["OFF2-30NL", "OFF2", "30NL", "DC", 0],
                              ["OFF2-38DK", "OFF2", "38DK", "DC", 0],
                              ["OFF2-90UK", "OFF2", "90UK", "DC", 0],
                              ["OFF2-28BE", "OFF2", "28BE", "DC", 0],
                              ["OFF3-61RO", "OFF3", "61RO", "DC", 0],
                              ["OFF3-66BG", "OFF3", "66BG", "DC", 0],
                              ["OFF4-73EE", "OFF4", "73EE", "DC", 0],
                              ["OFF4-77LT", "OFF4", "77LT", "DC", 0],
                              ["OFF4-78LV", "OFF4", "78LV", "DC", 0],
                              ["OFF4-45PL", "OFF4", "45PL", "DC", 0],
                              ["OFF4-89SE", "OFF4", "89SE", "DC", 0],
                              ["OFF5-87SE", "OFF5", "87SE", "DC", 0],
                              ["OFF5-75FI", "OFF5", "75FI", "DC", 0],
                              ["OFF6-17FR", "OFF6", "17FR", "DC", 0],
                              ["OFF6-21FR", "OFF6", "21FR", "DC", 0],
                              ["OFF7-93UK", "OFF7", "93UK", "DC", 0],
                              ["OFF7-95UK", "OFF7", "95UK", "DC", 0],
                              ["OFF8-94UK", "OFF8", "94UK", "DC", 0],
                              ["OFF8-21FR", "OFF8", "95UK", "DC", 0],
                              ["OFF9-54IT", "OFF9", "54IT", "DC", 0],
                              ["OFF9-62HR", "OFF9", "62HR", "DC", 0],
                              ["OFFA-xxTR", "OFFA", "xxTR", "DC", 0],
                              ["OFFA-68GR", "OFFA", "68GR", "DC", 0],
                              ["OFFA-69GR", "OFFA", "69GR", "DC", 0],
                              ["OFFB-06ES", "OFFB", "06ES", "DC", 0],
                              ["OFFB-11ES", "OFFB", "11ES", "DC", 0],
                              ["OFFC-83NO", "OFFC", "83NO", "DC", 0],
                              ["OFFD-84NO", "OFFD", "84NO", "DC", 0],
                              # ["OFFE-85NO", "OFFE", "85NO", "DC", 0],
                              ["OFFF-38DK", "OFFF", "38DK", "DC", 0],
                              ["OFFF-72DK", "OFFF", "72DK", "DC", 0],
                              ["OFFF-89SE", "OFFF", "89SE", "DC", 0],
                              ["OFFG-22FR", "OFFG", "22FR", "DC", 0],
                              ["OFFG-90UK", "OFFG", "90UK", "DC", 0],
                              ["OFFG-91UK", "OFFG", "91UK", "DC", 0],
                              ["OFFI-12PT", "OFFI", "12PT", "DC", 0]],
                             columns=["id", "bus0", "bus1", "carrier", "s_nom"])
    add_lines = add_lines.set_index("id")
    lines = lines.append(add_lines)

    # Adding length to the lines
    lines["length"] = pd.Series([0]*len(lines.index), index=lines.index)
    for idx in lines.index:
        bus0_id = lines.loc[idx]["bus0"]
        bus1_id = lines.loc[idx]["bus1"]
        bus0_x = buses.loc[bus0_id]["x"]
        bus0_y = buses.loc[bus0_id]["y"]
        bus1_x = buses.loc[bus1_id]["x"]
        bus1_y = buses.loc[bus1_id]["y"]
        lines.loc[idx, "length"] = geopy.distance.geodesic((bus0_y, bus0_x), (bus1_y, bus1_x)).km

    if plotting:
        from epippy.topologies.core.plot import plot_topology
        plot_topology(buses.dropna(subset=["onshore_region"]), lines)
        plt.show()

    buses.to_csv(f"{generated_dir}buses.csv")
    lines.to_csv(f"{generated_dir}lines.csv")


def get_topology(network: pypsa.Network, countries: List[str] = None, add_offshore: bool = True,
                 extend_line_cap: bool = True, use_ex_line_cap: bool = True,
                 plot: bool = False) -> pypsa.Network:
    """
    Load the e-highway network topology (buses and links) using PyPSA.

    Parameters
    ----------
    network: pypsa.Network
        Network instance
    countries: List[str] (default: None)
        List of ISO codes of countries for which we want the e-highway topology
    add_offshore: bool (default: True)
        Whether to include offshore nodes
    extend_line_cap: bool (default True)
        Whether line capacity is allowed to be expanded
    use_ex_line_cap: bool (default True)
        Whether to use existing line capacity
    plot: bool (default: False)
        Whether to show loaded topology or not

    Returns
    -------
    network: pypsa.Network
        Updated network
    """

    assert countries is None or len(countries) != 0, "Error: Countries list must not be empty. If you want to " \
                                                     "obtain, the full topology, don't pass anything as argument."

    topology_dir = f"{data_path}topologies/e-highways/generated/"
    buses_fn = f"{topology_dir}buses.csv"
    assert isfile(buses_fn), f"Error: Buses are undefined. Please run 'preprocess'."
    buses = pd.read_csv(buses_fn, index_col='id')
    lines_fn = f"{topology_dir}lines.csv"
    assert isfile(lines_fn), f"Error: Lines are undefined. Please run 'preprocess'."
    lines = pd.read_csv(lines_fn, index_col='id')

    # Remove offshore buses if not considered
    if not add_offshore:
        buses = buses.dropna(subset=["onshore_region"])

    if countries is not None:
        # In e-highway, GB is referenced as UK
        iso_to_ehighway = {"GB": "UK"}
        ehighway_countries = [iso_to_ehighway[c] if c in iso_to_ehighway else c for c in countries]

        # Remove onshore buses that are not in the considered region,
        # keep also buses that are offshore (i.e. with a country name that is not a string)
        def filter_buses(bus):
            return (not isinstance(bus.country, str)) or (bus.name[2:] in ehighway_countries)
        buses = buses.loc[buses.apply(filter_buses, axis=1)]
    else:
        countries = replace_iso2_codes(list(set([idx[2:] for idx in buses.dropna(subset=["onshore_region"]).index])))

    # Converting polygons strings to Polygon object
    for region_type in ["onshore_region", "offshore_region"]:
        regions = buses[region_type].values
        # Convert strings
        for i, region in enumerate(regions):
            if isinstance(region, str):
                regions[i] = shapely.wkt.loads(region)

    # Remove lines for which one of the two end buses has been removed
    lines = pd.DataFrame(lines.loc[lines.bus0.isin(buses.index) & lines.bus1.isin(buses.index)])

    # Removing offshore buses that are not connected anymore
    connected_buses = sorted(list(set(lines["bus0"]).union(set(lines["bus1"]))))
    buses = buses.loc[connected_buses]
    assert len(buses) != 0, "Error: No buses are located in the given list of countries."

    # Add offshore polygons to remaining offshore buses
    if add_offshore:
        offshore_shapes = get_shapes(countries, which='offshore', save=True)["geometry"]
        if len(offshore_shapes) != 0:
            offshore_zones_shape = unary_union(offshore_shapes.values)
            offshore_bus_indexes = buses[buses["onshore_region"].isnull()].index
            offshore_buses = buses.loc[offshore_bus_indexes]
            # Use a home-made 'voronoi' partition to assign a region to each offshore bus
            buses.loc[offshore_bus_indexes, "offshore_region"] = voronoi_special(offshore_zones_shape, offshore_buses[["x", "y"]])

    # Setting line parameters
    """ For DC-opf
    lines['s_nom'] *= 1000.0  # PyPSA uses MW
    lines['s_nom_min'] = lines['s_nom']
    # Define reactance   # TODO: do sth more clever
    lines['x'] = pd.Series(0.00001, index=lines.index)
    lines['s_nom_extendable'] = pd.Series(True, index=lines.index) # TODO: parametrize
    lines['capital_cost'] = pd.Series(index=lines.index)
    for idx in lines.index:
        carrier = lines.loc[idx].carrier
        cap_cost, _ = get_costs(carrier, sum(network.snapshot_weightings['objective']))
        lines.loc[idx, ('capital_cost', )] = cap_cost * lines.length.loc[idx]
    """

    lines['p_nom'] = lines["s_nom"]
    if not use_ex_line_cap:
        lines['p_nom'] = 0
    lines['p_nom_min'] = lines['p_nom']
    lines['p_min_pu'] = -1.  # Making the link bi-directional
    lines = lines.drop('s_nom', axis=1)
    lines['p_nom_extendable'] = extend_line_cap
    lines['capital_cost'] = pd.Series(index=lines.index)
    for idx in lines.index:
        carrier = lines.loc[idx].carrier
        cap_cost, _ = get_costs(carrier, sum(network.snapshot_weightings['objective']))
        lines.loc[idx, ('capital_cost', )] = cap_cost * lines.length.loc[idx]

    network.import_components_from_dataframe(buses, "Bus")
    network.import_components_from_dataframe(lines, "Link")
    # network.import_components_from_dataframe(lines, "Line") for dc-opf

    if plot:
        from epippy.topologies.core.plot import plot_topology
        plot_topology(buses, lines)
        plt.show()

    return network


if __name__ == "__main__":
    preprocess(True)
