#! python
import os
import typer
from typing import Optional
from together_cli.src.core.instances import pprint_instances
from together_cli.src.model import serve_model
from together_cli.src.system import check_binary_exists, check_folders, check_lockable_drive

app = typer.Typer()
home_dir = os.path.expanduser("~")
default_together_home = os.path.join(home_dir, "together")


@app.command()
def check():
    is_slurm = check_binary_exists("sinfo")
    print("Slurm: ", is_slurm)
    is_singularity = check_binary_exists("singularity")
    print("Singularity: ", is_singularity)
    is_docker = check_binary_exists("docker")
    print("Docker: ", is_docker)


@app.command()
def serve(
        # required arguments
        model: str = typer.Option(...,
                                  prompt="What's the model name you want to serve?"),
        home_dir: str = typer.Option(
            default_together_home, help="The home directory for Together? It cannot be on an NFS drive."),
        data_dir: str = typer.Option(
            ..., prompt="The directory you want to store model weights? It could be on an NFS drive."),

        # optional, but suggested arguments
        gpus: str = typer.Option(
            None, help="GPU Specifiers (e.g., titanrtx:1), required if you are not using baremetal ndoes"),
        queue: str = typer.Option(None, help="Queue name - default is None"),
        singularity: bool = typer.Option(
            False, help="Use singularity to serve the model"),
        docker: Optional[bool] = typer.Option(
            False, help="Use docker to serve the model"),
        tags: Optional[str] = typer.Option("", help="tags"),
        account: str = typer.Option(
            None, help="Account name - default is None"),
        modules: str = typer.Option(
            None, help="Modules to load - default is None"),
        duration: str = typer.Option(
                "1:00:00", help="Duration of the job - default is '1:00:00'"),
        matchmaker_addr: str = typer.Option(
                "wss://api.together.xyz/websocket", help="Global Matchmaker address - leave it as it is in most cases"),
        port: int = typer.Option(
            8092, help="Port number - default is 8092-8093. In case of conflict, change it to a different number, increase by 2"),
        node_list: str = typer.Option(
                None, help="Node list - default is None"),
        cluster: str = typer.Option(
                "baremetal", help="Cluster Management System - default is 'baremetal'"),
        dry_run: bool = typer.Option(
                False, help="Only Generate submission scripts for review - default is False"),
        owner: str = typer.Option(
                None, help="Owner of the instance - default is None"),
    ):
    if docker and singularity:
        print("[ERROR] You can only choose one of docker or singularity")
        return
    if docker:
        print("[INFO] Containerization: Docker")
    elif singularity:
        print("[INFO] Containerization: Singularity")
    else:
        print("[ERROR] You must choose one of docker or singularity")

    if cluster != 'baremetal' and gpus is None:
        print("[ERROR] You must specify gpus if you are not using baremetal nodes")
        return
    # expand home_dir and data_dir if they start with ~
    if home_dir.startswith("~"):
        home_dir = os.path.expanduser(home_dir)
    if data_dir.startswith("~"):
        data_dir = os.path.expanduser(data_dir)

    check_folders(home_dir=home_dir, data_dir=data_dir)
    is_homedir_lockable = check_lockable_drive(home_dir)

    if not is_homedir_lockable:
        print(
            "[ERROR] Your home directory is not lockable. Please choose another directory.")
        return

    serve_model(
        model_name=model,
        queue_name=queue,
        home_dir=home_dir,
        data_dir=data_dir,
        matchmaker_addr=matchmaker_addr,
        tags=tags,
        use_docker=docker,
        use_singularity=singularity,
        gpus=gpus,
        account=account,
        modules=modules,
        node_list=node_list,
        port=port,
        duration=duration,
        cluster=cluster,
        dry_run=dry_run,
        owner = owner,
    )


@app.command()
def list():
    pprint_instances()


@app.command()
def main():
    print("TOMA")


@app.command()
def logs(
    instance_id: str = typer.Option(..., prompt="What's the ID of the instance you want to inspect?")):
    pass


if __name__ == "__main__":
    app()
