from __future__ import annotations

import atexit
import json
import signal
import threading
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple

import requests

from .types import Config, SdkMetric

APDEX_T_MS_DEFAULT = 500
MIN_BUCKET_MS = 60_000

BINS_MS: Tuple[int, ...] = (
    0,
    50,
    100,
    200,
    300,
    400,
    500,
    750,
    1000,
    1500,
    2000,
    3000,
    5000,
    8000,
    12000,
    20000,
    30000,
    45000,
    60000,
    120000,
    300000,
)


@dataclass
class AggBucket:
    ts_bucket_ms: int
    service: Optional[str]
    route: str
    method: str
    env: Optional[str]
    release: Optional[str]
    req_count: int = 0
    err_count: int = 0
    sum_dur_ms: float = 0.0
    hist_counts: List[int] = field(default_factory=lambda: [0] * len(BINS_MS))
    sat_count: int = 0
    tol_count: int = 0
    tot_count: int = 0


def _bucket_minute(ts_ms: int) -> int:
    import time

    if ts_ms <= 0:
        ts_ms = int(time.time() * 1000)
    return (ts_ms // MIN_BUCKET_MS) * MIN_BUCKET_MS


def _bin_index(dur_ms: int) -> int:
    if dur_ms < 0:
        dur_ms = 0
    for i in range(len(BINS_MS) - 1):
        if dur_ms < BINS_MS[i + 1]:
            return i
    return len(BINS_MS) - 1


def _resolve_batch_params(cfg: Config) -> Tuple[int, int]:
    env_str = (cfg.env or "PROD").upper()

    if cfg.batchSize and cfg.batchSize > 0:
        batch_size = cfg.batchSize
    else:
        if env_str == "DEV":
            batch_size = 100
        elif env_str == "HMG":
            batch_size = 300
        else:
            batch_size = 1000

    if cfg.flushIntervalMs and cfg.flushIntervalMs > 0:
        flush_interval_ms = cfg.flushIntervalMs
    else:
            flush_interval_ms = 120_000

    return batch_size, flush_interval_ms


def _parse_int_safe(v: Any) -> Optional[int]:
    try:
        n = int(v)
        return n if n >= 0 else None
    except Exception:
        return None


def _bucket_key(
    m: SdkMetric,
) -> Tuple[int, str, str, Optional[str], Optional[str], Optional[str]]:
    ts_bucket = _bucket_minute(m.ts)
    method = (m.method or "GET").upper()
    route = m.route or "/"
    service = m.service
    env = m.env.upper() if isinstance(m.env, str) else m.env
    release = m.release
    return ts_bucket, method, route, service, env, release


def _agg_from_metric(
    m: SdkMetric,
    apdex_t_ms: int,
    existing: Optional[AggBucket] = None,
) -> AggBucket:
    if existing is None:
        ts_bucket, method, route, service, env, release = _bucket_key(m)
        existing = AggBucket(
            ts_bucket_ms=ts_bucket,
            service=service,
            route=route,
            method=method,
            env=env,
            release=release,
        )

    existing.req_count += 1
    if m.status >= 500:
        existing.err_count += 1

    dur_ms = int(m.dur_ms or 0)
    if dur_ms < 0:
        dur_ms = 0
    existing.sum_dur_ms += dur_ms

    idx = _bin_index(dur_ms)
    existing.hist_counts[idx] += 1

    if apdex_t_ms > 0:
        if dur_ms <= apdex_t_ms:
            existing.sat_count += 1
        elif dur_ms <= 4 * apdex_t_ms:
            existing.tol_count += 1
        existing.tot_count += 1

    return existing


def _serialize_agg(b: AggBucket) -> Dict[str, Any]:
    return {
        "tsBucketMs": b.ts_bucket_ms,
        "service": b.service,
        "route": b.route,
        "method": b.method,
        "reqCount": b.req_count,
        "errCount": b.err_count,
        "sumDurMs": int(b.sum_dur_ms),
        "histCounts": list(b.hist_counts),
        "satCount": b.sat_count,
        "tolCount": b.tol_count,
        "totCount": b.tot_count,
        "env": b.env,
        "release": b.release,
    }


class _BatchManager:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self._batch_size, self._flush_interval_ms = _resolve_batch_params(cfg)
        self._buf: List[SdkMetric] = []
        self._lock = threading.Lock()
        self._flushing = False
        self._stop = False
        self._protected: Set[str] = set(cfg.protectedRoutes or [])
        self._timer: Optional[threading.Timer] = None
        self._request_timeout_ms = getattr(cfg, "requestTimeoutMs", None) or 30_000
        self._start_timer()
        self._install_exit_hooks()

    def _log(self, *args: Any):
        if self.cfg.debug:
            print("[observify]", *args)

    def _start_timer(self):
        def _tick():
            try:
                self.flush()
            finally:
                if not self._stop:
                    self._timer = threading.Timer(
                        self._flush_interval_ms / 1000.0, _tick
                    )
                    self._timer.daemon = True
                    self._timer.start()

        _tick()

    def _install_exit_hooks(self):
        def _drain(*_a):
            try:
                self.flush()
            except Exception:
                pass

        atexit.register(_drain)
        try:
            signal.signal(signal.SIGINT, lambda *_: (_drain(), exit(0)))
            signal.signal(signal.SIGTERM, lambda *_: (_drain(), exit(0)))
        except Exception:
            pass

    def add(self, m: SdkMetric):
        with self._lock:
            self._buf.append(m)
            if len(self._buf) >= self._batch_size:
                threading.Thread(target=self.flush, daemon=True).start()

    def _build_agg_payload(self, batch: List[SdkMetric]) -> List[Dict[str, Any]]:
        if not batch:
            return []

        apdex_t = self.cfg.apdexTMs or APDEX_T_MS_DEFAULT
        buckets: Dict[
            Tuple[int, str, str, Optional[str], Optional[str], Optional[str]],
            AggBucket,
        ] = {}

        for m in batch:
            key = _bucket_key(m)
            agg = buckets.get(key)
            agg = _agg_from_metric(m, apdex_t_ms=apdex_t, existing=agg)
            buckets[key] = agg

        payload = [_serialize_agg(b) for b in buckets.values()]
        return payload

    def flush(self):
        with self._lock:
            if self._flushing or not self._buf:
                return
            self._flushing = True
            batch = self._buf
            self._buf = []

        try:
            payload = self._build_agg_payload(batch)
            if not payload:
                return

            self._log("sending_agg", {"count": len(payload)})

            res = requests.post(
                self.cfg.remoteUrl,
                data=json.dumps(payload),
                headers={
                    "content-type": "application/json",
                    "x-api-key": self.cfg.apiKey,
                },
                timeout=self._request_timeout_ms / 1000.0,
            )
            if res.status_code >= 400:
                txt = ""
                try:
                    txt = res.text
                except Exception:
                    pass
                self._log("server_error", res.status_code, txt)
                with self._lock:
                    self._buf = batch + self._buf

        except Exception as err:
            self._log("network_error", str(err))
            with self._lock:
                self._buf = batch + self._buf
        finally:
            with self._lock:
                self._flushing = False

    def stop(self):
        self._stop = True
        if self._timer:
            self._timer.cancel()
            self._timer = None
        self.flush()


_global_managers: Dict[str, _BatchManager] = {}


def get_manager(cfg: Config) -> _BatchManager:
    key = f"{cfg.remoteUrl}|{cfg.apiKey}"
    mgr = _global_managers.get(key)
    if not mgr:
        mgr = _BatchManager(cfg)
        _global_managers[key] = mgr
    return mgr


parse_int_safe = _parse_int_safe
