#!/usr/bin/env python3
import os
import sys
import time
import json
import socket
from threading import Thread, Lock

version = "2.0.2"
port = 8044
pid = os.getpid()


def read(file):
    if os.path.exists(file):
        with open(file, "r") as f:
            return json.load(f)
    return dict()


data = read("keys.db")
mu = Lock()


def store(file, keys):
    d = json.dumps(keys, indent=4, sort_keys=True)
    with open(file, "w") as f:
        f.write(d)


def log(s):
    date = time.strftime("%Y-%m-%d %H:%M:%S")
    print("[{}] [{}] {}".format(date, pid, s))


def db():
    while True:
        time.sleep(1)
        keys = read("keys.db")
        mu.acquire()
        if keys != data:
            store("keys.db", data)
            log("* DB saved on disk")
        mu.release()


def parse(req):
    command, key, value = None, None, None
    for i, c in enumerate(req.split(" ")):
        if i == 0:
            command = c
        elif i == 1:
            key = c
        elif i == 2:
            value = req[req.index(c) :]
        else:
            break
    if command == "get":
        if not key:
            return "(error): no key for get command"
        mu.acquire()
        if key in data:
            res = data[key]
        else:
            res = "(error): key not found"
        mu.release()
        return res
    elif command == "set":
        if not key:
            return "(error): no key for set command"
        if not value:
            return "(error): no value for set command"
        mu.acquire()
        data[key] = value
        mu.release()
        return "OK"
    elif command == "del":
        if not key:
            return "(error): no key for del command"
        mu.acquire()
        data.pop(key, None)
        mu.release()
        return "OK"
    elif command == "keys":
        mu.acquire()
        if key:
            keys = str(list(filter(lambda k: k.startswith(key), data.keys())))
        else:
            keys = str(data.keys()).removeprefix("dict_keys(").strip(")")
        mu.release()
        return keys
    else:
        return "(error): invalid command"


def recvall(conn):
    data = b""
    while True:
        packet = conn.recv(1024)
        data += packet
        if len(packet) < 1024:
            break
    return data


def handle_connection(conn, addr):
    while True:
        try:
            data = recvall(conn)
            # data = conn.recv(1024)
        except:
            conn.close()
            break
        if not data:
            conn.close()
            break
        req = data.decode("utf-8").strip()
        res = parse(req) + "\n"
        conn.sendall(res.encode("utf-8"))


def listen():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(("127.0.0.1", port))
    sock.listen()
    log("Listening at 127.0.0.1:%s" % port)
    while True:
        try:
            conn, addr = sock.accept()
            Thread(target=handle_connection, args=(conn, addr), daemon=True).start()
        except:
            sys.exit(0)


# Main
if len(sys.argv) > 1 and sys.argv[1] == "-p":
    mode = "persistence"
    persist = True
else:
    mode = "ephemeral"
    persist = False

t = """
     ___
    /\  \ 
   /::\  \       Erebor %s
  /:/\:\  \ 
 /:/  \:\  \ 
/:/__/ \:\__\    Running in %s mode
\:\  \ /:/  /    Port: %-10s 
 \:\  /:/  /     PID:  %-10s 
  \:\/:/  / 
   \::/  /             https://pypi.org/project/erebor 
    \/__/ 
        """ % (
    version,
    mode,
    port,
    pid,
)
print(t)
if persist:
    Thread(target=db, args=(), daemon=True).start()
listen()
