# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback
import k1lib, torch, math, numpy as np; from functools import partial
import matplotlib.pyplot as plt
def allocated() -> int: return torch.cuda.memory_allocated()
class MemoryData:
    def __init__(self, mProfiler, mS:k1lib.selector.ModuleSelector):
        self.mProfiler = mProfiler; self.mS = mS
        self.handles = k1lib.Object.fromDict({"fp":0,"f":0,"b":0})
        self.values = k1lib.Object.fromDict({"fp":0,"f":0,"b":0})
        self.hook(); self.startMemory = allocated()
    def hook(self):
        def hk(v, m, i, o=None):
            self.values[v] += (value := allocated() - self.startMemory)
            if v == "f" or v == "b":
                self.mProfiler.linear.append(value)
                self.mProfiler.linState.append(0)
                self.mProfiler.linSignature.append(self.mS.signature)
        self.handles.fp = self.mS.nnModule.register_forward_pre_hook(partial(hk, "fp"))
        self.handles.f = self.mS.nnModule.register_forward_hook(partial(hk, "f"))
        self.handles.b = self.mS.nnModule.register_backward_hook(partial(hk, "b"))
    def unhook(self):
        self.handles.fp.remove(); self.handles.f.remove(); self.handles.b.remove()
    def __getstate__(self):
        answer = dict(self.__dict__)
        del answer["mS"]; del answer["mProfiler"]; return answer
    def __setstate__(self, state): self.__dict__.update(dict(state))
    def __str__(self):
        fp = f"fp({k1lib.format.size(self.values.fp)})".ljust(14)
        f = f"f({k1lib.format.size(self.values.f)})".ljust(13)
        b = f"b({k1lib.format.size(self.values.b)})".ljust(13)
        delta = f"delta({k1lib.format.size(self.values.f - self.values.fp)})".ljust(17)
        return f"{b} {delta} {fp} {f}"
class MemoryProfiler(Callback):
    def startRun(self):
        if self.selector == self.learner.selector: # if no selectors found
            self.selector = self.learner.selector.copy().clearProps()
        for m in self.selector.modules(): m.data = MemoryData(self, m)
        self.selector.displayF = lambda m: (k1lib.format.red if m.selected("_memProf_") else k1lib.format.identity)(m.data)
        self.linear:List[int] = [] # bytes of each mS's passes
        self.linState:List[bool] = [] # selected segments, used in plot
        self.linSignature:List[int] = [] # list of mS's signatures
    def startStep(self): return True
    def endRun(self):
        self.linear = np.array(self.linear)
        self.linState = np.array(self.linState); self._updateLinear()
    def run(self):
        with self.cbs.context(), self.cbs.suspendEvaluation():
            self.cbs.withCuda(); self.learner.run(1, 1)
        for m in self.selector.modules(): m.data.unhook()
    def _updateLinear(self):
        def applyF(m):
            for i in range(len(self.linear)):
                if self.linSignature[i] == m.signature:
                    self.linState[i] = m.selected("_memProf_")
        self.selector.apply(applyF)
    def css(self, css:str):
        self.selector.parse(k1lib.selector.filter(css, "_memProf_"))
        self._updateLinear(); print(self.__repr__())
        self.selector.clearProps(); self._updateLinear()
    def __repr__(self):
        plt.figure(dpi=120); plt.grid(True)
        l = self.linear; s = self.linState; plt.xlabel("Time")
        l=l/1000**(idx := math.floor(math.log10(l.max())/3))
        plt.ylabel(k1lib.format.sizes[idx])
        k1lib.viz.plotSegments(range(len(l)), l, s); plt.show()
        params = k1lib.format.item(sum([p.numel() for p in self.model.parameters()]))
        return f"""MemoryProfiler (params: {params}):
{k1lib.tab(self.selector.__repr__(intro=False))}

Can...
- mp.css("..."): highlights a particular part of the network
- mp.selector: to get internal k1lib.ModuleSelector object"""