#!/usr/bin/env python

#  Copyright (c) 2016-2019, Broad Institute, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name Broad Institute, Inc. nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import argparse
#workaround for no X windows
import matplotlib as mp
mp.use("Agg")
import kmertools
import kmerizer
import numpy as np
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument('kmerfiles', nargs='+',
                    help='Input kmer files (npz with counts)')
parser.add_argument("-s", "--spectrum", help="File to ouput kmer spectrum graph (.png best)")
parser.add_argument("--histogram", help="Text output file for histogram (count of kmers by frequency)")
parser.add_argument("-f", "--filter", action='store_true', help="Try to determine cutoffs")
parser.add_argument("-m", "--max", type=int, help="Max frequence for kmer spectrum plot")
parser.add_argument("-r", "--ref", help="Reference kmer file for comparison")
args = parser.parse_args()

kset = kmertools.KmerSet()

for kmerfile in args.kmerfiles:
    kset.load(kmerfile)
    spectrum = kset.spectrum()
    plt.semilogy(spectrum[0], spectrum[1], label=kmerfile.split('.')[0])

if args.histogram:
    kset.writeHistogram(args.histogram)

if args.ref:
    OFFSET = 1000000
    kref = kmertools.KmerSet()
    kref.load(args.ref)
    pseudoCounts = np.maximum(kref.counts, OFFSET)
    (allKmers, allCounts) = kmerizer.merge_counts(kref.kmers, pseudoCounts, kset.kmers, kset.counts)
    genomic = allCounts >= OFFSET
    genomicCounts = allCounts[genomic]
    genomicCounts -= OFFSET
    genomicSpectrum = np.unique(genomicCounts, return_counts=True)
    plt.semilogy(genomicSpectrum[0], genomicSpectrum[1], label='genomic')
    bogusCounts = allCounts[np.logical_not(genomic)]
    bogusSpectrum = np.unique(bogusCounts, return_counts=True)
    plt.semilogy(bogusSpectrum[0], bogusSpectrum[1], label='bogus')
    # if we used a given cutoff, how many errors would we have?
    freq = list(range(1, 200))
    errors = []
    for f in freq:
        fp = bogusCounts[bogusCounts >= f].size
        fn = genomicCounts[genomicCounts < f].size
        errors.append(fp + fn)
    print('Min errors:', min(errors))
    plt.semilogy(freq, errors, label='total errors')


if args.filter:
    thresholds = kset.spectrumFilter()
    print('Kmer spectrum thresholds:', thresholds)

plt.xlabel("Kmer Frequency")
plt.ylabel("Number of Kmers")
plt.legend()

if args.max:
    plt.xlim(0, args.max)

if args.spectrum:
    plt.savefig(args.spectrum)
else:
    plt.show()

