# -*- coding: utf-8 -*-

import uuid

import numpy as np

from sail_utils.cv.head.deep_sort import iou_matching
from sail_utils.cv.head.deep_sort import kalman_filter
from sail_utils.cv.head.deep_sort import linear_assignment
from sail_utils.cv.head.deep_sort.config import (
    MAX_AGE,
    MAX_IOU_DISTANCE,
    N_INIT
)
from sail_utils.cv.head.deep_sort.detection import Detection
from sail_utils.cv.head.deep_sort.track import Track


def _convert_detection_struct(from_struct: list) -> list:
    """
    convert detection format to deep sort compatible format
    :param from_struct:
    :return:
    """
    to_struct = []
    for each in from_struct:
        xmin = each['location'][0]
        ymin = each['location'][1]
        xmax = each['location'][2]
        ymax = each['location'][3]
        confidence = each['score']
        feature = each['feature'] if 'feature' in each else np.zeros(5, dtype=np.float64) + 1.0
        to_struct.append(
            Detection(
                np.array([xmin, ymin, xmax - xmin, ymax - ymin],
                         dtype=np.float64),
                confidence,
                feature=feature))
    return to_struct


class Tracker:
    """
    This is the multi-target tracker.

    Parameters
    ----------
    metric : nn_matching.NearestNeighborDistanceMetric
        A distance metric for measurement-to-track association.
    max_age : int
        Maximum number of missed misses before a track is deleted.
    n_init : int
        Number of consecutive detections before the track is confirmed. The
        track state is set to `Deleted` if a miss occurs within the first
        `n_init` frames.

    Attributes
    ----------
    metric : nn_matching.NearestNeighborDistanceMetric
        The distance metric used for measurement to track association.
    max_age : int
        Maximum number of missed misses before a track is deleted.
    n_init : int
        Number of frames that a track remains in initialization phase.
    kf : kalman_filter.KalmanFilter
        A Kalman filter to filter target trajectories in image space.
    tracks : List[Track]
        The list of active tracks at the current time step.

    """

    def __init__(self,
                 metric,
                 max_iou_distance=MAX_IOU_DISTANCE,
                 max_age=MAX_AGE,
                 n_init=N_INIT):
        self.metric = metric
        self.max_iou_distance = max_iou_distance
        self.max_age = max_age
        self.n_init = n_init

        self.kf = kalman_filter.KalmanFilter()
        self.tracks = []
        self._next_id = uuid.uuid1()

    def predict(self):
        """Propagate track state distributions one time step forward.

        This function should be called once every time step, before `update`.
        """
        for track in self.tracks:
            track.predict(self.kf)

    def result(self, time_stamp):
        track_result = []
        for track in self.tracks:
            if track.is_confirmed():
                track_bbox = track.to_tlwh()
                track_id = track.track_id
                track_x0 = int(track_bbox[0])
                track_x1 = int(track_bbox[0] + track_bbox[2])
                track_y0 = int(track_bbox[1])
                track_y1 = int(track_bbox[1] + track_bbox[3])
                track_result.append(dict(
                    id=track_id,
                    location=[track_x0, track_y0, track_x1, track_y1],
                    time_stamp=time_stamp
                ))
        return sorted(track_result, key=lambda x: (x['time_stamp'], x['id']))

    def update(self, detections):
        """
        Perform measurement update and track management.

        Parameters
        ----------
        detections : List[dict]
            A list of detections at the current time step.

        """
        # Run matching cascade.
        detections = _convert_detection_struct(detections)
        matches, unmatched_tracks, unmatched_detections = \
            self._match(detections)

        # Update track set.
        for track_idx, detection_idx in matches:
            self.tracks[track_idx].update(
                self.kf, detections[detection_idx])
        for track_idx in unmatched_tracks:
            self.tracks[track_idx].mark_missed()
        for detection_idx in unmatched_detections:
            self._initiate_track(detections[detection_idx])
        self.tracks = [t for t in self.tracks if not t.is_deleted()]

        # Update distance metric.
        active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
        features, targets = [], []
        for track in self.tracks:
            if not track.is_confirmed():
                continue
            features += track.features
            targets += [track.track_id for _ in track.features]
            track.features = []
        self.metric.partial_fit(
            np.asarray(features), np.asarray(targets), active_targets)

    def _match(self, detections):
        def gated_metric(tracks, dets, track_indices, detection_indices):
            features = np.array([dets[i].feature for i in detection_indices])
            targets = np.array([tracks[i].track_id for i in track_indices])
            cost_matrix = self.metric.distance(features, targets)
            cost_matrix = linear_assignment.gate_cost_matrix(
                self.kf, cost_matrix, tracks, dets, track_indices,
                detection_indices)

            return cost_matrix

        # Split track set into confirmed and unconfirmed tracks.
        confirmed_tracks = [
            i for i, t in enumerate(self.tracks) if t.is_confirmed()]
        unconfirmed_tracks = [
            i for i, t in enumerate(self.tracks) if not t.is_confirmed()]

        # Associate confirmed tracks using appearance features.
        matches_a, unmatched_tracks_a, unmatched_detections = \
            linear_assignment.matching_cascade(
                gated_metric, self.metric.matching_threshold, self.max_age,
                self.tracks, detections, confirmed_tracks)

        # Associate remaining tracks together with unconfirmed tracks using IOU.
        iou_track_candidates = unconfirmed_tracks + [
            k for k in unmatched_tracks_a if
            self.tracks[k].time_since_update <= self.max_age]
        unmatched_tracks_a = [
            k for k in unmatched_tracks_a if
            self.tracks[k].time_since_update > self.max_age]
        matches_b, unmatched_tracks_b, unmatched_detections = \
            linear_assignment.min_cost_matching(
                iou_matching.iou_cost, self.max_iou_distance, self.tracks,
                detections, iou_track_candidates, unmatched_detections)

        matches = matches_a + matches_b
        unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
        return matches, unmatched_tracks, unmatched_detections

    def _initiate_track(self, detection):
        mean, covariance = self.kf.initiate(detection.to_xyah())
        self.tracks.append(Track(
            mean, covariance, self._next_id, self.n_init, self.max_age,
            detection.feature))
        self._next_id = uuid.uuid1()
