# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback, Callbacks
import k1lib, numpy as np, math
from functools import partial
import matplotlib.pyplot as plt
from typing import Callable
__all__ = ["Loss", "Accuracy"]
def plotF(losses, _slice): # actual function stored by the sliceable plot
    plt.figure(figsize=(10, 3), dpi=100); step = _slice.step or 1
    tR, vR = k1lib.Range.proportionalSlice(len(losses.train), len(losses.valid), _slice)
    try:
        plt.subplot(1, 2, 1); plt.plot(tR.range_[::step], losses.train[tR.slice_][::step]); plt.title(f"Train loss")
        plt.subplot(1, 2, 2); plt.plot(vR.range_[::step], losses.valid[vR.slice_][::step]); plt.title(f"Valid loss")
    except: pass
def commonPlot(obj):
    return k1lib.viz.SliceablePlot(partial(plotF, obj), docs="""\n\nReminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame""")
def nonEmptyList(_list):
    return [0] if _list == [] else _list
@k1lib.patch(Callback.cls)
class Loss(Callback):
    "Records losses after each batch"
    def __init__(self):
        super().__init__(); self.order = 20
        self.train = []; self.valid = [] # all stats all times
        # average stats for each epoch
        self.epoch = k1lib.Object.fromDict({"train": [], "valid": []})\
                        .withRepr("Use...\n" +\
                                 "- `.train` for epoch-averaged training losses\n" +\
                                 "- `.valid` for epoch-averaged validation losses\n" +\
                                 "- `.plot()` to plot the 2 above")
        self.plot = partial(commonPlot, self)
        self.epoch.plot = partial(commonPlot, self.epoch)
        self._trainLosses = []; self._validLosses = []
    def endLoss(self):
        if self.model.training: self._trainLosses.append(self.loss)
        else: self._validLosses.append(self.loss)
    def endEpoch(self):
        self.train.extend(self._trainLosses); self.epoch.train.append(np.mean(nonEmptyList(self._trainLosses)))
        self.valid.extend(self._validLosses); self.epoch.valid.append(np.mean(nonEmptyList(self._validLosses)))
        self._trainLosses = []; self._validLosses = []
    def __repr__(self):
        return f"""{super()._reprHead}, use...
- cb.train: for all training losses over all epochs and batches (#epochs * #batches)
- cb.valid: for all validation losses over all epochs and batches (#epochs * #batches)
- cb.plot(): to plot the 2 above
- cb.epoch: for average losses of each epochs
{super()._reprCan}"""
@k1lib.patch(Callbacks, docs=Loss)
def withLoss(self): return self.append(Loss())
@k1lib.patch(Callback.cls)
class Accuracy(Callback):
    "Records accuracies after each batch. Have to define an .accuracyF() function to use first"
    def __init__(self, accuracyF):
        super().__init__(); self.order = 20; self.accuracyF = accuracyF
        self.train = [0]; self.valid = [0]
    def startRun(self):
        self.train = list(self.train); self.valid = list(self.valid)
    def endRun(self):
        self.train = np.array(self.train); self.valid = np.array(self.valid)
    def endLoss(self):
        (self.train if self.model.training else self.valid).append(self.accuracyF(self.learner))
    def plot(self):
        def plotF(_slice):
            plt.figure(figsize=(10, 3), dpi=100); step = _slice.step or 1
            tR, vR = k1lib.Range.proportionalSlice(len(self.train), len(self.valid), _slice)
            try:
                plt.subplot(1, 2, 1); plt.plot(tR.range[::step], 100*self.train[tR.slice][::step]); plt.title(f"Train accuracy")
                plt.subplot(1, 2, 2); plt.plot(vR.range[::step], 100*self.valid[vR.slice][::step]); plt.title(f"Valid accuracy")
            except: pass
        return k1lib.viz.SliceablePlot(plotF)
    def __repr__(self):
        return f"""{super()._reprHead}{f" (.accuracyF not defined yet)" if self.accuracyF == None else ""}, use...
- a.train: for train accuracies over all batches
- a.valid: for train accuracies over all batches
- a.plot(): to plot the 2 above
{super()._reprCan}"""
@k1lib.patch(Callbacks, docs=Accuracy)
def withAccuracy(self, accuracyF:Callable[["Learner"], float]):
    return self.append(Accuracy(accuracyF))