#!/usr/bin/env python
import os
import time
import multiprocessing
import dddm
import wimprates as wr
import numpy as np
import argparse
import json

assert wr.__version__ != '0.2.2'


def main():
    print("run_dddm_multinest.py::\tstart")
    args = parse_arguments()

    if args.sampler == 'multinest':
        rank = get_multiprocessing_rank()
    time.sleep(5)

    print(f"run_dddm_multinest.py::\tstart for mw = {args.mw}, log-cross-section = "
          f"{args.cross_section}. Fitting {args.nparams} parameters")
    set_context(args)

    if args.target not in dddm.experiment:
        # Add a placeholder
        dddm.experiment[args.target] = {'type': 'combined'}

    stats = dddm.CombinedInference(
        targets=tuple(args.sub_experiments),
        detector_name=args.target,
        verbose=args.verbose)

    stats = set_config(stats, args)
    log = stats.log

    if args.multicore_hash != "":
        stats.get_save_dir(_hash=args.multicore_hash)
    if stats.config['sampler'] == 'multinest':
        stats.run_multinest()
    elif stats.config['sampler'] == 'nestle':
        stats.run_nestle()
    if args.multicore_hash == "" or (args.sampler == 'multinest' and rank == 0):
        log.info('saving the results')
        stats.save_results()
        stats.save_sub_configs()
    assert stats.log_dict['did_run']

    log.info(f"run_dddm_multinest.py::\t"
             f"finished for mw = {args.mw}, sigma = {args.cross_section}")
    print("finished, bye bye")


def set_context(args):
    if args.context_from_json is not None:
        context = open_json(args.context_from_json)

        dddm.context.context.update(context)


def get_multiprocessing_rank():
    # MPI functionality only used for multinest
    try:
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
    except (ModuleNotFoundError, ImportError) as e:
        raise ModuleNotFoundError('Cannot run in multicore mode as'
                                  ' mpi4py is not installed') from e
    print(f"MPI-info\tn_cores: {multiprocessing.cpu_count()}\t"
          f"pid: {os.getpid()}\t"
          f"ppid: {os.getppid()}\trank{rank}")
    return rank


def set_config(model_simulator, args):
    update_config = {
        'mw': np.log10(args.mw),
        'sigma': args.cross_section,
        'sampler': args.sampler,
        'poisson': args.poisson,
        'notes': args.notes,
        'earth_shielding': args.shielding,
        'n_energy_bins': args.bins,
        'fit_parameters': model_simulator.known_parameters[:args.nparams],
        'nlive': args.nlive,
        'tol': args.tol}

    model_simulator.set_prior(args.priors_from)
    if 'fixed' not in args.priors_from:
        model_simulator.config['prior']['log_mass'] = {
            'range': [int(np.log10(args.mw)) - 2.5, int(np.log10(args.mw)) + 3.5],
            'prior_type': 'flat'}
        model_simulator.config['prior']['log_cross_section'] = {
            'range': [int(args.cross_section) - 7, int(args.cross_section) + 5],
            'prior_type': 'flat'}
    model_simulator.config['prior']['log_mass']['param'] = \
        model_simulator.config['prior']['log_mass']['range']
    model_simulator.config['prior']['log_cross_section']['param'] = \
        model_simulator.config['prior']['log_cross_section']['range']

    update_config['halo_model'] = dddm.VerneSHM() if args.shielding else dddm.SHM()
    model_simulator.config.update(update_config)
    update_keys = list(update_config.keys())
    update_keys += ['prior', 'halo_model']
    model_simulator.copy_config(update_keys)
    return model_simulator


def parse_arguments():
    parser = argparse.ArgumentParser(description="DDDM run fit for combined targets")
    parser.add_argument(
        '-sampler',
        type=str,
        default='multinest',
        help="sampler (multinest or nestle)")
    parser.add_argument('-mw',
                        type=np.float,
                        default=50.,
                        help="wimp mass")
    parser.add_argument(
        '-context_from_json',
        type=str,
        default=None,
        help="Get the context from a json file")
    parser.add_argument(
        '-cross_section',
        type=np.float,
        default=-45,
        help="wimp cross-section")
    parser.add_argument(
        '-nlive',
        type=int,
        default=1024,
        help="live points used by multinest")
    parser.add_argument(
        '-tol',
        type=float,
        default=0.1,
        help="tolerance for optimisation (see multinest option dlogz)")
    parser.add_argument(
        '-notes',
        type=str,
        default="default",
        help="notes on particular settings")
    parser.add_argument(
        '-bins',
        type=int,
        default=10,
        help="the number of energy bins")
    parser.add_argument(
        '-nparams',
        type=int,
        default=2,
        help="Number of parameters to fit")
    parser.add_argument(
        '-priors_from',
        type=str,
        default="Pato_2010",
        help="Obtain priors from paper <priors_from>")
    parser.add_argument(
        '-verbose',
        type=float,
        default=0,
        help="Set to 0 (no print statements), 1 (some print statements) "
             "or >1 (a lot of print statements). Set the level of print "
             "statements while fitting.")
    parser.add_argument(
        '-multicore_hash',
        type=str,
        default="",
        help="no / default, override internal determination if we need "
             "to take into account earth shielding.")

    parser.add_argument(
        '-target',
        type=str,
        default='Combined',
        help="Target material of the detector (Xe/Ge/Ar)")
    parser.add_argument('-sub_experiments', nargs='*',
                        help="Extra directories to look for data")
    parser.add_argument('--poisson', action='store_true',
                        help="add poisson noise to data")
    parser.add_argument('--shielding', action='store_true',
                        help="add shielding to simulation")
    return parser.parse_args()


def open_json(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f'{path} does not exist!')
    with open(path) as file:
        json_file = json.load(file)
    return json_file


if __name__ == '__main__':
    main()
