#!/usr/bin/env python

#  Copyright (c) 2016-2019, Broad Institute, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name Broad Institute, Inc. nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import argparse
#workaround for no X windows
import matplotlib as mp
mp.use("Agg")
import kmertools


def parseMegaGiga(str, mult = 1):
    """Parse argument of the kind 100M --> 100,000,000"""
    if not str:
        return 0
    suffix = str[-1]
    if suffix in ('M', 'm'):
        return parseMegaGiga(str[:-1], 1000000)
    elif suffix in ('G', 'g'):
        return parseMegaGiga(str[:-1], 1000000000)
    else:
        return int(str) * mult


parser = argparse.ArgumentParser()
parser.add_argument('sequences', nargs='+',
                    help='Input sequence files (fasta or fastq by default; optionally compressed with gz or bz2)')
parser.add_argument("-k", "--K", help="Kmer size (default {})".format(kmertools.DEFAULT_K),
                    type=int, default=kmertools.DEFAULT_K)
parser.add_argument("-o", "--output", help="Output file (hdf5 or npz)")
parser.add_argument("-f", "--fingerprint", help="Compute and save minhash fingerprint (sketch)",
                    action="store_true")
parser.add_argument("--fingerprint_fraction", "--ff", type=float, default=kmertools.DEFAULT_FINGERPRINT_FRACTION,
                    help="Fraction of kmers to include in fingerprint (default {})".format(kmertools.DEFAULT_FINGERPRINT_FRACTION))
parser.add_argument("-c", "--compress", help="No longer used", action="store_true")
parser.add_argument("-F", "--filter", help="Filter output kmers based on kmer spectrum (to prune sequencing errors)",
                    action="store_true")
parser.add_argument("-s", "--spectrum", help="File to ouput kmer spectrum graph (.png best)")
parser.add_argument("-e", "--entropy", action="store_true", help="Calculate Shannon entropy (in bases)")
parser.add_argument("-l", "--limit", help="Only process about this many kmers (can have suffix of M or G)")
parser.add_argument("-p", "--prune", help="Prune singletons after accumulating this many (can have suffix of M or G)")
parser.add_argument("--histogram", help="Histogram output file (count of kmers by frequency)")
parser.add_argument("--kmerset", help="Input is kmerset in hdf5 or npz format", action="store_true")
args = parser.parse_args()

if args.kmerset:
    assert len(args.sequences) == 1, "Sorry, you can only specify one kmerset file at a time!!"
    print('Loading', args.sequences[0])
    kset = kmertools.kmerSetFromFile(args.sequences[0])
else:
    kset = kmertools.KmerSet(args.K)
    for arg in args.sequences:
        print('Kmerizing', arg)
        kset.kmerizeFile(arg, limit=parseMegaGiga(args.limit), prune=parseMegaGiga(args.prune))

if args.entropy:
    print("Entropy:", round(kset.entropy(), 2))

if args.histogram:
    kset.writeHistogram(args.histogram)

if args.spectrum:
    thresholds = kset.spectrumMinMax()
    if thresholds:

        kset.plotSpectrum(args.spectrum, thresholds[2])
    else:
        kset.plotSpectrum(args.spectrum)

if args.filter:
    thresholds = kset.spectrumFilter()
    if thresholds:
        print('Keeping kmers with frequency in the range', thresholds)
    kset.printStats()

if args.fingerprint:
    kset.minHash(args.fingerprint_fraction)

if args.output:
    print('Writing output to', args.output)
    kset.save(args.output, compress=True)

