#!/usr/bin/env python
#
# launch script for TACC
# deals with both command files for parametric launcher
# and with single commands

import argparse
import logging
import os
import sys
from tempfile import *
import subprocess
import math
from datetime import datetime
from launch import config


config_file = config.find_config()
defaults = {'account': None, 'cores': None, 'max-nodes': None, 'max-cores': None}
try:
    queues = config.load_config(config_file)
    queues['default'] = defaults
    for queue, settings in queues.items():
        queues[queue] = defaults.copy()
        if 'global' in queues:
            queues[queue].update(queues['global'])
        queues[queue].update(settings)
except IOError as err:
    queues = {'default': defaults}


def launch_slurm(
    command='',
    script_name=None,
    runtime='01:00:00',
    job_name='launch',
    out_file=None,
    account=None,
    queue_name='normal',
    email=None,
    qsub_file=None,
    keep_qsub_file=False,
    test=False,
    compiler='intel',
    hold=None,
    cwd=None,
    n_nodes=None,
    n_tasks=None,
    tpn=None,
    n_ants_proc=None,
    schedule='dynamic',
    remora=None,
):
    logger = logging.getLogger()
    if queue_name in queues:
        nodes = queues[queue_name]
    else:
        nodes = queues['default']
    if len(command) > 0:
        parametric = False
        logger.info(f'Running serial command: {command}')
        n_commands = 1
        n_nodes = 1
        n_tasks = 1
    elif script_name is not None:
        # read commands file
        try:
            f = open(script_name, 'r')
        except:
            logger.error(f'Script does not exist: {script_name}')
            sys.exit(0)
        script_commands = f.readlines()
        f.close()

        # check for empty lines
        for s in script_commands:
            if s.strip() == '':
                logger.error('Command file contains empty lines - please remove them')
                sys.exit()

        # determine whether to use launcher
        n_commands = len(script_commands)
        logger.info(f'Found {n_commands} commands')
        if n_commands == 1:
            # if only one, do not use launcher, which fails sometimes
            parametric = False
            command = script_commands[0]
            logger.info(f'Running serial command: {command}')
        else:
            parametric = True
            logger.info(f'Submitting parametric job file: {script_name}')

    else:
        logger.error(
            'You must either specify a script name (using -s) or a command to run'
        )
        sys.exit()

    if qsub_file is None:
        qsub_file, qsub_file_path = mkstemp(
            prefix=job_name + "_", dir='.', suffix='.slurm', text=True
        )
        os.close(qsub_file)
    else:
        qsub_file_path = qsub_file

    logger.info(f'Outputting SLURM commands to {qsub_file_path}')
    qsub_file = open(qsub_file_path, 'w')
    qsub_file.write('#!/bin/bash\n#\n')
    qsub_file.write('# SLURM control file automatically created by launch\n#\n')
    qsub_file.write(f'# Created on: {datetime.now()}\n')
    if parametric:
        # fill in the blanks
        if tpn is not None:
            # user specified the number of tasks per node; get the
            # number of nodes given that, evenly splitting tasks by
            # node
            n_nodes = int(math.ceil(float(n_commands) / float(tpn)))
            n_tasks = n_nodes * int(tpn)
        elif n_nodes is None and n_tasks is None:
            # nothing is explicitly specified; use one task per core
            # based on the queue, with the minimum number of nodes
            # given that
            if nodes['cores'] is None:
                raise ValueError('Must define cores.')
            else:
                cores = nodes['cores']
            n_nodes = int(math.ceil(n_commands / cores))
            n_tasks = n_nodes * cores
            logger.info(f'Estimated {n_nodes} nodes and {n_tasks} tasks')
        elif n_nodes is None and n_tasks is not None:
            # number of total tasks is specified, but not nodes or
            # tpn; cannot calculate tpn, so just minimize the number
            # of nodes
            if nodes['cores'] is None:
                raise ValueError('Must define cores.')
            else:
                cores = nodes['cores']
            n_tasks = int(n_tasks)
            n_nodes = int(math.ceil(n_tasks / cores))
            logger.info(f'Number of nodes not specified; estimated as {n_nodes}')
        elif n_tasks is None and n_nodes is not None:
            # tasks and tpn not specified; set tasks to the number of
            # commands
            if nodes['cores'] is None:
                raise ValueError('Must define cores.')
            else:
                cores = nodes['cores']
            n_nodes = int(n_nodes)
            n_tasks = int(n_nodes * cores)
            logger.info(f'Number of tasks not specified; estimated as {n_tasks}')
        else:
            n_nodes = int(n_nodes)
            n_tasks = int(n_tasks)

        if nodes['max-nodes'] is not None:
            if n_nodes > nodes['max-nodes']:
                n_nodes = nodes['max-nodes']
        if nodes['max-cores'] is not None:
            if n_nodes > nodes['max-cores']:
                n_nodes = nodes['max-cores']

        qsub_file.write(
            '# Using parametric launcher with control file: %s\n#\n' % script_name
        )
        qsub_file.write(f'#SBATCH -N {n_nodes:d}\n')
        qsub_file.write(f'#SBATCH -n {n_tasks:d}\n')
    else:
        qsub_file.write(f'# Launching single command: {command}\n#\n')
        qsub_file.write('#SBATCH -N 1\n')
        qsub_file.write('#SBATCH -n 1\n')

    if cwd is not None:
        qsub_file.write(f'#SBATCH -D {cwd}\n')
    qsub_file.write(f'#SBATCH -J {job_name}\n')
    if out_file is not None:
        qsub_file.write(f'#SBATCH -o {out_file}\n')
    else:
        qsub_file.write(f'#SBATCH -o {job_name}.o%j\n')
    qsub_file.write(f'#SBATCH -p {queue_name}\n')
    qsub_file.write(f'#SBATCH -t {runtime}\n')

    if type(hold) is str:
        qsub_file.write(f'#SBATCH -d afterok:{hold:d}\n')

    account = account if account is not None else nodes['account']
    if account is not None:
        qsub_file.write(f'#SBATCH -A {account}\n')

    if email is not None:
        qsub_file.write('#SBATCH --mail-type=ALL\n')
        qsub_file.write(f'#SBATCH --mail-user={email}\n')

    qsub_file.write('\numask 2\n\n')

    if cwd is None:
        cwd = os.getcwd()

    qsub_file.write('echo " Starting at $(date)"\n')
    qsub_file.write('start=$(date +%s)\n')
    qsub_file.write('echo " WORKING DIR: %s/"\n' % cwd)
    qsub_file.write('echo " JOB ID:      $SLURM_JOB_ID"\n')
    qsub_file.write('echo " JOB NAME:    $SLURM_JOB_NAME"\n')
    qsub_file.write('echo " NODES:       $SLURM_NODELIST"\n')
    qsub_file.write('echo " N NODES:     $SLURM_NNODES"\n')
    qsub_file.write('echo " N TASKS:     $SLURM_NTASKS"\n')

    if compiler == "gcc":
        qsub_file.write('module swap intel gcc\n')

    if remora is not None:
        qsub_file.write('module load remora\n')

    if n_ants_proc is not None:
        qsub_file.write(
            f'export ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={n_ants_proc:d}\n'
        )

    if not parametric:
        qsub_file.write('set -x\n')
        if remora is not None:
            qsub_file.write(f'remora {command}\n')
        else:
            qsub_file.write(f'{command}\n')
        qsub_file.write('set +x\n')
        if remora is not None:
            qsub_file.write(f'mv {cwd}/remora_$SLURM_JOB_ID {remora}\n')
    else:
        qsub_file.write(f'export LAUNCHER_SCHED={schedule}\n')
        qsub_file.write(f'export LAUNCHER_JOB_FILE={script_name}\n')
        if cwd is not None:
            qsub_file.write(f'export LAUNCHER_WORKDIR={cwd}\n')
        else:
            qsub_file.write('export LAUNCHER_WORKDIR=$(pwd)\n')
        qsub_file.write('$LAUNCHER_DIR/paramrun\n')
    qsub_file.write('echo " "\necho " Job complete at $(date)"\necho " "\n')
    qsub_file.write('finish=$(date +%s)\n')
    qsub_file.write(
        'printf "Job duration: %02d:%02d:%02d (%d s)\n" $(((finish-start)/3600)) $(((finish-start)%3600/60)) $(((finish-start)%60)) $((finish-start))\n'
    )

    qsub_file.close()

    job_id = None
    if not test:
        process = subprocess.Popen(
            f'sbatch {qsub_file_path}',
            shell=True,
            stdout=subprocess.PIPE,
            encoding='utf-8',
        )
        for line in process.stdout:
            logger.info(line.strip())

            if line.find('Submitted batch job') == 0:
                job_id = int(line.strip().split(' ')[3])
        process.wait()

    if not keep_qsub_file:
        logger.info(f'Deleting qsubfile: {qsub_file_path}')
        os.remove(qsub_file_path)
    return job_id


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='process SLURM job.')
    parser.add_argument(
        '-N', '--nodes', help='minimum number of nodes', dest='nodes', default=None
    )
    parser.add_argument(
        '-n', '--ntasks', help='number of tasks to run', dest='ntasks', default=None
    )
    parser.add_argument(
        '-e',
        '--tasks-per-node',
        help='number of tasks per node',
        dest='tpn',
        default=None,
    )
    parser.add_argument(
        '-s', '--script', help='name of parallel script to run', dest='script_name'
    )
    parser.add_argument(
        '-r',
        '--runtime',
        help='maximum runtime for job',
        default='01:00:00',
        dest='runtime',
    )
    parser.add_argument(
        '-J', '--jobname', help='job name', default='launch', dest='jobname'
    )
    parser.add_argument(
        '-o', '--outfile', help='output file', default=None, dest='outfile'
    )
    parser.add_argument(
        '-p', '-q', '--queue', help='name of queue', default='normal', dest='queue'
    )
    parser.add_argument(
        '-A', '--projname', help='name of project', dest='projname', default=None
    )
    parser.add_argument(
        '-m', '--email', help='email address for notification', dest='email'
    )
    parser.add_argument(
        '-D', '--cwd', help='name of working directory', dest='directory'
    )
    parser.add_argument('-f', '--qsubfile', help='name of batch file', dest='qsubfile')
    parser.add_argument('-w', '--waitproc', help='process to wait for', dest='waitproc')
    parser.add_argument(
        '-k',
        '--keepqsubfile',
        help='keep qsub file',
        dest='keepqsubfile',
        action="store_true",
        default=False,
    )
    parser.add_argument(
        '-t',
        '--test',
        help='do not actually launch job',
        dest='test',
        action="store_true",
        default=False,
    )
    parser.add_argument(
        '-c',
        '--compiler',
        help='compiler (default=intel)',
        dest='compiler',
        default='intel',
    )
    parser.add_argument(
        '-a',
        '--antsproc',
        help='number of processes for ANTS',
        dest='antsproc',
        type=int,
    )
    parser.add_argument(
        '-x', '--remora', help='directory to save resource usage info using remora'
    )
    parser.add_argument(
        '-d',
        '-i',
        '--hold_jid',
        help='wait for this job id to complete before running',
        dest='hold',
        default=None,
    )
    parser.add_argument(
        '-b',
        '--schedule',
        default='interleaved',
        help="schedule type (default: interleaved)",
    )
    parser.add_argument(
        '-v',
        '--verbosity',
        default=2,
        type=int,
        help='Verbosity level (0 - errors only, 1 - warnings, 2 - information)'
    )

    (args, command_args) = parser.parse_known_args(sys.argv[1:])

    if len(command_args) > 0:
        cmd = ' '.join(command_args)
    else:
        cmd = ''

    levels = {0: logging.ERROR, 1: logging.WARN, 2: logging.INFO}
    logging.basicConfig(stream=sys.stdout, level=levels[args.verbosity])

    launch_slurm(
        command=cmd,
        script_name=args.script_name,
        runtime=args.runtime,
        job_name=args.jobname,
        out_file=args.outfile,
        account=args.projname,
        queue_name=args.queue,
        email=args.email,
        qsub_file=args.qsubfile,
        keep_qsub_file=args.keepqsubfile,
        test=args.test,
        compiler=args.compiler,
        hold=args.hold,
        cwd=args.directory,
        n_nodes=args.nodes,
        n_tasks=args.ntasks,
        tpn=args.tpn,
        n_ants_proc=args.antsproc,
        schedule=args.schedule,
        remora=args.remora,
    )
