#!/usr/bin/env python
"""Path scoring tool"""
#  Copyright (c) 2016-2019, Broad Institute, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name Broad Institute, Inc. nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import h5py
import kmertools
import argparse
import kmerizer
import numpy as np
from itertools import chain
import random
import pydot
import copy

# globals
verbose = False
method = 'adjusted'
calc_p = 1000
min_unique = 100
min_nodes = 0.7
min_score = 0.1
complete = False
max_gap = None
calc_random = 0


class RandomTreeNode:
    """Class to support randomizing tree node scores"""
    def __init__(self, tree=None, root=None):
        self.score = None
        self.name = None
        self.children = []
        
        if not root:
            root = self
            root.all_scores = []
        if tree:
            self.name = tree.name
            root.all_scores.append(tree.score)
            for child in tree.children:
                self.children.append(RandomTreeNode(child, root))
        

    def isLeaf(self):
        return not self.children
        
    def randomize(self, root=None):
        if not root:
            root = self
            root.i = 0
            random.shuffle(root.all_scores)
            
        
        self.score = root.all_scores[root.i]
        root.i += 1
        for child in self.children:
            child.randomize(root)
            


class KmerTreeNode:
    """Class to recursively hold tree representation"""
    def __init__(self, kmers, name=None, children=[], counts=None, parent=None, score=None, common=None, denominator=None):
        global treeNodeCount
        if name:
            self.name = name
        else:
            treeNodeCount += 1
            self.name = "node-" + str(treeNodeCount)
        self.kmers = kmers
        self.counts = counts
        self.parent = parent
        self.children = children
        self.score = score
        self.common = common
        self.denominator = denominator
        
    def isLeaf(self):
        return not self.children

    def label(self):
        s = ""
        if self.counts is not None:
            s += str((self.counts == 1).sum()) + "/"
        s += str(self.kmers.size)
        s += "\n" + self.name
        return s
    
    def uniqueKmersPerNode(self):
        sizes = []
        sizes.append(self.kmers[self.counts==1].size)
        for child in self.children:
            sizes.extend(uniqueKmersPerNode(child))
        
        return sizes
    
    def saveSampleHdf5(self, h5node, compress=True):
        if compress is True:
            compress = "gzip"
        h5node.attrs["type"] = np.string_("SampleScoreNode")
        h5node.attrs["name"] = np.string_(self.name)
        
        if self.score is not None:
            h5node.attrs["score"] = self.score
        else:
            h5node.attrs["score"] = 0
        
        if self.denominator is not None:
            h5node.attrs["denominator"] = self.denominator
        else:
            h5node.attrs["denominator"] = 0
        
        if self.common is not None:
            h5node.create_dataset("common", data=self.common, compression=compress)
        
        for child in self.children:
            childGroup = h5node.create_group(child.name)
            child.saveSampleHdf5(childGroup, compress=compress)
    
    
    def computeScoreWeighted(self, kmers):
        """Score this node vs a set of kmers"""
        if self.counts is None:
            print("Counts missing from node. Using unique score...")
            self.computeScoreUnique(kmers)
        weights = 1.0 / (self.counts * self.counts)
        self.denominator = weights.sum()
        if self.denominator > 0:
            subCounts = kmerizer.intersect_counts(self.kmers, self.counts, kmers)
            self.score = (1.0 / (subCounts * subCounts)).sum() / self.denominator
            self.common = kmerizer.intersect(self.kmers, kmers)
        else:
            self.score = 0
        

    def computeScoreWeighted2(self, kmers):
        if self.denominator > 0:
            if self.counts is None:
                # unique tree
                return float(kmerizer.intersect(self.kmers, kmers).size) / self.denominator
            else:
                subCounts = kmerizer.intersect_counts(self.kmers, self.counts, kmers)
                return (1.0 / (subCounts * subCounts)).sum() / self.denominator
        else:
            return 0

    def computeScoreUnique(self, kmers):
        """Score this node vs a set of kmers"""
        if self.counts is not None:
            uniqueKmers = self.kmers[self.counts == 1]
        else:
            uniqueKmers = self.kmers
        if uniqueKmers.size > 0:
            #self.common = np.intersect1d(uniqueKmers, kmers, assume_unique = True)
            self.common = kmerizer.intersect(uniqueKmers, kmers)
            self.denominator = float(uniqueKmers.size)
            if self.denominator > 0:
                self.score = float(self.common.size) / self.denominator
            else:
                self.score = 0
        else:
            self.score = 0

    def computeScore(self, kmers):
        """Score this node vs a set of kmers"""
        if self.kmers.size > 0:
            #self.common = np.intersect1d(self.kmers, kmers, assume_unique = True)
            self.common = kmerizer.intersect(self.kmers, kmers)
            self.denominator = float(self.kmers.size)
            if self.denominator > 0:
                self.score = float(self.common.size) / self.denominator
            else:
                self.score = 0 
        else:
            self.score = 0
    
    def computeScoreAdjusted(self, kmers, min_unique=100, calc_p=1000):
        """Score this node vs a set of kmers"""
        if self.counts is not None:
            uniqueKmers = self.kmers[self.counts == 1]
        else:
            uniqueKmers = self.kmers
        if uniqueKmers.size >= min_unique or uniqueKmers.size == self.kmers.size:
            self.computeScoreUnique(kmers)
        else:
            self.computeScoreWeighted(kmers)
            if self.score > 0 and self.kmers.size < calc_p:
                randscores = self.computeRandomScore(kmers)                    
                if ( sum (self.score <= randscores) / float(randscores.size) ) >= .05:
                    self.score = 0 # assign no score since p >= .05
    
    def computeRandomScore(self, kmers, n=100): #n is number of iterations needs to be properly chosen
        """Randomize kmers and score node"""
        n_kmers = kmerizer.count_common(self.kmers, kmers)
        res = []
        for i in range(n): 
            # randomly sample, keep order
            sampled = sorted(np.asarray(random.sample(range(kmers.size), n_kmers)))
            randsamp = kmers[sampled]
            res.append(self.computeScoreWeighted2(randsamp))
        return (np.asarray(res))
    
    def computenodecov(self, samplekmers, kmercounts):
        sect = np.in1d(samplekmers, self.kmers)
        sectcounts = kmercounts[sect]
        if len(sectcounts) > 0:
            self.cov = np.mean(sectcounts)
        else:
            self.cov = np.nan

def treeFromHdf5(h5node, parent=None):
    """Construct KmerTree from hdf5 file"""
    #assert h5node.attrs["type"] == "KmerTreeNode", "HDF5 group not a KmerTreeNode"
    if not (h5node.attrs["type"] == "KmerTreeNode" or h5node.attrs["type"] == "SampleScoreNode"):
        print("HDF5 not a valid class", h5node.attrs["type"])
        return
    name = h5node.attrs["name"]
    score = h5node.attrs.get('score')
    denominator = h5node.attrs.get('denominator')
    children = []
    kmers = None
    counts = None
    common = None


    for key, value in list(h5node.items()):
        if key == "kmers":
            kmers = np.array(value)
        elif key == "counts":
            counts = np.array(value)
        elif key == "common":
            common = np.array(value)
        else:
            children.append(value)
    node = KmerTreeNode(kmers, name = name, counts = counts, parent = parent, score = score, common = common, denominator = denominator)
    node.children = [treeFromHdf5(child, node) for child in children]
    return node


def scoreSample(tree, sampleKmers, samplecounts, scoring='adjusted', min_unique=100, calc_p=1000):
    """Score a sample against a KmerTree"""
    if scoring == 'adjusted':
        tree.computeScoreAdjusted(sampleKmers, min_unique, calc_p)
    elif scoring == "unique":
        tree.computeScoreUnique(sampleKmers)
    elif scoring == "weighted":
        tree.computeScoreWeighted(sampleKmers)
    else:
        tree.computeScore(sampleKmers)
    for child in tree.children:
        scoreSample(child, sampleKmers, samplecounts, scoring, min_unique = min_unique, calc_p = calc_p)


def get_path_scores(tree, scores=None, leaves=False, min_score=0.1, max_gap=None):
    """Iterate through all children in tree and compile path scores"""
    if not tree:
        return [[], [], []]
    if not scores:
        scores = []
         
    name = tree.name
    scores.append(tree.score)
    paths = []
    if max_gap == None or min_score <= 0 or any([x >= min_score for x in scores[-max_gap:]]):
        if not leaves or tree.isLeaf():
            paths.append([[name], scores[:], tree])
        for child in tree.children:
            for path in get_path_scores(child, scores[:], leaves=leaves, min_score=min_score, max_gap=max_gap):
                paths.append([[name]+path[0], path[1], path[2]])
    
    return paths


def filter_paths(paths, min_nodes=0.7, min_score=0.1):
    """Filter to remove low supported paths by both node count and score"""
    if min_nodes <= 0 or min_score <= 0:
        return paths
    keep = []
    for path in paths:
        min_node_count = np.ceil(min_nodes * len(path[1]))
        if len([x for x in path[1] if x >= min_score]) >= min_node_count:
            keep.append(path)
    
    return keep


def rank_and_reduce_paths(paths, min_score=0.1):
    """Rank paths by node score and reduce paths based on parental nodes"""

    # rank paths by their average node score (in decreasing order)
    ranked = sorted(paths, key = lambda path: (float(sum(path[1]))/len(path[1]), len([x for x in path[1] if x >= min_score])), reverse=True)
    
    # remove paths that are subsets of higher ranked paths unless they have new nodes >= min_score
    reduced = []
    for path in ranked:
        if not reduced:
            reduced.append(path)
            continue
            
        pathset = set(path[0])
        for higher_rank in sorted(reduced, key = lambda path: (len(path[0]), sum(path[1])), reverse=True):
            if pathset.issuperset(higher_rank[0]):
                n = len(pathset.intersection(higher_rank[0]))
                if min_score <= 0 or any([x >= min_score for x in path[1][n:]]):
                    reduced.append(path)
                break
        else:
            reduced.append(path)
    
    # rerank on complete, then on length of path, then on path score
    # remove paths that are complete subsets of better paths
    final = []
    for path in sorted(reduced, key = lambda path: (path[2].isLeaf(), len(path[0]), sum(path[1])), reverse=True):
        if not final:
            final.append(path)
            continue
        
        pathset = set(path[0])
        for higher_rank in final:
            if pathset.issubset(higher_rank[0]):
                break
        else:
            final.append(path)
    
    # return sorted complete paths first, then by average node score
    return sorted(final, key = lambda path: (path[2].isLeaf(), float(sum(path[1]))/len(path[1])), reverse=True)


def get_strains(tree):
    """Recurse through tree and get children leaves"""
    answer = []
    if tree.isLeaf():
        answer.append([tree.name, tree.score])
    else:
        for child in tree.children:
            answer.extend(get_strains(child))
    
    return answer


def scorePathsInSample(tree, min_nodes=0.7, min_score = 0.1, max_gap=None, sampleKmers =  None, samplecounts = None ):
    """Score kmers of sample to paths in a tree"""
    
    scores = {}
    nodes = {}
    paths = filter_paths(get_path_scores(tree, [], False, min_score, max_gap), min_nodes, min_score)
    ranked = rank_and_reduce_paths(paths, min_score)

    if sampleKmers is not None and len(ranked) > 0:        
        endnodes = [x[2] for x in ranked]
        nodenames = [x[0] for x in ranked]
        filterout = sum(nodenames, [])
        filterout = list(set([x for x in filterout if filterout.count(x) > 1]))
        keep = [np.array([x not in filterout for x in group ]) for group in nodenames]
        pathcovs = [getpathcoverage(x, sampleKmers, samplecounts) for x in endnodes]
        uniquenodecovs = [np.nanmean(y[keep[x]]) for x,y in enumerate(pathcovs)] #avoinding nodes with no coverage overlap
        relabundance = (uniquenodecovs/sum(uniquenodecovs))* 100
        for x,y in enumerate(ranked):
            y.append(relabundance[x])
    else:
        for x,y in enumerate(ranked):
            y.append("NA")
        
        
    for path in ranked:
    
        name = '|'.join(path[0])
        score = (float(sum(path[1]))/len(path[1]), sum(path[1]))
        #scores[name] = (score, sorted(get_strains(path[2]), key = lambda strain: strain[1], reverse=True))
        scores[name] = (score, sorted(get_strains(path[2]), key = lambda strain: strain[1], reverse=True), path[3])

        if verbose:
            print(name, ':', score)
            print(' | '.join(['%.3f' % x for x in path[1]]))
            if path[2].isLeaf():
                print('Complete!')
            else:
                print(scores[name][1])

    return scores


def get_p_value_on_scores(tree, min_nodes=0.7, min_score=0.1, max_gap=None):
    """Get 99%ile of the sum of node scores of random paths (p = 0.01)"""
    random_tree = RandomTreeNode(tree, root=None)
    random_scores = []
    temp = False
    global verbose
    if verbose:
        temp = verbose # this is just too much info...
        verbose = False
    for i in range(calc_random):
        random_tree.randomize(None)
        rand_scores = scorePathsInSample(random_tree, min_nodes, min_score, max_gap)
        if rand_scores:
            random_scores.append(max([x[0][1] for x in list(rand_scores.values())]))
        else:
            random_scores.append(0)
    if temp:
        verbose = temp
    return np.percentile(random_scores, 99)
    

def save_text(filein, pathScores, out, random_score=None, only_complete=False, overwrite=False):
    """Save path scores to text file"""
    if overwrite:
        w = open(out, 'wb')
        w.write("Input,Path,Score,Complete,Significant,Strains,RelAbundance\n")
    else:
        # append, and don't write header
        w = open(out, 'ab')
    for path in sorted(pathScores, key = lambda x: (len(pathScores[x][1]) > 1, -pathScores[x][0][0], -pathScores[x][0][1])):
        if len(pathScores[path][1]) == 1:
            complete = 'complete'
        elif only_complete:
            break # keep only complete
        else:
            complete = 'incomplete'
        
        strains = ""
        non_zero = False
        for strain in pathScores[path][1]:
            if strain[1] >= 0.001:
                non_zero = True
            elif non_zero:
                break
            if strains: strains += " "
            strains += "%s:%.3f" % (strain[0], strain[1])
        
        if random_score:
            if pathScores[path][0][1] > random_score:
                significance = 'True'
            else:
                significance = 'False'
        else:
            significance = "NA"
        
        w.write("%s,%s,%.3f,%s,%s,%s,%s\n" % (filein, path, pathScores[path][0][0], complete, significance, strains, pathScores[path][2]))
    w.close()
    

def graphTree(graph, tree, parentNode):
    """Graph a tree to dot figure"""
    leaf = tree.isLeaf()
    nodeLabel = tree.label()
    color = 1
    textcolor = 9
    id = tree.name
    if tree.score is not None:
        frac = tree.score
        nodeLabel = "%.0f%% %s" % (100.0 * frac, nodeLabel)
        color = int(round(frac * 8.0)) + 1
        if color >= 7:
            textcolor = 1
    node = pydot.Node(id, label=nodeLabel, style="filled", colorscheme='blues9', fillcolor=color, fontcolor=textcolor)
    graph.add_node(node)
    if parentNode:
        graph.add_edge(pydot.Edge(parentNode, node))
    if not leaf:
        for child in tree.children:
            graphTree(graph, child, node)


def process_sample(tree, samplekmers, samplecounts):
    """Score a sample to a tree"""
    
    print("Scoring sample kmers using:", method)
    scoreSample(tree, sampleKmers,samplecounts, method, min_unique, calc_p)


def process_tree(filein, tree, sampleKmers, samplecounts ,out=None, dot=None, overwrite=False):
    """"After sample scored to tree, calculate paths and output results"""
    if calc_random:
        print("Estimating threshold for significant path...")
        random_score = get_p_value_on_scores(tree, min_nodes, min_score, max_gap)
        print("For this tree, sum of node scores must be > %.3f for p < 0.01" % random_score)
    else:
        random_score = None
    
    print("Scoring paths in tree...")
    pathScores = scorePathsInSample(tree, min_nodes, min_score, max_gap, sampleKmers, samplecounts)
    
    for path in sorted(pathScores, key = lambda x: (len(pathScores[x][1]) > 1, -pathScores[x][0][0], -pathScores[x][0][1])):
        if len(pathScores[path][1]) == 1:
            print(path, "complete", end=' ')
        elif complete:
            break # keep only complete
        else:
            print(path, "incomplete", end=' ')
    
        if random_score:
            if pathScores[path][0][1] > random_score:
                print("significant", end=' ')
            else:
                print("p > 0.01", end=' ')
    
        print("%.3f" % pathScores[path][0][0])
    
    if out:
        save_text(filein, pathScores, out, random_score, complete, overwrite=overwrite)
    
    if dot:
        graph = pydot.Dot(graph_type='graph', rankdir='LR')
        graphTree(graph, tree, None)
        graph.write_png(dot)


###
### Christine's methods
###
def get_parent(node, amountup=1):
    """returns the parent node from a given amount of steps"""
    thisnode = node
    i = 0
    while i!=amountup:
        thisnode = thisnode.parent
        i += 1
    return (thisnode)
    

def get_children(node, childrenlist, limit, steps):
    """returns all the children of a node within a given limit of steps down"""
    children = childrenlist
    for child in node.children:
        stepnum = steps
        if child.isLeaf():
            children.append(child.name)
        else:
            if stepnum != limit:
                stepnum += 1
                get_children(child, children, limit, stepnum)  
    return (children)


def is_leaf_there(node, scorelist):
    """checks if path passes suitablity threshold"""
    nodethresh = int(round(len(scorelist) * min_nodes)) #needs to have at least this many nodes consecutively with a score
    
    isgood = all(i > 0.1 for i in scorelist[0:nodethresh]) #low thresh of 0.1 for the node to be supported)
    
    if isgood:
         #finds last node in path node is supported.
         
        firstabove = next(((idx, score) for idx,score in enumerate(scorelist[len(scorelist): nodethresh-1: -1]) if score > 0.1), None)
        limit =  len(scorelist) - nodethresh + 1
        
        if firstabove is None:
            lastsupport = get_parent(node, len(scorelist) - nodethresh)
        elif firstabove[0] == 0:
            #if the leafnode is the last supported, then returns the node name along with 
            #other nodes that shouldn't be considered
            thisnode = node
            nixlist = get_path(node)
            return ({node.name:nixlist[1:]}) 
        else: 
            lastsupport = get_parent(node, firstabove[0])
        if lastsupport.parent is None:
              print(node.name + " has last support of root node")
        else:
            chilren = get_children(lastsupport, [], limit , 0)
        
            #returns the node of last support and all the children passing a limit/distance threshold of that node
            return ({lastsupport.name:chilren})
    
    else:
        return (None)


def get_supported_paths(node, currentscore, scoredict):
    """Get supported paths in a scored tree"""
    cs = copy.deepcopy(currentscore)
    cs.append(node.score)

    if node.isLeaf():
        node.pathscore = sum(cs)/len(cs)
        results = is_leaf_there(node, cs) #runs function that checks if path is suitable
        
        if results is not None:
            scoredict.update(results) #saves results in dict of results
    else:
        for child in node.children:
            get_supported_paths(child, cs, scoredict)
      
    return (scoredict)  


def get_path(node):
    """returns the path node of a given node"""
    path = [node.name]
    thisnode = node
    i = 0
    while thisnode.parent is not None:
        thisnode = thisnode.parent
        path.append(thisnode.name)
    return (path)


def get_node(node, whichnode, correct=None):
    """returns the desired node (whichnode just the name) from a tree"""
    cs = correct
    if node.name == whichnode:
        return node
    else:
        for child in node.children:
            res = get_node(child, whichnode, cs)
            if res is not None:
                return res


def clean_paths(res, tree):
    """Get consecutive path results"""    

    for key in list(res.keys()):
        if "node" in key:
            if key in sum(list(res.values()), []):
                del res[key]
            else:
                thenode = get_node(tree, key)
                
                nodepath = get_path(thenode)
                for i in nodepath[1:]:
                    if i in list(res.keys()):
                        del res[i]
                        
    finaldict = {}
    for key in list(res.keys()):
        if "node" in key:
            finaldict[key] = dict()
            path = get_path(get_node(tree,key))[::-1]
            finaldict[key]["path"] = "|".join(path)
            finaldict[key]["strains"] = dict()
            for child in res[key]:
                thenode = get_node(tree, child)
                finaldict[key]["strains"][child] = thenode.pathscore

        else:
            thenode = get_node(tree, key)
            finaldict[key] = dict(score = thenode.pathscore, path = "|".join(get_path(thenode))[::-1])
        
    return finaldict


def output_consecutive_paths (pathdict, filein, out):
    """Print output for consecutive path detection"""
    with open(out, 'a') as f:
        for key, value in list(pathdict.items()):
            if "node" in key:
                for strain , score in list(value["strains"].items()):
                    f.write('{0},{1},{2},{3}, {4}, {5}\n'.format(filein, value["path"], "no", strain, score)) 
            else:
                f.write('{0},{1},{2},{3}, {4}, {5}\n'.format(filein,value['path'], "yes", key, value['score'])) 


def consecutive_paths(filein, tree, out=None, dot=None, overwrite=False):
    """Christine's consecutive method"""
    possible_paths = get_supported_paths(tree, [] , {})
    cleaned_paths = clean_paths(possible_paths, tree)
    if out:
        if overwrite:
            with open(out, 'w') as f:
                f.write("Input,Path,Complete,Strains,score,RelAbundance\n")
        output_consecutive_paths(cleaned_paths, filein, out)
    if dot:
        graph = pydot.Dot(graph_type='graph', rankdir='LR')
        graphTree(graph, tree, None)
        graph.write_png(dot)



###
### StrainSeeker method
###
def strain_seeker_score(tree, sampleKmers, thresh, currentpath, finalpaths):
    """Score sample for strain seeker method"""
    scoreSample(tree, sampleKmers, method, recurse = False)

    cp = copy.deepcopy(currentpath)
    cp.append(tree.name)
    if tree.score > thresh :
        if tree.isLeaf():
            finalpaths["|".join(cp)] = tree.name
        else:
            childsum = 0
            childprod = 0
            for child in tree.children:
                scoreSample(child, sampleKmers, method, recurse = False)
                childsum += child.score
                childprod = childprod * child.score
            
            if childsum == childprod == 0:
                print("nores")
            else:
                oe = tree.score / float(childsum - childprod)
                if 0.8 <= oe <= 1.20 :
                    for child in tree.children:
                        strain_seeker_score(child,sampleKmers,method, thresh, cp, finalpaths)
                elif oe > 1.2 :
                    finalpaths["|".join(cp)] = get_children(tree, [], 2, 0)
                else:
                    print("Not in db")
     
    return (finalpaths)
        
        
def strain_seeker(filein, tree, out=None, dot=None, overwrite=False):
    """Strain Seeker method implementation"""
    sampleKmers = kmertools.loadKmers(filein)
    possible_paths = strain_seeker_score(tree, sampleKmers, 0.1, [], {})
    if out:
        output_strain_seeker(filein, possible_paths, out, overwrite=overwrite)
    if dot:
        graph = pydot.Dot(graph_type='graph', rankdir='LR')
        graphTree(graph, tree, None)
        graph.write_png(dot)
    

def output_strain_seeker (filein, pathdict, out, overwrite=False):
    """Output results of strain seeker method"""
    if overwite:
        w = open(out, 'wb')
        w.write("Path,Complete,Strains\n")
    else:
        w = open(out, 'ab')
    for key, value in list(pathdict.items()):
        if len(value) > 1 and isinstance(value, list):
            w.write('{0},{1},{2},{3}\n'.format(filein, key, "no", ' |'.join(value)) )
        else:
            w.write('{0},{1},{2},{3}\n'.format(filein, key, "yes", value)) 
    
    w.close()

def getpathcoverage(node, samplekmers, kmercounts):
    node.computenodecov(samplekmers, kmercounts)

    coverage = [node.cov]

    while node.parent is not None:
        node = node.parent
        node.computenodecov(samplekmers, kmercounts)
        coverage.insert(0, node.cov)
        
    return np.array(coverage)
###
### Main
###

parser = argparse.ArgumentParser()
parser.add_argument("-v", "--verbose", help="increase output verbosity",
                    action="store_true")
parser.add_argument("input", help="HDF5 tree input file (can be previously scored tree!)")
parser.add_argument("sample", nargs='*', help="sample hdf5 to calculate paths in tree")
parser.add_argument("-o", "--out", help="Output results to this csv file")
parser.add_argument("--each", action='store_true', help="Store results into one csv file for each sample")
parser.add_argument("--dot", action='store_true', help="Save graphviz dot-style tree graph")
parser.add_argument("--hdf5", action='store_true', help="HDF5 tree score output (useful to re-analyze paths with different options)")
parser.add_argument("-d", "--detection", help="Specify path detection method", 
                    default="allpaths", choices=['allpaths', 'consecutive', 'strainseeker'])
parser.add_argument("-m", "--method", help='Path scoring method', choices=['adjusted', 'weighted', 'unique', 'all'])
parser.add_argument("-u", "--min_unique", help="Minimum number of unique kmers per node to use unique score (default: 100)", type=int)
parser.add_argument("-p", "--calc_p", help="Calculate p value for nodes fewer than this many kmers (default: 1000)", type=int)
parser.add_argument("-n", "--min_nodes", type=float, help="fraction of nodes required in a path (default: 0.7)")
parser.add_argument("-s", "--min_score", type=float, help="minimum score of a node in a tree to keep (default: 0.1)")
parser.add_argument("-c", "--complete", action="store_true", help="Keep only complete paths")
parser.add_argument("-g", "--max_gap", type=int, help="Short circuit path detection if this many nodes are below min_score (default: disabled)")
parser.add_argument("-r", "--random", type=int, default=0, help='Calculate paths in n random trees to estimate significance of path scores (default: disabled)')
args = parser.parse_args()

if args.verbose:
    verbose = args.verbose
if args.method:
    method = args.method
if args.min_unique:
    min_unique = args.min_unique
if args.calc_p:
    calc_p = args.calc_p
if args.min_nodes:
    min_nodes = args.min_nodes
if args.min_score:
    min_score = args.min_score
if args.complete:
    complete = args.complete
if args.max_gap:
    max_gap = args.max_gap
if args.random:
    calc_random = args.random

dot = None
out = None


print("Loading tree file:", args.input)
with h5py.File(args.input, 'r') as h5:
    tree = treeFromHdf5(h5)
    assert tree != None, 'Could not load tree!'

if not args.each:
    with open(args.out, 'wb') as w:
        if args.detection == 'allpaths':
            w.write("Input,Path,Score,Complete,Significant,Strains,RelAbundance\n")
        elif args.detection == 'consecutive':
            w.write("Input,Path,Complete,Strains,score,RelAbundance\n")
        else:
            w.write("Path,Complete,Strains\n")
        
    
if args.sample:
    for sample in args.sample:
        name = os.path.basename(args.input) + '_' + os.path.basename(sample)
        samplename =  os.path.basename(sample)
        if args.dot:
            dot = name + '.png'
        if args.each:
            out = name + '.csv'
        elif args.out:
            out = args.out
        if args.detection != 'strainseeker':
            print("Loading kmers from sample:", sample)
            sampleset = kmertools.kmerSetFromFile(sample)
            sampleKmers = sampleset.kmers
            samplecounts = sampleset.counts
            process_sample(tree, sampleKmers, samplecounts)
        if args.hdf5:
            print("Saving scored tree to hdf5 file...")
            with h5py.File(name+'_scored.hdf5', 'w') as h5:
                tree.saveSampleHdf5(h5)
        if args.detection == 'allpaths':
            process_tree(samplename, tree, sampleKmers, samplecounts, out=out, dot=dot, overwrite=args.each)
        elif args.detection == 'consecutive':
            consecutive_paths(samplename, tree, out=out, dot=dot, overwrite=args.each)
        else:
            strain_seeker(samplename, tree, out=out, dot=dot, overwrite=args.each)
            
            
else:
    assert tree.score != None, "No score value for this tree. Did you forget to specify samples?"
    if args.each:
        out = args.input + '.csv'
    elif args.out:
        out = args.out
    if args.dot:
        dot = args.input + '.png'
    if args.detection == 'allpaths':
        process_tree(sample, tree, out=out, dot=dot, overwrite=True)
    elif args.detection == 'consecutive':
        consecutive_paths(sample, tree, out=out, dot=dot, overwrite=True)
    else:
        raise Exception("strainseeker NYI on pre-generated tree")




