# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_load_data.ipynb.

# %% auto 0
__all__ = ['sc_data_selection', 'load_source_data']

# %% ../nbs/04_load_data.ipynb 4
import pickle, healpy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from .config import (Config, UTC, MJD)
from .exposure import binned_exposure, sc_data_selection, sc_process, weighted_aeff
from .data_man import get_week_files

# %% ../nbs/04_load_data.ipynb 5
class ConeSelect():
    """Manage selection of pixels with cone

    """

    def __init__(self, config, l,b):

        cart = lambda l,b: healpy.dir2vec(l,b, lonlat=True)
        self.conepix = healpy.query_disc(config.nside, cart(l,b), np.radians(config.radius), nest=config.nest)

    def __call__(self, pixels, shift=11):
        """
        For a set of pixels, return a mask for the cone
        """
        # select by comparing high-order pixels (faster)
        a = np.right_shift(pixels, shift)
        c = np.unique(np.right_shift(self.conepix, shift))
        return np.isin(a,c)

# %% ../nbs/04_load_data.ipynb 7
def _get_photons_near_source(config, source, week): #tzero, photon_df):
    """
    Select the photons near a source

    - source : a PointSource object
    - week : dict with
        - tzero : start time for the photon
        - photons : dict with photon data
        - runlist : list of run numbers

    Returns a DF with
    - `band` index,
    - `time` in MJD (added tstart and converted from MET)
    - `pixel` index, nest indexing
    - `radius` distance in deg from source direction
    """

    def _cone(config, source, nest=True):
        # cone geometry stuff: get corresponding pixels and center vector
        l,b,radius = source.l, source.b, config.radius
        cart = lambda l,b: healpy.dir2vec(l,b, lonlat=True)
        conepix = healpy.query_disc(config.nside, cart(l,b), np.radians(radius), nest=nest)
        center = healpy.dir2vec(l,b, lonlat=True)
        return center, conepix

    center, conepix = _cone(config,source)

    df = pd.DataFrame.from_dict(week['photons'])

    tstart = week['tstart']
    allpix = df.nest_index.values

    # select by comparing high-order pixels (faster)
    shift=11
    a = np.right_shift(allpix, shift)
    c = np.unique(np.right_shift(conepix, shift))
    incone = np.isin(a,c)

    if sum(incone)<2:
        if config.verbose>1:
            print(f'\nWeek starting {UTC(MJD(tstart))} has 0 or 1 photons')
        return

    if config.verbose>2:
        a, b = sum(incone), len(allpix)
        print(f'Select photons for source {source.name}:\n\tPixel cone cut: select {a} from {b} ({100*a/b:.1f}%)')

    # cut df to entries in the cone
    dfc = df[incone].copy()

    # assemble the DataFrame, remove those outside the radius
    out_df = dfc

    # make sure times are monotonic by sorting (needed since runs not in order in most
    #  week-files after March 2018)
    # out_df = dfc.sort_values(by='time')

    if config.verbose>2:
        print(f'selected photons:\n{out_df.head()}')

    return out_df

# %% ../nbs/04_load_data.ipynb 8
def sc_data_selection(config, source, sc_data):

    """
    Return a DataFrame with the S/C data for the source direction, wtih cos theta and zenith cuts

    columns:
    - start, stop, livetime -- from the FT2 info
    - cos_theta -- angle between bore and direction
    - exp -- the exposure: effective area at angle weighted by a default spectral function, times livetime

    """

    sc_df = sc_process(config, source, sc_data)
    if len(sc_df)==0:
        return sc_df
    cos_theta = sc_df.cos_theta.values
    livetime = sc_df.livetime.values
    func = weighted_aeff(config, source)

    sc_df.loc[:,'exp'] = exp = (func(cos_theta) * livetime).astype(np.float32)

    # add detailed exposure info, the fraction of the total in each band
    if config.get('full_exp', False):
        A,L = func.binned(cos_theta), livetime
        fract = ((A.T * L/exp).T ).astype(np.float16)
        sc_df.loc[:,'exp_fract' ] = [list(x) for x in fract]

    return sc_df

# %% ../nbs/04_load_data.ipynb 11
class ProcessWeek(object):
    """
    Process a week's photon and livetime info into the source-related photon and exposure tables.
    """

    def __init__(self, config, source, week_file):
        """

        """
        with open(week_file, 'rb') as inp:
            week = pickle.load(inp)

        # convert the photon and spacecraft dicts to DataFrames
        pdf = pd.DataFrame(week['photons'])
        sc_data = edf = pd.DataFrame.from_dict(week['sc_data'])
        self.runlist = week.get('runlist', None)
        self.start = MJD(week['tstart'])
        self.config = config

        if config.verbose>1:
            print(f'Opened week file "{week_file.name}" of {UTC(self.start)}')
            print(f'\tFound {len(pdf):,} photons, {len(edf):,} SC entries)')

        self.sc_df = sc_df = sc_data_selection(config, source, sc_data)

        # interleaved start/stop
        self.stime = np.empty(2*len(sc_df.start))
        self.stime[0::2]=sc_df.start.values
        self.stime[1::2]=sc_df.stop.values
        assert np.all(np.diff(self.stime)>=0), 'Time-ordering failure'

        self.lt = sc_df.livetime.values
        self.ct = sc_df.cos_theta.values


        pdf = _get_photons_near_source(config,source, week)
        if pdf is None or len(pdf)<3 :
            self.photons = None
        else:
            assert pdf is not None and len(pdf)>0

            # set weights from the weight table, removing those with no weight
            pdf = source.wtman.add_weights(pdf)

            # finally set the time and the exposure per remaining photons
            self.photons = self.photon_times( pdf )

    def __str__(self):
        return f'Data for week of {UTC(self.start)}: {len(self.photons):,} photons'

    def __repr__(self): return self.__str__()

    def photon_times(self, pdf):

        # construct the time from the run number and offset
        # if  'run_id' in pdf: run = pdf.run_id.astype(float)
        # elif 'run_ref' in pdf: run = self.runlist[pdf.run_ref]
        # else:
        #     raise Exception('Expect run_id or run_ref')
        if not 'run_ref' in pdf:
            raise Exception('Old format data: recreate to insert run_ref')
        run = self.runlist[pdf.run_ref]
        ptime = MJD(run + pdf.trun * self.config.offset_size)
        pdf.loc[:,'time'] = ptime

        # select the subset with exposure info
        tk = np.searchsorted(self.stime, ptime)
        good_exp = np.mod(tk,2)==1
        pdfg = pdf[good_exp].copy()
        if len(pdfg)==0:
            return None
        pdfg.drop(columns=['trun', 'run_ref'], inplace=True)
        # time edges-- same for each band
        #xp = np.append(self.stime[0::2],self.stime[-1])

        return pdfg

    def hist_spacecraft(self):
        self.sc_df.hist('livetime cos_theta exp'.split(), bins=100, layout=(1,3), figsize=(12,3));

    def hist_photons(self):
        self.photons.hist('band time'.split(), bins=100, log=True, figsize=(12,3), layout=(1,3));


    def __call__(self):
        return dict(
            start= self.start,
            photons=self.photons,
            exposure=self.sc_df,

        )

# %% ../nbs/04_load_data.ipynb 12
class _TWeek():
    # This is a functor wrapping ProcessWeek which needs to be global for multprocessing.
    def __init__(self, config, source):
        self.config=config
        self.source=source

    def __call__(self, wkf):
        print('.', end='')
        eman = ProcessWeek( self.config, self.source, wkf)
        return (eman.photons, eman.sc_df)

def multiprocess_week_data(config, source, week_range, processes=None):
    """ Manage processing of set of week files with multiprocessing
    """

    from multiprocessing import Pool

    processes = processes or config.pool_size
    week_files = get_week_files(config,  week_range)
    txt = f', using {processes} processes ' if processes>1 else ''

    if config.verbose>0:
        print(f'\tProcessing {len(week_files)} week files {week_files[0].name} - {week_files[-1].name} {txt}', end='', flush=True)

    process_week = _TWeek(config, source)

    if processes>1:
        with Pool(processes=processes) as pool:
            week_data = pool.map(process_week, week_files)
    else:
        week_data = map(process_week,  week_files)
    print('\n')

    pp = []
    ee = []

    for wk in week_data:
        # append week data to photons, weighted exposure, band exposure
        pdf,edf = wk
        if pdf is not None and len(pdf)>2:
            pp.append(pdf)
        if len(edf)>0:
            ee.append(edf)

    return pp,ee

# %% ../nbs/04_load_data.ipynb 13
def load_source_data(config, source, week_range=None, key='', clear=False):
    """
    Generate photon and exposure tables specific to the source.

    - week_range [None] -- if None, select all weeks
    - key ['']   -- key to use for cache, construct from name if not set
    - clear [False]

    For the given source returns a tuple of
    - photons
    - exposure
    - the key

    """

    if config.datapath/'data_files' is None and key not in config.cache:
        raise Exception(f'Data for {source.name} is not cached, and config.datapath/"data_files" is not set')

    def load_from_weekly_data(config, source, week_range=None):

        pp, ee =  multiprocess_week_data(config, source, week_range)

        # concatenate the two lists of DataFrames
        p_df = pd.concat(pp, ignore_index=True)
        #p_df.loc[:,'run_id'] = pd.Categorical(p_df.run_id)
        e_df = pd.concat(ee, ignore_index=True)

        return p_df, e_df

    description=f'SourceData:  {source.name}' if config.verbose>0 else ''

    used_key = None # change if used cache
    weeks=f'weeks_{week_range[0]}-{week_range[1]}' if week_range is not None else 'data'
    if key is None:
        # always load directly if weeks specified or key set to None
        if config.verbose>0: print(description)
        r = load_from_weekly_data(config, source, week_range=week_range)
    else:
        # use the cache
        used_key = f'{source.filename}_{weeks}' if key=='' else key
        r = config.cache(used_key,
                    load_from_weekly_data, config, source, week_range=week_range,
                    overwrite=clear,
                    description=description)
    # append key used for retrieval
    return list(r) + [used_key]
