# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/90_main.ipynb.

# %% auto 0
__all__ = ['WtLike']

# %% ../nbs/90_main.ipynb 2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from .bayesian import get_bb_partition
from .lightcurve import fit_cells, LightCurve, flux_plot
from .cell_data import partition_cells
from .config import *


class WtLike(LightCurve):
    """
    Initialization
    --------------
    There are three layers of initialization, implemented in superclasses,
    each with parameters. The classnames, associated parameters and data members set:

    SourceData -- load photons and exposure
        parameters:
          - source : name, a PointSource object, or a Simulation object
          - config [Config()] : basic configuration
          - week_range [None] : range of weeks to load
          - key [''] : the cache key: '' means construct one with the source name, None to disable
          - clear [False] : if using cache, clear the contents first
        sets:
          - photons
          - exposure

    CellData -- create cells
        parameters:
          - time_bins [Config().time_bins] : binning: start, stop, binsize
        sets:
          - cells
        creates copies with new cells:
          - view
          - phase_view

    LightCurve -- likelihood analysis of the cells
        parameters:
          - e_min [10] -- threshold for exposure (cm^2 units)
          - n_min [2]  -- likelihood has trouble with this few
          - lc_key [None] -- possible cache for light curve
        sets:
          - fits, fluxes

    WtLike (this class) -- no parameters (may add BB-specific ones)
        Implements:  bb_view, plot_BB
        sets:
          - bb_flux  (only if bb_view invoked)

    """
    def bb_view(self, p0=0.05, key=None, clear=False):
        """Return a view with the BB analysis applied

        - p0 -- false positive probability parameter

        Its `plot` function will by default show an overplot on the parent's data points.
        """
        #  a new instance
        r = self.view()

        # bb analysis on this to make new  set of cells and poisson fits
        bb_edges  = get_bb_partition(self.config, self.fits,  p0=p0, key=key, clear=clear)
        r.cells = partition_cells(self.config, self.cells, bb_edges)

        r.fits = fit_cells(self.config, r.cells, )
        r.isBB = True
        r.bayes_p0 = p0
        return r

    def plot(self, *pars, **kwargs):
        # which view type is this?
        if getattr(self, 'isBB', False):
            return self.plot_bb(*pars, **kwargs)
        elif getattr(self, 'is_phase', False):
            return self.plot_phase(*pars, **kwargs)
        else:
            return super().plot(*pars, **kwargs)

    def plot_bb(self, ax=None, **kwargs):
        """Plot the light curve with BB overplot
        """
        import matplotlib.pyplot as plt
        self.check_plot_kwargs(kwargs)
        figsize = kwargs.pop('figsize', (12,4))
        fignum = kwargs.pop('fignum', 1)
        ts_min = kwargs.pop('ts_min',-1)
        source_name =kwargs.pop('source_name', self.source_name)
        fig, ax = plt.subplots(figsize=figsize, num=fignum) if ax is None else (ax.figure, ax)


        colors = kwargs.pop('colors', ('lightblue', 'wheat', 'blue') )
        flux_plot(self.parent.fits, ax=ax, colors=colors, source_name=source_name,
                  label=self.step_name+' bins', **kwargs)
        flux_plot(self.fits, ax=ax, step=True,
                  label=f'BB (p0={100*self.bayes_p0:.0f}%)', zorder=10,**kwargs)
        ax.grid(alpha=0.5)
        fig.set_facecolor('white')
        return fig

    def plot_phase(self, ax=None, **kwargs):
        """Plot a phase lightcurve

        """
        kw = dict(ylim=(0.975, 1.025), xlim=(0,1) )
        kw.update(kwargs)
        fig, ax = plt.subplots(figsize=(10,5)) if ax is None else (ax.figure, ax)
        fig = super().plot( ax=ax, xlabel=f'phase for {self.period:.3f}-day period');
        ax.set(**kwargs );
        ax.axhline(1.0, color='grey');
        return fig

    def reweighted_view(self, other):
        """
        Return a view in which the weights have been modifed to account for a variable neighbor

        * other -- the WtLike BB view of the neighbor.
        """
        import copy
        assert getattr(other, 'isBB', False), 'Expected a bb_view'
        r = copy.copy(self)
        r.parent = self
        r.photons = photons = self.photons.copy()

        def get_w2():
            nborwtman = other.source.wtman
            mod_photons = nborwtman.add_weights(photons.copy())
            return mod_photons.weight

        def get_alpha2():
            bbf = other.fluxes
            t,tw = bbf.t.values, bbf.tw.values
            bb_edges = np.append(t-tw/2, t[-1]+tw[-1]/2)
            ssi = np.searchsorted(bb_edges, photons.time)
            # isert nans at both ends of bb blux
            bbflux = np.append(np.insert(bbf.flux.values, 0, np.nan), np.nan)
            return bbflux[ssi]-1

        def wprime(row):
            """ Apply
            $$w'_1 = \frac{w_1}{1+\alpha_2\ w_2}\ \    $$
            """
            w1 = row['weight']
            a2 = row['alpha2']
            w2 = row['w2']
            return w1/(1+a2*w2)

        # alpha2 = get_alpha2()
        # w2 = get_w2()

        r.photons.loc[:,'w2'] = get_w2().astype(np.float32)
        r.photons.loc[:,'alpha2'] = get_alpha2().astype(np.float32)

        def fix_weights(df):
            """ return modified weights """
            w1 = df.weight
            w2 = df.w2
            a2 = df.alpha2
            return np.where( np.isnan(w2) | np.isnan(a2),
                           w1,
                           w1/(1+a2*w2) )

        #photons.loc[:,'oldw'] = photons.weight
        r.photons.loc[:,'weight'] = fix_weights(photons).astype(np.float32)

        # new weights, so must rebin, then fit them
        r.rebin(self.time_bins)
        r.update()

        return r
