#!/usr/bin/env python 
from extasycoco._version import __version__

import argparse
import mdtraj as mdt
import numpy as np
import os.path as op

from extasycoco.mc import mmc_series
from extasycoco.refinement import procrustes, Refiner
from sklearn.decomposition import PCA

def coco_mc(args):
    '''
    The command line implementation of the CoCoMC procedure. Should be invoked
    as:
    pyCoCMC -c mdfiles -t topfile -o outtraj [-d ndims -n nsamples
    -i interval -s scalefac --regularize]    
    where:
        mdfiles  is a list of one or more trajectory files
        topfile  is a compatible topology file
        outtraj  the samples, in a trajectory file format
        ndims    specifies the number of dimensions (PCs) in the CoCo mapping
                 (default=3).
        nsamples the number of samples to generate (default=1)
        scalefac A scaling factor for the sampling. Useful when the model
                 is built from NMR data whose variance may be exaggerated.
        regularize specifies that generated structures should (if possible)
                 have their bond lengths and angles regularised.
    '''

    if args.ndims < 1:
        raise ValueError('Error: dims must be > 0')
    pca = PCA(n_components=args.ndims)
    if args.topfile is None:
        if op.splitext(args.mdfiles[0])[1] != '.pdb':
            raise ValueError('Error - a topology file is required.')
        traj = mdt.load(args.mdfiles)
    else:
        traj = mdt.load(args.mdfiles, top=args.topfile)
    x_fitted = procrustes(traj.xyz)
    n_frames = len(x_fitted)
    pca.fit(x_fitted.reshape((n_frames, -1)))
    pstd = np.sqrt(pca.explained_variance_) * args.scalefac
    series = mmc_series(args.nsamples, args.interval, args.ndims)
    xmc = pca.inverse_transform(series * pstd * np.sqrt(2)).reshape((args.nsamples, -1, 3))
    if args.regularize:
        refiner = Refiner()
        refiner.fit(x_fitted)
        xmc = refiner.transform(xmc)
    tmc = mdt.Trajectory(xmc, traj.topology)
    tmc.save(args.outtraj)

################################################################################
#                                                                              #
#                                    ENTRY POINT                               #
#                                                                              #
################################################################################

if __name__ == '__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('-d','--ndims', type=int, default=3, help='The number of PCs to sample.')
    parser.add_argument('-n','--nsamples', type=int, default=1, help='The number of samples to generate.')
    parser.add_argument('-i','--interval', type=int, default=1, help='The number of MC steps between samples.')
    parser.add_argument('-s','--scalefac', type=float, default=1.0, help='Scaling factor for perturbations.')
    parser.add_argument('-c','--mdfiles', type=str, nargs='*', help='The input MD files to build the PCA model.', required=True)
    parser.add_argument('-o','--outtraj', type=str, help='The output trajectory.', required=True)
    parser.add_argument('-t','--topfile', type=str, help='Topology file.', required=False)
    parser.add_argument('-r', '--regularize', action='store_true', help='Regularize structures.')
    parser.add_argument('-V','--version', action='version', version=__version__)
    
    args=parser.parse_args()
    coco_mc(args)
