#!/usr/bin/env python
#
# launch script for TACC
# deals with both command files for parametric launcher
# and with single commands

import argparse
import os
import sys
from tempfile import *
import subprocess
import math
from datetime import datetime
from launch import config


config_file = config.find_config()
try:
    queues = config.load_config(config_file)
except IOError as err:
    queues = {'default': {'cores': None, 'max-nodes': None, 'max-cores': None}}


def launch_slurm(
    cmd='',
    script_name=None,
    runtime='01:00:00',
    jobname='launch',
    outfile=None,
    projname=None,
    queue='normal',
    email=None,
    qsubfile=None,
    keepqsubfile=False,
    test=False,
    compiler='intel',
    hold=None,
    cwd=None,
    nnodes=None,
    ntasks=None,
    tpn=None,
    antsproc=None,
    schedule='dynamic',
    remora=None,
):
    if queue in queues:
        nodes = queues[queue]
    else:
        nodes = queues['default']
    if len(cmd) > 0:
        parametric = False
        print('Running serial command: ' + cmd)
        ncmds = 1
        nnodes = 1
        ntasks = 1
    elif script_name is not None:
        # read commands file
        try:
            f = open(script_name, 'r')
        except:
            print('%s does not exist!' % script_name)
            sys.exit(0)
        script_cmds = f.readlines()
        f.close()

        # check for empty lines
        for s in script_cmds:
            if s.strip() == '':
                print('command file contains empty lines - please remove them first')
                sys.exit()

        # determine whether to use launcher
        ncmds = len(script_cmds)
        print('found %d commands' % ncmds)
        if ncmds == 1:
            # if only one, do not use launcher, which fails sometimes
            parametric = False
            cmd = script_cmds[0]
            print('Running serial command: ' + cmd)
        else:
            parametric = True
            print('Submitting parametric job file: ' + script_name)

    else:
        print(
            'ERROR: you must either specify a script name (using -s) or a command to run\n\n'
        )
        sys.exit()

    if qsubfile is None:
        qsubfile, qsubfilepath = mkstemp(
            prefix=jobname + "_", dir='.', suffix='.slurm', text=True
        )
        os.close(qsubfile)
    else:
        qsubfilepath = qsubfile

    print('Outputting SLURM commands to %s' % qsubfilepath)
    qsubfile = open(qsubfilepath, 'w')
    qsubfile.write('#!/bin/bash\n#\n')
    qsubfile.write('# SLURM control file automatically created by launch\n#\n')
    qsubfile.write('# Created on: {}\n'.format(datetime.now()))
    if parametric:
        # fill in the blanks
        if tpn is not None:
            # user specified the number of tasks per node; get the
            # number of nodes given that, evenly splitting tasks by
            # node
            nnodes = int(math.ceil(float(ncmds) / float(tpn)))
            ntasks = nnodes * int(tpn)
        elif nnodes is None and ntasks is None:
            # nothing is explicitly specified; use one task per core
            # based on the queue, with the minimum number of nodes
            # given that
            if nodes['cores'] is None:
                raise ValueError('Must define cores.')
            else:
                cores = nodes['cores']
            nnodes = int(math.ceil(ncmds / cores))
            ntasks = nnodes * cores
            print("Estimated %d nodes and %d tasks" % (nnodes, ntasks))
        elif nnodes is None and ntasks is not None:
            # number of total tasks is specified, but not nodes or
            # tpn; cannot calculate tpn, so just minimize the number
            # of nodes
            if nodes['cores'] is None:
                raise ValueError('Must define cores.')
            else:
                cores = nodes['cores']
            ntasks = int(ntasks)
            nnodes = int(math.ceil(ntasks / cores))
            print("Number of nodes not specified; estimated as %d" % nnodes)
        elif ntasks is None and nnodes is not None:
            # tasks and tpn not specified; set tasks to the number of
            # commands
            if nodes['cores'] is None:
                raise ValueError('Must define cores.')
            else:
                cores = nodes['cores']
            nnodes = int(nnodes)
            ntasks = int(nnodes * cores)
            print("Number of tasks not specified; estimated as %d" % ntasks)
        else:
            nnodes = int(nnodes)
            ntasks = int(ntasks)

        if nodes['max-nodes'] is not None:
            if nnodes > nodes['max-nodes']:
                nnodes = nodes['max-nodes']
        if nodes['max-cores'] is not None:
            if nnodes > nodes['max-cores']:
                nnodes = nodes['max-cores']

        qsubfile.write(
            '# Using parametric launcher with control file: %s\n#\n' % script_name
        )
        qsubfile.write('#SBATCH -N %d\n' % nnodes)
        qsubfile.write('#SBATCH -n %d\n' % ntasks)
    else:
        qsubfile.write('# Launching single command: %s\n#\n' % cmd)
        qsubfile.write('#SBATCH -N 1\n')
        qsubfile.write('#SBATCH -n 1\n')

    if cwd is not None:
        qsubfile.write('#SBATCH -D %s\n' % cwd)
    qsubfile.write('#SBATCH -J %s\n' % jobname)
    if outfile is not None:
        qsubfile.write('#SBATCH -o {0}\n'.format(outfile))
    else:
        qsubfile.write('#SBATCH -o {0}.o%j\n'.format(jobname))
    qsubfile.write('#SBATCH -p %s\n' % queue)
    qsubfile.write('#SBATCH -t %s\n' % runtime)

    if type(hold) is str:
        qsubfile.write("#SBATCH -d afterok")
        qsubfile.write(":{0}".format(int(hold)))
        qsubfile.write('\n')

    if projname is not None:
        qsubfile.write("#SBATCH -A {0}\n".format(projname))

    if email is not None:
        qsubfile.write('#SBATCH --mail-type=ALL\n')
        qsubfile.write('#SBATCH --mail-user=%s\n' % email)

    qsubfile.write('\numask 2\n\n')

    if cwd is None:
        cwd = os.getcwd()

    qsubfile.write('echo " Starting at $(date)"\n')
    qsubfile.write('start=$(date +%s)\n')
    qsubfile.write('echo " WORKING DIR: %s/"\n' % cwd)
    qsubfile.write('echo " JOB ID:      $SLURM_JOB_ID"\n')
    qsubfile.write('echo " JOB NAME:    $SLURM_JOB_NAME"\n')
    qsubfile.write('echo " NODES:       $SLURM_NODELIST"\n')
    qsubfile.write('echo " N NODES:     $SLURM_NNODES"\n')
    qsubfile.write('echo " N TASKS:     $SLURM_NTASKS"\n')

    if compiler == "gcc":
        qsubfile.write('module swap intel gcc\n')

    if remora is not None:
        qsubfile.write('module load remora\n')

    if antsproc is not None:
        qsubfile.write('export ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS=%d\n' % antsproc)

    if not parametric:
        qsubfile.write('set -x\n')
        if remora is not None:
            qsubfile.write('remora ' + cmd + '\n')
        else:
            qsubfile.write(cmd + '\n')
        qsubfile.write('set +x\n')
        if remora is not None:
            qsubfile.write('mv %s/remora_$SLURM_JOB_ID %s\n' % (cwd, remora))
    else:
        qsubfile.write('export LAUNCHER_SCHED=%s\n' % schedule)
        qsubfile.write('export LAUNCHER_JOB_FILE=%s\n' % script_name)
        if cwd is not None:
            qsubfile.write('export LAUNCHER_WORKDIR=%s\n' % cwd)
        else:
            qsubfile.write('export LAUNCHER_WORKDIR=$(pwd)\n')
        qsubfile.write('$LAUNCHER_DIR/paramrun\n')
    qsubfile.write('echo " "\necho " Job complete at $(date)"\necho " "\n')
    qsubfile.write('finish=$(date +%s)\n')
    qsubfile.write(
        'printf "Job duration: %02d:%02d:%02d (%d s)\n" $(((finish-start)/3600)) $(((finish-start)%3600/60)) $(((finish-start)%60)) $((finish-start))\n'
    )

    qsubfile.close()

    jobid = None
    if not test:
        process = subprocess.Popen(
            'sbatch %s' % qsubfilepath,
            shell=True,
            stdout=subprocess.PIPE,
            encoding='utf-8',
        )
        for line in process.stdout:
            print(line.strip())

            if line.find('Submitted batch job') == 0:
                jobid = int(line.strip().split(' ')[3])
        process.wait()

    if not keepqsubfile:
        print('Deleting qsubfile: %s' % qsubfilepath)
        os.remove(qsubfilepath)
    return jobid


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='process SLURM job.')
    parser.add_argument(
        '-N', '--nodes', help='minimum number of nodes', dest='nodes', default=None
    )
    parser.add_argument(
        '-n', '--ntasks', help='number of tasks to run', dest='ntasks', default=None
    )
    parser.add_argument(
        '-e',
        '--tasks-per-node',
        help='number of tasks per node',
        dest='tpn',
        default=None,
    )
    parser.add_argument(
        '-s', '--script', help='name of parallel script to run', dest='script_name'
    )
    parser.add_argument(
        '-r',
        '--runtime',
        help='maximum runtime for job',
        default='01:00:00',
        dest='runtime',
    )
    parser.add_argument(
        '-J', '--jobname', help='job name', default='launch', dest='jobname'
    )
    parser.add_argument(
        '-o', '--outfile', help='output file', default=None, dest='outfile'
    )
    parser.add_argument(
        '-p', '-q', '--queue', help='name of queue', default='normal', dest='queue'
    )
    parser.add_argument(
        '-A', '--projname', help='name of project', dest='projname', default='ANTS'
    )
    parser.add_argument(
        '-m', '--email', help='email address for notification', dest='email'
    )
    parser.add_argument(
        '-D', '--cwd', help='name of working directory', dest='directory'
    )
    parser.add_argument('-f', '--qsubfile', help='name of batch file', dest='qsubfile')
    parser.add_argument('-w', '--waitproc', help='process to wait for', dest='waitproc')
    parser.add_argument(
        '-k',
        '--keepqsubfile',
        help='keep qsub file',
        dest='keepqsubfile',
        action="store_true",
        default=False,
    )
    parser.add_argument(
        '-t',
        '--test',
        help='do not actually launch job',
        dest='test',
        action="store_true",
        default=False,
    )
    parser.add_argument(
        '-c',
        '--compiler',
        help='compiler (default=intel)',
        dest='compiler',
        default='intel',
    )
    parser.add_argument(
        '-a',
        '--antsproc',
        help='number of processes for ANTS',
        dest='antsproc',
        type=int,
    )
    parser.add_argument(
        '-x', '--remora', help='directory to save resource usage info using remora'
    )
    parser.add_argument(
        '-d',
        '-i',
        '--hold_jid',
        help='wait for this job id to complete before running',
        dest='hold',
        default=None,
    )
    parser.add_argument(
        '-b',
        '--schedule',
        default='interleaved',
        help="schedule type (default: interleaved)",
    )

    (args, command) = parser.parse_known_args(sys.argv[1:])

    if len(command) > 0:
        cmd = ' '.join(command)
    else:
        cmd = ''

    launch_slurm(
        cmd=cmd,
        script_name=args.script_name,
        runtime=args.runtime,
        jobname=args.jobname,
        outfile=args.outfile,
        projname=args.projname,
        queue=args.queue,
        email=args.email,
        qsubfile=args.qsubfile,
        keepqsubfile=args.keepqsubfile,
        test=args.test,
        compiler=args.compiler,
        hold=args.hold,
        cwd=args.directory,
        nnodes=args.nodes,
        ntasks=args.ntasks,
        tpn=args.tpn,
        antsproc=args.antsproc,
        schedule=args.schedule,
        remora=args.remora,
    )
