#!/usr/bin/env python 
from __future__ import print_function

from extasycoco._version import __version__
from extasycoco import dm_complement

import logging as log
import sys
import os
import os.path as op
import numpy as np
import argparse
import glob
import time
import mdio

log.basicConfig(format="%(levelname)s: %(message)s", level=log.INFO)


def cocodm_ui(args):
    '''
    The command line implementation of the distance-matrix version of the CoCo 
    procedure. Should be invoked as:
    pyCoCoDM -i mdfiles -t topfile -o outname [-d ndims -n npoints -g gridsize 
     --nsamples nsamples -l logfile -s selection --nompi --fmt format ]    
    where:
        mdfiles  is a list of one or more trajectory files
        topfile  is a compatible topology file
        outname  two options here. If a single name is given it is assumed it
                 is the basename for the structure files generated by CoCo.
                 The default format is pdb but this can be overriden by
                 the --fmt argument. There will be npoints of these; if 
                 outname='out' then they will be called
                 'out0.pdb', 'out1.pdb'... etc up to 'out(npoints-1).pdb'.
                 Option two is that multiple file names are given here. In 
                 that case the number of them defines the number of new
                 structures created (any npoints argument is ignored) and
                 their extensions define the format.
        format   is the output file format. Accepted options are 'pdb'
                 (default), 'rst', and 'gro'.
        ndims    specifies the number of dimensions (PCs) in the CoCo mapping
                 (default=3).
        npoints  specifies the number of frontier points to return structures
                 from (default=1)
        gridsize specifies the number of grid points per dimension in the CoCo
                 histogram (default=10)
        nsamples specifies the number of sample distance matrices to generate
                 (default=10,000).
        logfile  is an optional file with detailed analysis data.
        selection is an optional MDTraj style selection string. Only 
                 selected atoms will be used in the CoCo procedure, however
                 ALL atoms will be included in the output files (all unselected
                 ones having coordinates drawn from the first frame analyzed).
                 Such structures are, obviously, only useful as targets for
                 restrained MD or EM procedures.
        nompi    specifies that CoCo should not be run in parallel
        
    '''

    if args.grid < 1:
        raise ValueError('Error: gridsize must be > 0')
    if args.frontpoints < 1:
        raise ValueError('Error - frontpoints must be > 0')
    if args.dims < 1:
        raise ValueError('Error: dims must be > 0')

    if args.verbosity == 2:
        log.basicConfig(format="%(asctime)s: %(levelname)s: %(message)s", 
                        level=log.DEBUG)
        log.debug("Debug output.")
    elif args.verbosity == 1:
        log.basicConfig(format="%(levelname)s: %(message)s", 
                        level=log.INFO)
        log.info("Verbose output.")

    topfile = args.topfile
    if ((len(args.mdfile)==1) and (("*" in args.mdfile[0]) 
                                    or ("?" in args.mdfile[0]))):
        mdfiles = glob.glob('%s' % args.mdfile[0])
    else:
        mdfiles = args.mdfile
    logfile_name = args.logfile
    ndims = args.dims
    npoints = args.frontpoints
    gridsize = args.grid
    selection = args.selection
    nsamples = args.nsamples

    if len(args.output) > 1:
        npoints = len(args.output)
        outnames = args.output
    else:
        root, ext = op.splitext(args.output[0])
        outnames = ['{}{}{}'.format(root,rep,ext) for rep in range(npoints)]
            
    if args.nompi:
        comm = None
        rank = 0
        size = 1
    else:
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
            size = comm.Get_size()
        except ImportError:
            comm = None
            rank = 0
            size = 1
        
    if logfile_name is not None and rank == 0:
        try:
            logfile = open(logfile_name,'w')
        except IOError as e:
            print(e)
            exit(-1)
        logfile.write("*** pyCoCo ***\n\n")
    else:
        logfile=None

    if rank == 0:
        log.info('Loading trajectory files')
    loadstart = time.time()
    traj = mdio.mpi_load(mdfiles, top=topfile, selection=selection, comm=comm)
    loadtime = time.time() - loadstart
    natoms = traj.n_atoms
    if natoms == 0:
          raise ValueError('Error: the selection matches no atoms.')
    
    # create a one-frame trajectory  corresponding to the full system, and also 
    # an index list for the subset.
    if rank == 0:
        traj_ref = mdio.load(topfile)
        atom_indices = traj_ref.topology.select(selection)

    if logfile_name is not None and rank == 0:
        logfile.write("Trajectory files to be analysed:\n")
        for mdfile in mdfiles:
            logfile.write("{} \n".format(mdfile))

        logfile.write('\n')
    if rank == 0:
        log.info('trajectories contain {0} atoms and {1} frames'.format(natoms, len(traj)))
        log.info('time to load trajectory data: {:.2f} s.'.format(loadtime))

    # Some sanity checking for situations where few input structures have
    # been given. If there is just one, just return copies of it. If there
    # are < 5, ensure ndims is reasonable, and that the total number of 
    # grid points (at which new structures might be generated) is OK too.
    # Adust both ndims and gridsize if required, giving warning messages.
    if len(traj) == 1:
        if logfile_name is not None and rank == 0:
            logfile.write("WARNING: Only one input structure given, CoCo\n")
            logfile.write("procedure not possible, new structures will be\n")
            logfile.write("copies of the input structure.\n")

        if rank == 0:
            log.info('Warning: only one input structure!')
        for rep in range(npoints):
            if rank == 0:
                traj_ref[0].save(outnames[rep])
    else:
        new_traj = dm_complement(traj, logfile=logfile, n_samples=nsamples, 
                                 n_points=npoints, grid_dims=ndims,
                                 grid_bins=gridsize, pca_dims=20, comm=comm)
        if rank == 0:
            irep = -1
            for xyz in new_traj.xyz:
                irep += 1
                # merge the optimised subset into the full coordinates array:
                traj_ref.xyz[:,atom_indices] = xyz
                traj_ref.save(outnames[irep])
                
    if logfile_name is not None and rank == 0:
        logfile.close()

################################################################################
#                                                                              #
#                                    ENTRY POINT                               #
#                                                                              #
################################################################################

if __name__ == '__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('-g','--grid', type=int, default=10, help="Number of points along each dimension of the CoCo histogram")
    parser.add_argument('-d','--dims', type=int, default=3, help='The number of projections to consider from the input pcz file in CoCo; this will also correspond to the number of dimensions of the histogram.')
    parser.add_argument('-n','--frontpoints', type=int, default=1, help="The number of new frontier points to select through CoCo.")
    parser.add_argument('--nsamples', type=int, default=10000, help="The number of sample structures to generate.")
    parser.add_argument('-i','--mdfile', type=str, nargs='*', help='The MD files to process.', required=True)
    parser.add_argument('-o','--output', type=str, nargs='*', help='Basename of the pdb files that will be produced.', required=True)
    parser.add_argument('-t','--topfile', type=str, help='Topology file.', required=True)
    parser.add_argument('-v','--verbosity', action="count", help="Increase output verbosity.")
    parser.add_argument('-l','--logfile', type=str, default=None, help='Optional log file.')
    parser.add_argument('-s','--selection', type=str, default='all', help='Optional atom selection string.')
    parser.add_argument('--nompi', action='store_true', help='Disables any attempt to use MPI.')
    parser.add_argument('-V','--version', action='version', version=__version__)
    
    args=parser.parse_args()
    cocodm_ui(args)
