#! /usr/bin/env python

from __future__ import absolute_import
from __future__ import print_function
import re
import six


bool_feat_re = re.compile(r"^([a-z]+)(True|False)$")
int_feat_re = re.compile(r"^([a-z]+)([0-9]+)$")
real_feat_re = re.compile(r"^([a-z]+)([0-9]+\.?[0-9]*)$")
str_feat_re = re.compile(r"^([a-z]+)([A-Z][A-Za-z_0-9]+)$")


sqlite_keywords = """
    abort action add after all alter analyze and as asc attach
    autoincrement before begin between by cascade case cast check
    collate column commit conflict constraint create cross current_date
    current_time current_timestamp database default deferrable deferred
    delete desc detach distinct drop each else end escape except
    exclusive exists explain fail for foreign from full glob group
    having if ignore immediate in index indexed initially inner insert
    instead intersect into is isnull join key left like limit match
    natural no not notnull null of offset on or order outer plan pragma
    primary query raise references regexp reindex release rename
    replace restrict right rollback row savepoint select set table temp
    temporary then to transaction trigger union unique update using
    vacuum values view virtual when where""".split()


def parse_dir_feature(feat, number):
    bool_match = bool_feat_re.match(feat)
    if bool_match is not None:
        return (bool_match.group(1), "integer", int(bool_match.group(2) == "True"))
    int_match = int_feat_re.match(feat)
    if int_match is not None:
        return (int_match.group(1), "integer", float(int_match.group(2)))
    real_match = real_feat_re.match(feat)
    if real_match is not None:
        return (real_match.group(1), "real", float(real_match.group(2)))
    str_match = str_feat_re.match(feat)
    if str_match is not None:
        return (str_match.group(1), "text", str_match.group(2))
    return ("dirfeat%d" % number, "text", feat)


def larger_sql_type(type_a, type_b):
    assert type_a in [None, "text", "real", "integer"]
    assert type_b in [None, "text", "real", "integer"]

    if type_a is None:
        return type_b
    if type_b is None:
        return type_a
    if "text" in [type_a, type_b]:
        return "text"
    if "real" in [type_a, type_b]:
        return "real"
    assert type_a == type_b == "integer"
    return "integer"


def sql_type_and_value(value):
    if value is None:
        return None, None
    elif isinstance(value, bool):
        return "integer", int(value)
    elif isinstance(value, int):
        return "integer", value
    elif isinstance(value, float):
        return "real", value
    else:
        return "text", str(value)


def sql_type_and_value_from_str(value):
    if value == "None":
        return None, None
    elif value in ["True", "False"]:
        return "integer", value == "True"
    else:
        try:
            return "integer", int(value)
        except ValueError:
            pass
        try:
            return "real", float(value)
        except ValueError:
            pass
        return "text", str(value)


class FeatureGatherer:
    def __init__(self, features_from_dir=False, features_file=None):
        self.features_from_dir = features_from_dir

        self.dir_to_features = {}
        if features_file is not None:
            for line in open(features_file, "r").readlines():
                colon_idx = line.find(":")
                assert colon_idx != -1

                entries = [val.strip() for val in line[colon_idx+1:].split(",")]
                features = []
                for entry in entries:
                    equal_idx = entry.find("=")
                    assert equal_idx != -1
                    features.append((entry[:equal_idx],) +
                            sql_type_and_value_from_str(entry[equal_idx+1:]))

                self.dir_to_features[line[:colon_idx]] = features

    def get_db_features(self, dbname, logmgr):
        from os.path import dirname
        dn = dirname(dbname)

        features = self.dir_to_features.get(dn, [])[:]

        if self.features_from_dir:
            features.extend(parse_dir_feature(feat, i)
                    for i, feat in enumerate(dn.split("-")))

        for name, value in six.iteritems(logmgr.constants):
            features.append((name,) + sql_type_and_value(value))

        return features


def scan(fg, dbnames, progress=True):
    features = {}
    dbname_to_run_id = {}
    uid_to_run_id = {}
    next_run_id = 1

    from pytools import ProgressBar
    if progress:
        pb = ProgressBar("Scanning...", len(dbnames))

    for dbname in dbnames:
        from logpyle import LogManager
        try:
            logmgr = LogManager(dbname, "r")
        except Exception:
            print("Trouble with file '%s'" % dbname)
            raise

        unique_run_id = logmgr.constants.get("unique_run_id")
        run_id = uid_to_run_id.get(unique_run_id)

        if run_id is None:
            run_id = next_run_id
            next_run_id += 1

            if unique_run_id is not None:
                uid_to_run_id[unique_run_id] = run_id

        dbname_to_run_id[dbname] = run_id

        if progress:
            pb.progress()

        for fname, ftype, fvalue in fg.get_db_features(dbname, logmgr):
            if fname in features:
                features[fname] = larger_sql_type(ftype, features[fname])
            else:
                if ftype is None:
                    ftype = "text"
                features[fname] = ftype

        logmgr.close()

    if progress:
        pb.finished()

    return features, dbname_to_run_id


def make_name_map(map_str):
    import re
    result = {}

    if not map_str:
        return result

    map_re = re.compile(r"^([a-z_A-Z0-9]+)=([a-z_A-Z0-9]+)$")
    for fmap_entry in map_str.split(","):
        match = map_re.match(fmap_entry)
        assert match is not None
        result[match.group(1)] = match.group(2)

    return result


def transfer_data_table(db_conn, tbl_name, data_table):
    db_conn.executemany("insert into %s (%s) values (%s)" %
            (tbl_name,
                ", ".join(data_table.column_names),
                ", ".join("?" * len(data_table.column_names))),
            data_table.data)


def gather_single_file(outfile, infiles):
    from pytools import ProgressBar
    pb = ProgressBar("Importing...", len(infiles))

    import sqlite3
    db_conn = sqlite3.connect(outfile)

    from logpyle import _set_up_schema
    _set_up_schema(db_conn)

    from pickle import dumps

    seen_constants = set()
    seen_quantities = set()

    for dbname in infiles:
        pb.progress()

        from logpyle import LogManager
        logmgr = LogManager(dbname, "r")

        # transfer warnings
        transfer_data_table(db_conn, "warnings", logmgr.get_warnings())

        # transfer constants
        for key, val in six.iteritems(logmgr.constants):
            if key not in seen_constants:
                db_conn.execute("insert into constants values (?,?)",
                        (key, memoryview(dumps(val))))
                seen_constants.add(key)

        for qname, qdata in six.iteritems(logmgr.quantity_data):
            db_conn.execute("""insert into quantities values (?,?,?,?)""", (
                  qname, qdata.unit, qdata.description,
                  memoryview(dumps(qdata.default_aggregator))))

            if qname not in seen_quantities:
                db_conn.execute("""create table %s
                  (step integer, rank integer, value real)""" % qname)
                seen_quantities.add(qname)

            transfer_data_table(db_conn, qname, logmgr.get_table(qname))

        logmgr.close()

    pb.finished()

    db_conn.commit()
    db_conn.close()


def _normalize_types(x):
    # get rid of numpy types
    if isinstance(x, int):
        return int(x)
    if isinstance(x, float):
        return float(x)
    return x


def gather_multi_file(outfile, infiles, fmap, qmap, fg, features,
        dbname_to_run_id):
    from pytools import ProgressBar
    pb = ProgressBar("Importing...", len(infiles))

    feature_col_name_map = {}
    for fname in features:
        tgt_name = fmap.get(fname, fname)

        if tgt_name.lower() in sqlite_keywords:
            feature_col_name_map[fname] = tgt_name+"_"
        else:
            feature_col_name_map[fname] = tgt_name

    import sqlite3
    db_conn = sqlite3.connect(outfile)
    run_columns = [
            "id integer primary key",
            "dirname text",
            "filename text",
            ] + ["%s %s" % (feature_col_name_map[fname], ftype)
                    for fname, ftype in six.iteritems(features)]
    db_conn.execute("create table runs (%s)" % ",".join(run_columns))
    db_conn.execute("create index runs_id on runs (id)")

    db_conn.execute("""create table quantities (
            id integer primary key,
            name text,
            unit text,
            description text,
            rank_aggregator text
            )""")

    created_tables = set()

    from os.path import dirname, basename

    written_run_ids = set()

    for dbname in infiles:
        pb.progress()

        run_id = dbname_to_run_id[dbname]

        from logpyle import LogManager
        logmgr = LogManager(dbname, "r")

        if run_id not in written_run_ids:
            dbfeatures = fg.get_db_features(dbname, logmgr)
            qry = "insert into runs (%s) values (%s)" % (
                ",".join(["id", "dirname", "filename"]
                    + [feature_col_name_map[f[0]] for f in dbfeatures]),
                ",".join("?" * (len(dbfeatures)+3)))
            db_conn.execute(qry,
                    [run_id, dirname(dbname), basename(dbname)]
                    + [_normalize_types(f[2]) for f in dbfeatures])

            written_run_ids.add(run_id)

        for qname, qdat in six.iteritems(logmgr.quantity_data):
            tgt_qname = qmap.get(qname, qname)

            if tgt_qname not in created_tables:
                created_tables.add(tgt_qname)
                db_conn.execute("create table %s ("
                  "run_id integer, step integer, rank integer, value real)"
                  % tgt_qname)

                db_conn.execute("create index %s_main on %s (run_id,step,rank)" % (
                    tgt_qname, tgt_qname))

                agg = qdat.default_aggregator
                try:
                    agg = agg.__name__
                except AttributeError:
                    if agg is not None:
                        agg = str(agg)

                db_conn.execute("insert into quantities "
                        "(name,unit,description,rank_aggregator)"
                        "values (?,?,?,?)",
                        (tgt_qname, qdat.unit, qdat.description, agg))

            cursor = logmgr.db_conn.execute("select %s,step,rank,value from %s" % (
                run_id, qname))
            db_conn.executemany("insert into %s values (?,?,?,?)" % tgt_qname,
                    cursor)
        logmgr.close()
    pb.finished()

    db_conn.commit()
    db_conn.close()


def main():
    import sys
    from optparse import OptionParser

    parser = OptionParser(usage="%prog OUTDB DBFILES ...")
    parser.add_option("-1", "--single", action="store_true",
            help="Gather single-run instead of multi-run file")
    parser.add_option("-s", "--show-features", action="store_true",
            help="Only print the features found and quit")
    parser.add_option("-d", "--dir-features", action="store_true",
            help="Extract features from directory names")
    parser.add_option("-f", "--file-features", default=None,
            metavar="FILENAME",
            help="Read additional features from file, with lines like: "
            "'dirname: key=value, key=value'")
    parser.add_option("-m", "--feature-map", default=None,
            help="Specify a feature name map.",
            metavar="F1=FNAME1,F2=FNAME2")
    parser.add_option("-q", "--quantity-map", default=None,
            help="Specify a quantity name map.",
            metavar="Q1=QNAME1,Q2=QNAME2")
    options, args = parser.parse_args()

    if len(args) < 2:
        parser.print_help()
        sys.exit(1)

    outfile = args[0]
    from os.path import exists
    infiles = [fn for fn in args[1:] if exists(fn)]
    not_found_files = [fn for fn in args[1:] if not exists(fn)]
    if not_found_files:
        print("Warning: The follwing files were not found and are being ignored:")
        for fn in not_found_files:
            print("  ", fn)

    # list of run features as {name: sql_type}
    fg = FeatureGatherer(options.dir_features, options.file_features)
    features, dbname_to_run_id = scan(fg, infiles)

    fmap = make_name_map(options.feature_map)
    qmap = make_name_map(options.quantity_map)

    if options.show_features:
        for feat_name, feat_type in six.iteritems(features):
            print(fmap.get(feat_name, feat_name), feat_type)
        sys.exit(0)

    if options.single:
        if len(set(dbname_to_run_id.values())) > 1:
            raise ValueError(
                    "data seems to come from more than one run--"
                    "can't write single-run file")
        gather_single_file(outfile, infiles)
    else:
        gather_multi_file(outfile, infiles, fmap, qmap, fg, features,
                dbname_to_run_id)


if __name__ == "__main__":
    main()
