#!/usr/bin/env python3
import argparse
import json
import os
import subprocess
import sys

from os.path import expanduser, isfile, join

import boto3

from botocore.exceptions import ClientError


conf_file_name = '.sshaws.conf'

home = expanduser("~")
conf_file_path = join(home, conf_file_name)

conf = {}
if isfile(join(home, conf_file_name)):
    with open(conf_file_path, 'r') as conf_file:
        conf = json.load(conf_file)

default_region = boto3.session.Session().region_name
if not default_region:
    default_region = 'us-east-1'

default_regions = conf.get('regions', [default_region])

default_key_file_path_private = conf.get('key_file_path_private', f'{home}/.ssh/id_rsa')
default_key_file_path_public = conf.get('key_file_path_public', f'{home}/.ssh/id_rsa.pub')
default_os_user = conf.get('os_user', 'ec2-user')
default_use_private_ip = conf.get('use_private_ip', False)
# Backwards compat due to typo with versions <= 0.3.2
if not default_use_private_ip:
    default_use_private_ip = conf.get('use_private_id', False)
default_use_ssm = conf.get('use_ssm', False)
default_forward_agent = conf.get('forward_agent', False)
default_aliases_dict = conf.get('aliases', {})

parser = argparse.ArgumentParser(description=f'Connect to any ec2 instance your IAM user has rights for and that are reachable. You can use {conf_file_path} to preconfigure the options. See https://github.com/FrederikP/sshaws/blob/master/README.md')
parser.add_argument('instance_id', type=str, help='Id of the instance to connect to. Example: i-0bd8e1251a4713c17')
parser.add_argument('--os-user', type=str, default=default_os_user, help=f'OS username to use for connection. Default: {default_os_user}')
parser.add_argument('--public-key-file', type=argparse.FileType('r'),
                    default=default_key_file_path_public, help=f'Public key file to use for connection. Default: {default_key_file_path_public}')
parser.add_argument('--private-key-file', type=argparse.FileType('r'),
                    default=default_key_file_path_private, help=f'Private key file to use for connection. Default: {default_key_file_path_private}')
parser.add_argument('--regions', type=str, nargs='+', default=default_regions, help=f'Look for the instance in the given regions. Default: {default_regions}')
parser.add_argument('--use-private-ip', action='store_true', help=f'Use private IP even if public IP is available. sshaws will use private ip automatically, if no public IP is available.', default=default_use_private_ip)
parser.add_argument('--use-ssm', action='store_true', help=f'Use AWS System Manager Session Manager to proxy the connection', default=default_use_ssm)
parser.add_argument('-a', '--forward-agent', action='store_true', help=f'Enables ssh agent forwarding', default=default_forward_agent)
args = parser.parse_args()

def connect(region):
    instance_id = default_aliases_dict.get(args.instance_id, args.instance_id)
    ec2_client = boto3.client('ec2', region_name=region)
    try:
        print(f'Looking for instance {instance_id} in region {region}')
        response = ec2_client.describe_instances(InstanceIds=[instance_id])
    except ClientError as e:
        if e.response['Error']['Code'] != 'InvalidInstanceID.NotFound':
            raise
        else:
            print(f'Did not find instance {instance_id} in region {region}')
            return None
    instance = response['Reservations'][0]['Instances'][0]
    zone = instance['Placement']['AvailabilityZone']
    if args.use_private_ip or 'PublicIpAddress' not in instance:
        ip = instance['PrivateIpAddress']
    else:
        ip = instance['PublicIpAddress']
    connect_client = boto3.client('ec2-instance-connect', region_name=region)
    connect_client.send_ssh_public_key(
        InstanceId=instance_id,
        InstanceOSUser=args.os_user,
        SSHPublicKey=args.public_key_file.read(),
        AvailabilityZone=zone
    )
    return ip

ip_to_connect_to = None
for region in args.regions:
    ip_to_connect_to = connect(region)
    if ip_to_connect_to:
        break
if not ip_to_connect_to:
    print(f'Error: Did not find instance id {args.instance_id} in any region: {args.regions}')
    sys.exit(1)

command_list = ['ssh']
if args.forward_agent:
    command_list.append('-A')
if args.use_ssm:
    command_list.extend(['-o', 'ProxyCommand=sh -c "aws ssm start-session --target %h --document-name AWS-StartSSHSession --parameters \'portNumber=%p\'"'])
command_list.extend(['-i', args.private_key_file.name, f'{args.os_user}@{args.instance_id if args.use_ssm else ip_to_connect_to}'])
subprocess.run(command_list)
