#!/usr/bin/env python

#  Copyright (c) 2016-2019, Broad Institute, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name Broad Institute, Inc. nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import argparse
import csv
import math
import multiprocessing
from itertools import combinations
import kmertools

global args
global strains
global loader

def _compare_seqs(pair):
    try:
        i, j = pair
        if args.lowmem:
            s1name = kmertools.nameFromPath(strains[i])
            s1kmers = loader(strains[i])
            s2name = kmertools.nameFromPath(strains[j])
            s2kmers = loader(strains[j])
        else:
            s1name, s1kmers = strains[i]
            s2name, s2kmers = strains[j]
        if args.fraction:
            similarityScore = kmertools.similarityNumeratorDenominator
        else:
            similarityScore = kmertools.similarityScore
        result = (s1name, s2name, similarityScore(s1kmers, s2kmers, args.simscore))
        return result
    except KeyboardInterrupt:
        raise e
    except Exception as e:
        print('\nError: failed at processing %s vs %s\n' % (s1name, s2name))
        print(e)    
    return

def ani(j):
    k = kmertools.DEFAULT_K
    return 1 + math.log(2 * j / (1.0 + j)) / k if j else 0


parser = argparse.ArgumentParser()
parser.add_argument("--fingerprint", help="use minhash fingerprint instead of full kmer set to build graph",
                    action="store_true")
parser.add_argument("--sample", help="Compare similarity of this vs the other strains instead of all vs all")
parser.add_argument("--output", "-o", help="Similarity score output file (default stdout)")
parser.add_argument("--simscore", help="similarity scoring method", choices=["jaccard", "minsize", "maxsize", "reference"], default="jaccard")
parser.add_argument("--fraction", action="store_true", help="Output numerator and denomintor in fraction instead of calculating score")
parser.add_argument("-t", "--threads", type=int, help="Number of threads to use for pairwise similarity", default=1)
parser.add_argument("--lowmem", action='store_true', help="Load kmers on the fly for low memory (MUCH slower!)")
parser.add_argument('strains', nargs='+',
                    help='kmerized strain hdf5 or npz files')
args = parser.parse_args()


if args.fingerprint:
    loader = kmertools.loadFingerprint
else:
    loader = kmertools.loadKmers

if args.sample:
    name = kmertools.nameFromPath(args.sample)
    print('Computing similarity of', name, 'vs. strains', file=sys.stderr)
    kmers = loader(args.sample)
    if args.fraction:
        scores = [(name, kmertools.nameFromPath(strain), kmertools.similarityNumeratorDenominator(kmers, loader(strain), args.simscore))
                    for strain in args.strains]
    else:
        scores = [(name, kmertools.nameFromPath(strain), kmertools.similarityScore(kmers, loader(strain), args.simscore))
                    for strain in args.strains]
else:
    print('Computing pairwise similarity', file=sys.stderr)
    if args.simscore == 'reference':
        print("'reference' similarity score meaningless in pairwise similarity scoring", file=sys.stderr)
        sys.exit(1)
    
    if args.lowmem:
        strains = args.strains
    else:
        print('Loading kmers from files', file=sys.stderr)
        strains = [(kmertools.nameFromPath(arg), loader(arg)) for arg in args.strains]
    print('Comparing kmers to each other', file=sys.stderr)
    if args.threads > 1:        
        p = multiprocessing.Pool(args.threads)

        cmds = [pair for pair in combinations(range(len(strains)), 2)]
        map_async = p.map_async(_compare_seqs, cmds)
        scores = map_async.get()

        p.close()
        p.join()

    else:
        scores = []
        for pair in combinations(strains, 2):
            strain1, strain2 = pair
            if args.lowmem:
                s1name = kmertools.nameFromPath(strain1)
                s1kmers = loader(strain1)
                s2name = kmertools.nameFromPath(strain2)
                s2kmers = loader(strain2)
            else:
                s1name, s1kmers = strain1
                s2name, s2kmers = strain2
            if args.fraction:
                scores.append((s1name, s2name, kmertools.similarityNumeratorDenominator(s1kmers, s2kmers, args.simscore)))
            else:
                scores.append((s1name, s2name, kmertools.similarityScore(s1kmers, s2kmers, args.simscore)))

print('Sorting scores', file=sys.stderr)
if args.fraction:
    #scores.sort(lambda a, b: cmp(b[2][0]/b[2][1], a[2][0]/a[2][1]))
    scores.sort(lambda a, b: cmp(b[2][0], a[2][0]))
else:
    scores.sort(lambda a, b: cmp(b[2], a[2]))

print('Writing output', file=sys.stderr)
if args.output:
    output = open(args.output, 'w')
else:
    output = sys.stdout
csvwriter = csv.writer(output, delimiter="\t", lineterminator="\n")

if args.fraction:
    for name1, name2, (numerator, denominator) in scores:
        score = float(numerator) / denominator
        csvwriter.writerow((name1, name2, "%d" % numerator, "%d" % denominator, "%.5f" % score, "%.5f" % ani(score)))
else:
    for name1, name2, score in scores:
        csvwriter.writerow((name1, name2, "%.5f" % score, "%.5f" % ani(score)))
