# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03_weights.ipynb (unless otherwise specified).

__all__ = ['check_weights', 'add_weights', 'WeightMan', 'weight_radius_plots', 'PointSource']

# Cell
import os, sys,  pickle, healpy
from pathlib import Path
import numpy as np
from scipy.integrate import quad
from astropy.coordinates import SkyCoord, Angle
from .config import *

# Cell
def check_weights(config, source):
    """
    Check that weights for the source are available: if so, return the weight file name

    - source -- A PointSource object with information on source location

    Returns the filepath to the file if successful, otherwise, print a message abount available files
    """
    weight_files = config.wtlike_data/'weight_files'
    assert weight_files.is_dir(), f'Expect {weight_files} to be a directory'
    weight_file = weight_files/ (source.filename+'_weights.pkl')
    if not weight_file.exists():
        available = np.array(list(map(lambda p: p.name[:p.name.find('_weights')],
                          weight_files.glob('*_weights.pkl'))))
        print(f'{source} not found in list of weight files at\n\t {weight_files}.\n Available:\n{available}',
             file = sys.stderr)
        return None
    return weight_file

# Cell
# def load_weights(config, filename, ):
#     """Load the weight informaton

#     filename: pickled dict with map info

#     """
#     # load a pickle containing weights, generated by pointlike
#     assert os.path.exists(filename),f'File {filename} not found.'
#     with open(filename, 'rb') as file:
#         wtd = pickle.load(file, encoding='latin1')
#     assert type(wtd)==dict, 'Expect a dictionary'
#     test_elements = 'energy_bins pixels weights nside model_name radius order roi_name'.split()
#     assert np.all([x in wtd.keys() for x in test_elements]),f'Dict missing one of the keys {test_elements}'
#     if config.verbose>0:
#         print(f'Load weights from file {os.path.realpath(filename)}')
#         pos = wtd['source_lb']
#         print(f'\tFound: {wtd["source_name"]} at ({pos[0]:.2f}, {pos[1]:.2f})')
#     # extract pixel ids and nside used
#     wt_pix   = wtd['pixels']
#     nside_wt = wtd['nside']

#     # merge the weights into a table, with default nans
#     # indexing is band id rows by weight pixel columns
#     # append one empty column for photons not in a weight pixel
#     # calculated weights are in a dict with band id keys
#     wts = np.full((32, len(wt_pix)+1), np.nan, dtype=np.float32)
#     weight_dict = wtd['weights']
#     for k in weight_dict.keys():
#         t = weight_dict[k]
#         if len(t.shape)==2:
#             t = t.T[0] #???
#         wts[k,:-1] = t
#     return wts , wt_pix , nside_wt

# Cell
def _add_weights(config, wts, wt_pix, nside_wt, photon_data):
    """ get the photon pixel ids, convert to NEST (if not already) and right shift them
        add 'weight', remove 'band', 'pixel'
    """
    if not config.nest:
        # data are RING
        photon_pix = healpy.ring2nest(config.nside, photon_data.nest_index.values)
    else:
        photon_pix = photon_data.nest_index.values
    to_shift = 2*int(np.log2(config.nside/nside_wt));
    shifted_pix =   np.right_shift(photon_pix, to_shift)
    bad = np.logical_not(np.isin(shifted_pix, wt_pix))
    if config.verbose>0 & sum(bad)>0:
        print(f'\tApplying weights: {sum(bad)} / {len(bad)} photon pixels are outside weight region')
    if sum(bad)==len(bad):
        a = np.array(healpy.pix2ang(nside_wt, wt_pix, nest=True, lonlat=True)).mean(axis=1).round(1)
        b = np.array(healpy.pix2ang(nside_wt, shifted_pix, nest=True, lonlat=True)).mean(axis=1).round(1)

        raise Exception(f'There was no overlap of the photon data at {b} and the weights at {a}')
    shifted_pix[bad] = 12*nside_wt**2 # set index to be beyond pixel indices

    # find indices with search and add a "weights" column
    # (expect that wt_pix are NEST ordering and sorted)
    weight_index = np.searchsorted(wt_pix,shifted_pix)
    band_index = np.fmin(31, photon_data.band.values) #all above 1 TeV into last bin

    # final grand lookup -- isn't numpy wonderful!
    photon_data.loc[:,'weight'] = self.wts[tuple([band_index, weight_index])]

    # don't need these columns now (add flag to config to control??)
#     photon_data.drop(['band', 'pixel'], axis=1)

    if config.verbose>1:
        print(f'\t{sum(np.isnan(photon_data.weight.values))} events without weight')

# Cell
def add_weights(config,  photon_data, source): # nbins=50):
    """ add weights for the source to the photon data

    - photon_data -- DataFrame with photon data

    - source -- `PointSource` object

    """
    weight_file =  check_weights(config,  source)
    if weight_file is None:
        raise Exception(f'Weight file not found for {source}')

    ## NEW
    wtman = WeightMan(config, filename=weight_file)
    photon_data = wtman.add_weights(photon_data)

    ## OLD
#     wts, wt_pix, nside_wt = load_weights(config, weight_file)
#     _add_weights(config, wts, wt_pix, nside_wt, photon_data)

    #return np.histogram(photon_data.weight.values, np.linspace(0,1,nbins+1))[0]

# Cell
class WeightMan(dict):
    """ Weight Management

    * Load weight tables
    * Assign weights to photons
    """

    def __init__(self, config,  source=None, filename=None,):
        """
        TODO: find filename given source
        """
        wtpath =Path(config.wtlike_data)/'weight_files'
        assert wtpath.is_dir(), f' {wtpath} not an existing file path'
        # load a pickle containing weights, generated by pointlike

        if isinstance(source, PointSource):
            wtpath =Path(config.wtlike_data)/'weight_files'
            assert wtpath.is_dir(), f' {wtpath} not an existing file path'
            full_filename = wtpath/(source.filename+'_weights.pkl')
            self.source=source

        elif filename is not None:
            full_filename = wtpath/filename
            self.source = None

        else:
            raise Exception('WeightMan: expected source or filename')

        assert (full_filename).is_file(),f'File {filename} not found at {wtpath}'

        with open(full_filename, 'rb') as file:
            wtd = pickle.load(file, encoding='latin1')
        assert type(wtd)==dict, 'Expect a dictionary'
        self.update(wtd)
        self.__dict__.update(wtd)
        self.filename=filename
        self.config = config
#         pos = self['source_lb']
#         print(f'\tSource is {self["source_name"]} at ({pos[0]:.2f}, {pos[1]:.2f})')

        # check format--old has pixels, weights at tome
        srcfile = f'file "{self.filename}"' if self.source is None else f'file from source "{source.filename}"_weights.pkl'

        if hasattr(self, 'nside'):
            self.format=0
            if config.verbose>0:
                print(f'WeightMan: {srcfile} old format, nside={self.nside}')

            test_elements = 'energy_bins pixels weights nside model_name radius order roi_name'.split()
            assert np.all([x in wtd.keys() for x in test_elements]),f'Dict missing one of the keys {test_elements}'
            if config.verbose>0:
                print(f'Load weights from file {os.path.realpath(filename)}')
                pos = self['source_lb']
                print(f'\tFound: {self["source_name"]} at ({pos[0]:.2f}, {pos[1]:.2f})')
            # extract pixel ids and nside used
            self.wt_pix   = self['pixels']
            self.nside_wt = self['nside']

            # merge the weights into a table, with default nans
            # indexing is band id rows by weight pixel columns
            # append one empty column for photons not in a weight pixel
            # calculated weights are in a dict with band id keys
            self.wts = np.full((32, len(self.wt_pix)+1), np.nan, dtype=np.float32)
            weight_dict = self['weights']
            for k in weight_dict.keys():
                t = weight_dict[k]
                if len(t.shape)==2:
                    t = t.T[0] #???
                self.wts[k,:-1] = t

        else:
            self.format=1
            wtdict = self.wt_dict
            nsides = [v['nside'] for v in wtdict.values() ];

            if config.verbose>1:
                print(f'WeightMan: {srcfile} : new format, {len(nsides)} bamds'\
                      f' with nsides {nsides[0]} to {nsides[-1]}')
            if self.source is not None:
                self.source.fit_info = self.fitinfo
                if config.verbose>1:
                    print(f'\tAdded fit info {self.fitinfo} to source')

    def _old_format(self, photons):
        if not self.config.nest:
            # data are RING
            photon_pix = healpy.ring2nest(config.nside, photons.nest_index.values)
        else:
            photon_pix = photons.nest_index.values
        nside = self.nside_wt
        to_shift = 2*int(np.log2(self.config.nside//self.nside_wt));
        shifted_pix =   np.right_shift(photon_pix, to_shift)
        bad = np.logical_not(np.isin(shifted_pix, self.wt_pix))
        if self.config.verbose>0 & sum(bad)>0:
            print(f'\tApplying weights: {sum(bad)} / {len(bad)} photon pixels are outside weight region')
        if sum(bad)==len(bad):
            a = np.array(healpy.pix2ang(nside, self.wt_pix, nest=True, lonlat=True)).mean(axis=1).round(1)
            b = np.array(healpy.pix2ang(nside, shifted_pix, nest=True, lonlat=True)).mean(axis=1).round(1)

            raise Exception(f'There was no overlap of the photon data at {b} and the weights at {a}')
        shifted_pix[bad] = 12*nside**2 # set index to be beyond pixel indices

        # find indices with search and add a "weights" column
        # (expect that wt_pix are NEST ordering and sorted)
        weight_index = np.searchsorted(self.wt_pix,shifted_pix)
        band_index = np.fmin(31, photons.band.values) #all above 1 TeV into last bin

        # final grand lookup -- isn't numpy wonderful!
        photons.loc[:,'weight'] = self.wts[tuple([band_index, weight_index])].astype(np.float16)

    def _new_format(self, photons):

        wt_tables =self.wt_dict
        data_nside=1024



        #photons = photons.rename(columns=dict(weight='old_wt'))
        photons.loc[:,'weight'] = np.nan

        if self.config.verbose>1:
            print(f'WeightMan: processing {len(photons):,} photons')

        def load_data( band_id):
            """ fetch pixels and weights for the band;
                adjust pixels to the band nside
                generate mask for pixels, weights
            """
            band = photons[photons.band==band_id] #.query('band== @band_id')
            wt_table = wt_tables[band_id]
            nside =  wt_table['nside']
            new_weights = wt_table['wts'].astype(np.float16)
            to_shift = int(2*np.log2(data_nside//nside))
            data_pixels = np.right_shift(band.nest_index, to_shift)
            wt_pixels=wt_table['pixels']
            good = np.isin( data_pixels, wt_pixels)
            if self.config.verbose>2:
                print(f'\t {band_id:2}: {len(band):8,} -> {sum(good ):8,}')
            return data_pixels, new_weights, good

        def set_weights(band_id):
            if band_id not in wt_tables.keys(): return

            data_pixels, new_weights, good = load_data(band_id)
            wt_pixels = wt_tables[band_id]['pixels']
            indicies = np.searchsorted( wt_pixels, data_pixels[good])
            new_wts = new_weights[indicies]
            # get subset of photons in this band, with new weights
            these_photons = photons[photons.band==band_id][good]
            these_photons.loc[:,'weight']=new_wts
            photons.loc[photons.band==band_id,'weight'] = (these_photons.weight).astype(np.float16)
    #         if self.config.verbose>1:
    #             print(f' -> {len(new_wts):8,}')

        for band_id in range(16):
            set_weights(band_id)

        return photons

    def add_weights(self, photons):
        """
        get the photon pixel ids, convert to NEST (if not already) and right shift them
        add 'weight', remove 'band', 'nest_index'

        """
        if self.format==0:
            self._old_format(photons)
        else:
            photons = self._new_format(photons)

        # don't need these columns now (add flag to config to control??)
        photons.drop(['nest_index'], axis=1, inplace=True)

        if self.config.verbose>1:
            print(f'\t{sum(np.isnan(photons.weight.values)):,} events without weight')
        return photons

def weight_radius_plots(photons):
    """
    """
    import matplotlib.pyplot as plt

    fig, axx = plt.subplots(2,8, figsize=(16,5), sharex=True, sharey=True)
    plt.subplots_adjust(hspace=0.02, wspace=0)
    for id,ax in enumerate(axx.flatten()):
        subset = photons.query('band==@id & weight>0')
        ax.semilogy(subset.radius, subset.weight, '.', label=f'{id}');
        ax.legend(loc='upper right', fontsize=10)
        ax.grid(alpha=0.5)
    ax.set(ylim=(8e-4, 1.2), xlim=(0,4.9))
    plt.suptitle('Weights vs. radius per band')

# Cell

class PointSource():
    """Manage the position and name of a point source
    """
    def __init__(self, name, position=None, nickname=None):
        """
           * position: [None] (l,b) tuple or None. if None, expect to be found by lookup
           * nickname [None] -- use for file name, especially weights (in config.weight_folder)

        """
        self.name=name
        if position is None:
            if name.startswith('J'):
                # parse the name for (ra,dec)
                ra=(name[1:3]+'h'+name[3:7]+'m')
                dec = (name[7:10]+'d'+name[10:12]+'m')
                (ra,dec) = map(lambda a: float(Angle(a, unit=u.deg).to_string(decimal=True)),(ra,dec))
                skycoord = SkyCoord(ra, dec, unit='deg', frame='fk5')
            else:
                try:
                    skycoord = SkyCoord.from_name(name)
                except Exception as e:
                    print(f'PointSource: a source "{name}" was not recognized by SkyCoord: {e}', file=sys.stderr)
                    raise
            gal = skycoord.galactic
            self.l,self.b = (gal.l.value, gal.b.value)
        else:
            self.l,self.b = position
            skycoord = SkyCoord(self.l,self.b, unit='deg', frame='galactic')
        self.skycoord = skycoord
        self.nickname = nickname

    def __str__(self):
        return f'Source "{self.name}" at: (l,b)=({self.l:.3f},{self.b:.3f})'
    def __repr__(self): return str(self)

    @property
    def ra(self):
        sk = self.skycoord.transform_to('fk5')
        return sk.ra.value
    @property
    def dec(self):
        sk = self.skycoord.transform_to('fk5')
        return sk.dec.value

    @property
    def filename(self):
        """Modified name for file system"""
        return self.name.replace(' ', '_').replace('+','p') if self.nickname is None else self.nickname

    @classmethod
    def fk5(cls, name, position):
        """position: (ra,dec) tuple """
        ra,dec = position
        sk = SkyCoord(ra, dec, unit='deg',  frame='fk5').transform_to('galactic')
        return cls(name, (sk.l.value, sk.b.value))

    @property
    def spectral_model(self):
        if not hasattr(self, 'fit_info'): return None

        fi = self.fit_info
        if fi['modelname']=='LogParabola':
            return self.LogParabola(fi['pars'])
        else:
            raise Exception(f'PointSource: Unrecognized spectral model name {fi["modelname"]}')


    class FluxModel():

        emin, emax = 1e2, 1e5
        def photon_flux(self):
            return quad(self, self.emin, self.emax)[0]

        def energy_flux(self):
            func = lambda e: self(e) * e**2
            return quad(func, self.emin, self.emax)[0]

    class LogParabola(FluxModel):

        def __init__(self, pars):
            self.pars=pars
        def __call__(self, e):
            n0,alpha,beta,e_break=self.pars
            x = np.log(e_break/e)
            y = (alpha - beta*x)*x
            return n0*np.exp(y)