#!/usr/bin/env python

#  Copyright (c) 2016-2019, Broad Institute, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name Broad Institute, Inc. nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import argparse
import csv
from itertools import combinations
import numpy as np
import kmertools
import kmerizer
import time

parser = argparse.ArgumentParser()
parser.add_argument("-f", "--fingerprint", help="use minhash fingerprint instead of full kmer set to build graph",
                    action="store_true")
parser.add_argument("-s", "--sample", help="Compare similarity of this vs the other strains instead of all vs all")
parser.add_argument("-o", "--output", help="output text file")
parser.add_argument('strains', nargs='+',
                    help='kmerized strain hdf5 or npz files')
args = parser.parse_args()

print("Loading sample", args.sample)
sample = kmertools.kmerSetFromFile(args.sample)

print("Scoring strains")
results = []

for s in args.strains:
    name = kmertools.nameFromPath(s)
    t0 = time.time()
    mask = (1 << 32) - 1
    kset = kmertools.kmerSetFromFile(s)
    if args.fingerprint:
        strainKmers = kset.fingerprint
        strainCounts = np.ones_like(kset.fingerprint, dtype=np.int64)
    else:
        strainKmers = kset.kmers
        strainCounts = kset.counts
    # play a trick here to put the genomic kmer counts in the upper 32 bits and the sample kmer count in the lower.
    strainCounts <<= 32
    merged = kmerizer.merge_counts(strainKmers, strainCounts, sample.kmers, sample.counts)
    genomic = merged[1] >= (1 << 32)
    mergedKmers = merged[0][genomic]
    mergedCounts = merged[1][genomic]
    merged = None
    genomicCounts = mergedCounts >> 32
    mergedCounts &= mask
    inBoth = mergedCounts > 0
    genomicCounts = genomicCounts[inBoth]
    mergedCounts = mergedCounts[inBoth]
    covered = float(mergedCounts.size) / float(strainKmers.size)
    coverage = float(mergedCounts.sum()) / float(genomicCounts.sum())
    strainMergedCounts = np.zeros_like(strainCounts)
    k, c = kmerizer.merge_counts(strainKmers, strainMergedCounts, mergedKmers, mergedCounts)
    medianCov = int(round(np.median(c)))
    result = (name, covered, coverage, medianCov)
    results.append(result)
    print(name, "covered: %.3f coverage: %.2f median: %d score: %.2f" % (covered, coverage, medianCov, covered * coverage))

results.sort(lambda a, b: cmp(b[1]*b[2], a[1]*a[2]))

if args.output:
    output = open(args.output, 'w')
else:
    output = sys.stdout

print("name\tcovered\tcoverage\tmedian\tscore")
for r in results:
    print("%s\t%.3f\t%.2f\t%d\t%.2f" % (r[0], r[1], r[2], r[3], r[1] * r[2]), file=output)



