#!/usr/bin/env python 
from extasycoco._version import __version__

import logging as log
import sys
import os
import os.path as op
import argparse
import glob
import time
import mdtraj as mdt

from extasycoco import complement

def coco_ui(args):
    '''
    The command line implementation of the CoCo procedure. Should be invoked
    as:
    pyCoCo -i mdfiles -t topfile -o outname [-d ndims -n npoints -g gridsize -l
    logfile -s selection --nompi --regularize --fmt format --skip nskip]    
    where:
        mdfiles  is a list of one or more trajectory files
        topfile  is a compatible topology file
        outname  two options here. If a single name is given it is assumed it
                 is the basename for the structure files generated by CoCo.
                 The default format is pdb but this can be overriden by
                 the --fmt argument. There will be npoints of these; if 
                 outname='out' then they will be called
                 'out0.pdb', 'out1.pdb'... etc up to 'out(npoints-1).pdb'.
                 Option two is that multiple file names are given here. In 
                 that case the number of them defines the number of new
                 structures created (any npoints argument is ignored) and
                 their extensions define the format.
        format   is the output file format. Accepted options are 'pdb'
                 (default), 'rst', and 'gro'.
        ndims    specifies the number of dimensions (PCs) in the CoCo mapping
                 (default=3).
        npoints  specifies the number of frontier points to return structures
                 from (default=1)
        gridsize specifies the number of grid points per dimension in the CoCo
                 histogram (default=10)
        logfile  is an optional file with detailed analysis data.
        selection is an optional MDTraj style selection string. Only 
                 selected atoms will be used in the CoCo procedure, however
                 ALL atoms will be included in the output files (all unselected
                 ones having coordinates drawn from the first frame analyzed).
                 Such structures are, obviously, only useful as targets for
                 restrained MD or EM procedures.
        cache    is an optional directory of trajectory files, created/updated 
                 in a previous CoCo run, which must only feature the selected
                 atoms. In an MPI context, one file per process should work
                 best.

        nompi    specifies that CoCo should not be run in parallel
        regularize specifies that generated structures should (if possible)
                 have their bond lengths and angles regularised.
        skip     specifies the number of top eigenvectors to skip over in the
                 CoCo process, e.g. if nskip is 1, and ndims is 3, the process
                 will use the distributions in PCs 2-4.
        
    '''

    if args.grid < 1:
        print('Error: gridsize must be > 0')
        exit(1)
    if args.frontpoints < 1:
        print('Error - frontpoints must be > 0')
        exit(1)
    if args.dims < 1:
        print('Error: dims must be > 0')
        exit(1)

    if args.verbosity == 2:
        log.basicConfig(format="%(asctime)s: %(levelname)s: %(message)s", 
                        level=log.DEBUG)
        log.debug("Debug output.")
    elif args.verbosity == 1:
        log.basicConfig(format="%(levelname)s: %(message)s", 
                        level=log.INFO)
        log.info("Verbose output.")

    dict = {}
    dict['topfile'] = args.topfile
    if ((len(args.mdfile)==1) and (("*" in args.mdfile[0]) 
                                    or ("?" in args.mdfile[0]))):
        dict['mdfiles'] = glob.glob('%s' % args.mdfile[0])
    else:
        dict['mdfiles'] = args.mdfile
    dict['logfile'] = args.logfile
    dict['ndims'] = args.dims
    dict['npoints'] = args.frontpoints
    dict['gridsize'] = args.grid
    dict['selection'] = args.selection
    dict['cache'] = args.cache

    if len(args.output) > 1:
        dict['npoints'] = len(args.output)
        dict['outnames'] = args.output
    else:
        root, ext = op.splitext(args.output[0])
        dict['outnames'] = ['{}{}{}'.format(root,rep,ext) for rep in range(dict['npoints'])]
            
    if args.nompi:
        comm = None
        rank = 0
        size = 1
    else:
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
            size = comm.Get_size()
        except ImportError:
            comm = None
            rank = 0
            size = 1
        
    if dict['logfile'] is not None and rank == 0:
        try:
            logfile = open(dict['logfile'],'w')
        except IOError as e:
            print(e)
            exit(-1)
        logfile.write("*** pyCoCo ***\n\n")

    traj = mdt.load(dict['mdfiles'], top=args.topfile)
    newtraj = complement(traj, selection=args.selection, npoints=dict['npoints'],
                         gridsize=args.grid, ndims=args.dims, 
                         refine=args.regularize, logfile=logfile, 
                         nskip=args.skip, rank=rank,
                         currentpoints=args.currentpoints,
                         newpoints=args.newpoints)

            
    if rank == 0:
        for rep, t in enumerate(newtraj):
            t.save(dict['outnames'][rep])
                
        if dict['logfile'] is not None and rank == 0:
            logfile.write("\nRMSD matrix for new structures:\n")
            for i in range(dict['npoints']):
                for j in range(dict['npoints']):
                    logfile.write("{:6.2f}".format(mdt.rmsd(newtraj[i], newtraj[j])[0]))
                logfile.write("\n")

        if args.cache:
            if len(cachelist) == 0:
                cachelist = [dict['cache'] + '/cache{}.dcd'.format(i) for i in range(size)] 
                chunksize = len(cf)/size
            else:
                chunksize = len(cf)/len(cachelist)
            for i in range(size):
                temptrj = dict['cache'] + '/tmp.dcd'
                start = i*chunksize
                end = min((i+1)*chunksize, len(cf))
                cf.write(temptrj, cf[start:end])
                if rank == 0:
                    os.rename(temptrj, cachelist[i])
            
            if dict['logfile'] is not None and rank == 0:
                logfile.write("Cache {} updated.".format(args.cache))

    if dict['logfile'] is not None and rank == 0:
        logfile.close()
################################################################################
#                                                                              #
#                                    ENTRY POINT                               #
#                                                                              #
################################################################################

if __name__ == '__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('-g','--grid', type=int, default=10, help="Number of points along each dimension of the CoCo histogram")
    parser.add_argument('-d','--dims', type=int, default=3, help='The number of projections to consider from the input pcz file in CoCo; this will also correspond to the number of dimensions of the histogram.')
    parser.add_argument('-n','--frontpoints', type=int, default=1, help="The number of new frontier points to select through CoCo.")
    parser.add_argument('-i','--mdfile', type=str, nargs='*', help='The MD files to process.', required=True)
    parser.add_argument('-o','--output', type=str, nargs='*', help='Basename of the pdb files that will be produced.', required=True)
    parser.add_argument('-t','--topfile', type=str, help='Topology file.', required=True)
    parser.add_argument('-v','--verbosity', action="count", help="Increase output verbosity.")
    parser.add_argument('-l','--logfile', type=str, default=None, help='Optional log file.')
    parser.add_argument('-c','--cache', type=str, default=None, help='Optional cache directory.')
    parser.add_argument('-s','--selection', type=str, default='all', help='Optional atom selection string.')
    parser.add_argument('--nompi', action='store_true', help='Disables any attempt to use MPI.')
    parser.add_argument('-r', '--regularize', action='store_true', help='Regularize structures.')
    parser.add_argument('-V','--version', action='version', version=__version__)
    parser.add_argument('-f','--fmt', type=str, default=None, help='Optional output format.')
    parser.add_argument('--currentpoints', type=str, default=None,
    help='Optional file with coordinates of current points.')
    parser.add_argument('--newpoints', type=str, default=None,
    help='Optional file with coordinates of CoCo-generated points.')
    parser.add_argument('--skip', type=int, default=0, help='The number of top eigenvectors to skip over (default=0)')
    
    args=parser.parse_args()
    coco_ui(args)
