#!/usr/bin/env python3
#
# glimps-train: create "transformers" that can convert between fine-grained
#               and coarse-grained representations of molecular components
#               from a matched pair of trajectory files and an additional
#               molecule specification file.
#
#               Invoke as:
#
#                   glimps-train <cg_traj> <fg_traj> <molspecs>
#
#               where <cg_traj> and <fg_traj> are in any MDTraj-supported
#               format and <molspecs> is a molecule specification file.
#               (Use "glimps-train -h" to see additional options)
#
#               The molecule specification file should contain rows with the
#               format:
#
#                 <molid> <n_copies> <n_fg> <n_cg>
#
#               where:
#                 <molid>:    name of component (your choice)
#                 <n_copies>: copies of this component in the trajectories
#                 <n_fg>:     number of particles per copy, fine-grained form
#                 <n_cg>:     number of particles per copy, coarse-grained form
#
#               The order of these rows must match the order of the data in 
#               the trajectory files.
#
import argparse
import json
import pickle
import os
import os.path as op
try:
    import mdtraj as mdt
except ImportError:
    print('You need to install MDTraj ("pip install mdtraj") to use glimps-train')
    exit(1)
import numpy as np
from mdplus.multiscale import Glimps

def fit(args):
    if args.cgtop:
        cgtraj = mdt.load(args.cgtraj, top=args.cgtop)
    else:
        cgtraj = mdt.load(args.cgtraj)

    if args.fgtop:
        fgtraj = mdt.load(args.fgtraj, top=args.fgtop)
    else:
        fgtraj = mdt.load(args.fgtraj)

    if not len(cgtraj) == len(fgtraj):
        raise ValueError('Error: CG and FG trajectories contain different numbers of snapshots')
    n_frames = cgtraj.n_frames
    molspecs = {}
    with open(args.molspec) as f:
        for line in f.readlines():
            if line[0] != "#":
                words = line.split()
                if len(words) == 4:
                    molspec = {}
                    molid = words[0]
                    molspec["n_copies"] = int(words[1])
                    molspec["n_fg"] = int(words[2])
                    molspec["n_cg"] = int(words[3])
                    molspecs[molid] = molspec
                    
    j_fg = 0
    j_cg = 0
    if os.path.isfile(args.datadir):
        raise ValueError('Error: {} exists but is not a directory'.format(args.datadir))
    if not op.exists(args.datadir):
        os.makedirs(args.datadir)
        
    for molid in molspecs:
            molspec = molspecs[molid]
            i_fg = j_fg 
            j_fg = i_fg + molspec["n_fg"] * molspec["n_copies"]
            k_fg = i_fg + molspec["n_fg"]
            i_cg = j_cg
            j_cg = i_cg + molspec["n_cg"] * molspec["n_copies"]
            k_cg = i_cg + molspec["n_cg"]
            if args.pickle:
                pklfile = op.join(args.datadir, molid + '.pkl') 
            else:
                pklfile = op.join(args.datadir, molid + '.json') 
            if op.exists(pklfile):
                print('Warning: skipping molecule {} as already in data directory'.format(molid))
            else:
                print('Creating transformer for molecule {}'.format(molid))
                g = Glimps(pca=False)
                x_fg = fgtraj.xyz[:, i_fg:j_fg].reshape((n_frames * molspec["n_copies"], molspec["n_fg"], 3))
                x_cg = cgtraj.xyz[:, i_cg:j_cg].reshape((n_frames * molspec["n_copies"], molspec["n_cg"], 3))
                top_fg = fgtraj.topology.subset(range(i_fg, k_fg))
                top_cg = cgtraj.topology.subset(range(i_cg, k_cg))
                max_samples = 6 * molspec["n_fg"]
                l = len(x_fg)
                if l > max_samples:
                    indices = np.linspace(0, l-1, max_samples, dtype=np.int32)
                    x_cg = x_cg[indices]
                    x_fg = x_fg[indices]
                g.fit(x_cg, x_fg)
                molspec["transformer"] = g
                molspec["fg_topology"], _ = top_fg.to_dataframe()
                molspec["cg_topology"], _ = top_cg.to_dataframe()

                data = {
                       "transformer": molspec["transformer"],
                       "fg_topology": molspec["fg_topology"],
                       "cg_topology": molspec["cg_topology"]
                       }
                if args.pickle:
                    with open(pklfile, "wb") as f:
                        pickle.dump(data, f)
                else:
                    with open(pklfile, 'w') as f:
                        data['transformer'] = molspec['transformer'].get_state()
                        data['fg_topology'] = molspec['fg_topology'].to_dict()
                        data['cg_topology'] = molspec['cg_topology'].to_dict()
                        json.dump(data, f)
              
                
     
parser = argparse.ArgumentParser()
parser.add_argument('cgtraj', help='CG trajectory file')
parser.add_argument('fgtraj', help='FG trajectory file')
parser.add_argument('molspec', help='molecule specification file')
parser.add_argument('--cgtop', help='CG topology file')
parser.add_argument('--fgtop', help='FG topology file')
parser.add_argument('--datadir', default='.', help='location of transformer files')
parser.add_argument('--pickle', help='Save in legacy pickle format rather than JSON')

args = parser.parse_args()
fit(args)
