#!/usr/bin/env python

#  Copyright (c) 2016-2019, Broad Institute, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name Broad Institute, Inc. nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import argparse
import csv
import h5py
import random
import re
import string
from itertools import combinations
import kmertools
import kmerizer
import numpy as np
import pydot
from functools import reduce

treeNodeCount = 0

class KmerTreeNode:
    """Class to recursively hold tree representation"""
    def __init__(self, kmers, name = None, children = [], counts = None, similarity = None):
        global treeNodeCount
        if name:
            self.name = name
        else:
            treeNodeCount += 1
            self.name = "node-" + str(treeNodeCount)
        self.kmers = kmers
        self.counts = counts
        self.children = children
        self.score = None
        self.common = None
        self.denominator = None
        if similarity:
            self.distance = 1 - similarity
        else:
            self.distance = 1

    def isLeaf(self):
        return not self.children

    def label(self):
        s = ""
        if self.counts is not None:
            s += str((self.counts == 1).sum()) + "/"
        s += str(self.kmers.size)
        s += "\n" + self.name
        return s

    def nwk(self):
        if self.isLeaf():
            return self.name
        else:
            return "(" + ",".join(["%s:%g" % (c.nwk(), self.distance) for c in self.children]) + ")"

    def saveHdf5(self, h5node, compress = True):
        if compress is True:
            compress = "gzip"
        h5node.attrs["type"] = np.string_("KmerTreeNode")
        h5node.attrs["name"] = np.string_(self.name)
        #h5.attrs["k"] = self.k;
        if self.kmers is not None:
            h5node.create_dataset("kmers", data=self.kmers, compression=compress)
        if self.counts is not None:
            h5node.create_dataset("counts", data=self.counts, compression=compress)
        for child in self.children:
            childGroup = h5node.create_group(child.name)
            child.saveHdf5(childGroup, compress=compress)

    def saveSampleHdf5(self, h5node, compress = True):
        if compress is True:
            compress = "gzip"
        h5node.attrs["type"] = np.string_("SampleScoreNode")
        h5node.attrs["name"] = np.string_(self.name)
        
        if self.score is not None:
            h5node.attrs["score"] = self.score
        else:
            h5node.attrs["score"] = 0
        
        if self.denominator is not None:
            h5node.attrs["denominator"] = self.denominator
        else:
            h5node.attrs["denominator"] = 0
        
        if self.common is not None:
            h5node.create_dataset("common", data=self.common, compression=compress)
        
        for child in self.children:
            childGroup = h5node.create_group(child.name)
            child.saveSampleHdf5(childGroup, compress=compress)

    def saveText(self, outfile):
        header = ["%", self.name] + [child.name for child in self.children]
        print("\t".join(header), file=outfile)

        if self.counts is None:
            for i in range(self.kmers.size):
                    print(kmertools.kmerString(args.K, int(self.kmers[i])), file=outfile)
        else:
            for i in range(self.kmers.size):
                print("%s\t%d" % (kmertools.kmerString(args.K, int(self.kmers[i])), self.counts[i]), file=outfile)
        for child in self.children:
            child.saveText(outfile)

    def computeScoreWeighted(self, kmers):
        """Score this node vs a set of kmers"""
        weights = 1.0 / (self.counts * self.counts)
        self.denominator = weights.sum()
        if self.denominator > 0:
            subCounts = kmerizer.intersect_counts(self.kmers, self.counts, kmers)
            self.score = (1.0 / (subCounts * subCounts)).sum() / self.denominator
            self.common = kmerizer.intersect(self.kmers, kmers)
        else:
            self.score = 0

    def computeScoreUnique(self, kmers):
        """Score this node vs a set of kmers"""
        uniqueKmers = self.kmers[self.counts == 1]
        if uniqueKmers.size > 0:
            #self.common = np.intersect1d(uniqueKmers, kmers, assume_unique = True)
            self.common = kmerizer.intersect(uniqueKmers, kmers)
            self.denominator = float(uniqueKmers.size)
            if self.denominator > 0:
                self.score = float(self.common.size) / self.denominator
            else:
                self.score = 0
        else:
            self.score = 0

    def computeScore(self, kmers):
        """Score this node vs a set of kmers"""
        if self.kmers.size > 0:
            #self.common = np.intersect1d(self.kmers, kmers, assume_unique = True)
            self.common = kmerizer.intersect(self.kmers, kmers)
            self.denominator = float(self.kmers.size)
            if self.denominator > 0:
                self.score = float(self.common.size) / self.denominator
            else:
                self.score = 0 
        else:
            self.score = 0



def getScore(nodePair, scores, simscore="jaccard"):
    if nodePair in scores:
        score = scores[nodePair]
    else:
        node1, node2 = nodePair
        k1 = node1.kmers
        k2 = node2.kmers
        score = kmertools.similarityScore(k1, k2, simscore)
        scores[nodePair] = score
    return score

def similarityScores(tree, scores, simscore="jaccard"):
    # calculate the intersection stats of kmers for all tree node pairs
    pairScores = [(pair, getScore(pair, scores, simscore)) for pair in combinations(tree, 2)]
    pairScores.sort(lambda a, b: cmp(b[1], a[1]))
    return pairScores

def newTreeNode(children, similarity=None):
    """
    Make a new node, propagating up the kmers in common among the children, leaving the children with
    their kmers not in common.
    :param children: child nodes
    :return: the new node
    """
    intersection = reduce(lambda x, y: kmerizer.intersect(x.kmers, y.kmers), children)
    node = KmerTreeNode(intersection, children=children, similarity=similarity)
    for child in children:
        child.kmers = np.setdiff1d(child.kmers, intersection, assume_unique=True)
    return node


def buildTree(tree, scores, simscore="jaccard"):
    pairScores = similarityScores(tree, scores, simscore)
    closePair, score = pairScores[0]
    if score == 0:
        print("Warning! Zero score detected. No shared k-mers left.", file=sys.stderr)
    newNode = newTreeNode(list(closePair), similarity=score)
    newTree = [x for x in tree if x not in closePair]
    newTree.append(newNode)
    print("%.5f" % score, end=' ')
    return newTree


def graphTree(graph, tree, parentNode):
    leaf = tree.isLeaf()
    nodeLabel = tree.label()
    color = 1
    textcolor = 9
    id = tree.name
    if tree.score is not None:
        frac = tree.score
        nodeLabel = "%.0f%% %s" % (100.0 * frac, nodeLabel)
        color = int(round(frac * 8.0)) + 1
        if color >= 7:
            textcolor = 1
    node = pydot.Node(id, label=nodeLabel, style="filled", colorscheme='blues9', fillcolor=color, fontcolor=textcolor)
    graph.add_node(node)
    if parentNode:
        graph.add_edge(pydot.Edge(parentNode, node))
    if not leaf:
        for child in tree.children:
            graphTree(graph, child, node)



def gatherKmers(node):
    if node.isLeaf():
        return [node.kmers]
    else:
        return [kmers for child in node.children for kmers in gatherKmers(child)] + [node.kmers]

def dedupeTree(tree):
    def dedupe(node):
        node.kmers = node.kmers[node.counts == 1]
        node.counts = None
        for child in node.children:
            dedupe(child)
        return

    def simplifyTree(node, parent = None):
        # first simplify tree below us
        for child in node.children:
            simplifyTree(child, node)
        # adopt grandchildren where our child is dead
        deadChildren = [child for child in node.children if child.kmers.size == 0 and not child.isLeaf()]
        grandChildren = [grandChild for child in deadChildren for grandChild in child.children]
        for child in deadChildren:
            node.children.remove(child)
        #print node.name, node.children, grandChildren
        node.children += grandChildren

    dedupe(tree)
    simplifyTree(tree)


def countTree(tree):

    def setCounts(node, allKmers, allCounts):
        OFFSET = 1000000000
        counts = np.full(node.kmers.size, OFFSET, dtype=np.int64)
        merged = kmerizer.merge_counts(node.kmers, counts, allKmers, allCounts)
        mine = merged[1] >= OFFSET
        node.kmers = merged[0][mine]
        node.counts = merged[1][mine]
        node.counts -= OFFSET
        for child in node.children:
            setCounts(child, allKmers, allCounts)
        return

    print('Setting node kmer counts')
    kmerList = gatherKmers(tree)
    allKmers = np.concatenate(kmerList)
    distinctKmers = np.unique(allKmers, return_counts=True)
    setCounts(tree, distinctKmers[0], distinctKmers[1])


def treeFromNWK(nwkFile, nodes):
    nwk = open(nwkFile, 'r')
    quoted = re.sub(r'([^();,]+)', '"\\1"', nwk.read().strip().replace(";", ""))
    nwk.close()
    treeStructure = eval(quoted)

    nodesByName = {node.name: node for node in nodes}

    def makeTree(node):
        if isinstance(node, str):
            return nodesByName[node]
        else:
            children = [makeTree(child) for child in node]
            return newTreeNode(children)

    return makeTree(treeStructure)

def treeFromHdf5(h5node):
    assert h5node.attrs["type"] == "KmerTreeNode", "HDF5 group not a KmerTreeNode"
    name = h5node.attrs["name"]
    children = []
    kmers = None
    counts = None
    for key, value in list(h5node.items()):
        if key == "kmers":
            kmers = np.array(value)
        elif key == "counts":
            counts = np.array(value)
        else:
            children.append(value)
    node = KmerTreeNode(kmers, name = name, counts = counts)
    node.children = [treeFromHdf5(child) for child in children]
    return node


def loadLeafNode(fileName, fingerprint):
    name = kmertools.nameFromPath(fileName)
    kset = kmertools.KmerSet()
    if args.fingerprint:
        kset.kmers = kmertools.loadFingerprint(arg)
    else:
        kset.kmers = kmertools.loadKmers(arg)
    return KmerTreeNode(kset.kmers, name=name)

def scoreSample(tree, sampleKmers, scoring):
    if scoring == "unique":
        tree.computeScoreUnique(sampleKmers)
    elif scoring == "weighted":
        tree.computeScoreWeighted(sampleKmers)
    else:
        tree.computeScore(sampleKmers)
    for child in tree.children:
        scoreSample(child, sampleKmers, scoring)
    
    

###
### Main
###

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", help="increase output verbosity",
                    action="store_true")
parser.add_argument("--fingerprint", help="use minhash fingerprint instead of full kmer set to build graph",
                    action="store_true")
parser.add_argument("--dedupe", action="store_true", help="Remove duplicate kmers from tree")
parser.add_argument("--nwk", help="File to receive NWK format tree")
parser.add_argument("-k", "--K", help="Kmer size (default 23)", default=23)
parser.add_argument("--input", help="HDF5 tree input file")
parser.add_argument("--output", help="HDF5 tree output file")
parser.add_argument("--text", help="Text output file")
parser.add_argument("--dot", help="File to receive graphviz dot-style tree graph")
#parser.add_argument("--sample", help="test coverage of tree vs kmers in this sample hdf5")
#parser.add_argument("--sampleout", help="HDF5 sample scoring output file")
#parser.add_argument("--samplescore", help="sample matching scoring method", choices=["all", "unique", "weighted"], default="all")
parser.add_argument("--simscore", help="similarity scoring method", choices=["jaccard", "minsize", "maxsize"], default="jaccard")
parser.add_argument('genomes', nargs='*',
                    help='kmerized genome hdf5 or npz files')
args = parser.parse_args()

if args.input:
    print('Generating kmer tree from', args.input)
    #tree = treeFromNWK(args.input, nodes)
    with h5py.File(args.input, 'r') as h5:
        tree = treeFromHdf5(h5)
else:
    scores = {}
    nodes = [loadLeafNode(arg, args.fingerprint) for arg in args.genomes]
    nleaves = len(nodes)
    print('Loaded', nleaves, 'genomes')
    while len(nodes) > 1:
        nodes = buildTree(nodes, scores, args.simscore)
        print(len(nodes))
        sys.stdout.flush()
    tree = nodes[0]
    countTree(tree)

if args.dedupe:
    print('Deduping tree')
    dedupeTree(tree)

# moved to treepath
# if args.sample:
#     print 'Scoring sample'
#     sampleKmers = kmertools.loadKmers(args.sample)
#     scoreSample(tree, sampleKmers, args.samplescore)
#     if args.sampleout:
#         print 'Output sample scores to', args.sampleout
#         with h5py.File(args.sampleout, 'w') as h5:
#             tree.saveSampleHdf5(h5, compress = True)

if args.output:
    fileName = args.output
    if not fileName.endswith(".hdf5"):
        fileName += ".hdf5"
    print('Saving tree to', fileName)
    with h5py.File(fileName, 'w') as h5:
        tree.saveHdf5(h5, compress = True)

if args.text:
    with open(args.text, 'w') as outtext:
        tree.saveText(outtext)

if args.dot:
    graph = pydot.Dot(graph_type='graph', rankdir='LR')
    graphTree(graph, tree, None)
    graph.write_png(args.dot)

if args.nwk:
    with open(args.nwk, 'w') as nwk:
        print(tree.nwk() + ';', file=nwk)
