#!/usr/bin/env python3
"""
Compute per-file IoU and mean-IoU between an *inverted* reference mask and
several *inverted* transformed risk-maps, visualised as a colour overlay.

USAGE
-----
    python compute_miou.py MASK_PATH COORDS_PATH [--delay 400]

    MASK_PATH   : reference mask image
    COORDS_PATH : coordinates.txt   (filename  angle  scale  offset_x  offset_y)
    --delay     : milliseconds each overlay is shown (0 = wait for key)
"""
import argparse
import sys
from pathlib import Path

import cv2
import numpy as np

from utils import affine, parse_coords_path

# ────────── CLI ────────────────────────────────────────────────────────────────
ap = argparse.ArgumentParser(description="mIoU (inverted masks) with overlay")
ap.add_argument("mask", help="reference mask image")
ap.add_argument("coords", help="coordinates.txt file")
ap.add_argument(
    "--delay", type=int, default=400, help="delay per overlay in ms (0 = wait)"
)
args = ap.parse_args()
delay = max(args.delay, 0)

mask_path = Path(args.mask).expanduser().resolve()
coords_path = Path(args.coords).expanduser().resolve()
risk_root = coords_path.parent


# ────────── helpers ────────────────────────────────────────────────────────────
def binarise_and_invert(img, thr=0.99):
    """gray/RGB → binary (0/1) then invert (1-p)."""
    if img.ndim == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    bin_img = (img.astype(np.float32) / 255.0 > thr).astype(np.uint8)
    return 1 - bin_img  # invert here!


def iou(a, b):
    """Intersection over Union (IoU) = TP / (TP + FP + FN)"""
    inter = np.logical_and(a, b).sum()
    union = np.logical_or(a, b).sum()
    return inter / union if union else 1.0

def recall(a, b):
    """Recall = TP / (TP + FN)"""
    tp = np.logical_and(a, b).sum()
    fn = np.logical_and(a, np.logical_not(b)).sum()
    return tp / (tp + fn) if (tp + fn) else 1.0 

def precision(a, b):
    """Precision = TP / (TP + FP)"""
    tp = np.logical_and(a, b).sum()
    fp = np.logical_and(np.logical_not(a), b).sum()
    return tp / (tp + fp) if (tp + fp) else 1.0


def f1_score(a, b):
    """F1 Score = 2 * (Precision * Recall) / (Precision + Recall)"""
    p = precision(a, b)
    r = recall(a, b)
    return 2 * p * r / (p + r) if (p + r) else 0.0

# ────────── load & invert reference mask ───────────────────────────────────────
mask_img = cv2.imread(str(mask_path), cv2.IMREAD_UNCHANGED)
if mask_img is None:
    sys.exit(f"Error: cannot open {mask_path}")
mask_full = binarise_and_invert(mask_img)
H_full, W_full = mask_full.shape

# ────────── parse the coordinate path file ─────────────────────────────────────
files = parse_coords_path(coords_path)

# ────────── main loop ──────────────────────────────────────────────────────────
ious = {}
precisions = {}
recalls = {}
f1_scores = {}
cv2.namedWindow("overlay", cv2.WINDOW_NORMAL)

for k, v in files.items():
    fname = k
    ang = v["ang"]
    scl = v["scl"]
    tx = v["tx"]
    ty = v["ty"]
    print(fname, ang, scl, tx, ty)
    ang, scl, tx, ty = map(float, (ang, scl, tx, ty))
    risk_path = Path(fname) if Path(fname).is_absolute() else risk_root / fname

    risk_img = cv2.imread(str(risk_path), cv2.IMREAD_UNCHANGED)
    h, w = risk_img.shape[:2]
    print(f"Processing {risk_path} ({h}x{w})")
    # take the middle N/2 x N/2 pixels if the image
    risk_img = risk_img[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
    h, w = risk_img.shape[:2]
    # # adjust the affine transform to match the new size
    tx = tx + w // 2
    ty = ty + h // 2

    if risk_img is None:
        print(f"[warning] cannot read {risk_path}; skipping.")
        continue

    risk_bin = binarise_and_invert(risk_img)
    M = affine(ang, scl, tx, ty, w, h)
    warped = cv2.warpAffine(
        risk_bin, M, (W_full, H_full), flags=cv2.INTER_NEAREST, borderValue=0
    )

    # ─── crop to bounding box of warped risk map ───────────────────────────
    nz = cv2.findNonZero(warped)
    if nz is None:
        x, y, w_box, h_box = 0, 0, W_full, H_full
    else:
        x, y, w_box, h_box = cv2.boundingRect(nz)
        pad = 10
        x = max(x - pad, 0)
        y = max(y - pad, 0)
        w_box = min(w_box + 2 * pad, W_full - x)
        h_box = min(h_box + 2 * pad, H_full - y)

    mask_crop = mask_full[y : y + h_box, x : x + w_box]
    warped_crop = warped[y : y + h_box, x : x + w_box]

    this_iou = iou(mask_crop, warped_crop)
    this_recall = recall(mask_crop, warped_crop)
    this_precision = precision(mask_crop, warped_crop)
    this_f1_score = f1_score(mask_crop, warped_crop)
    ious[risk_path] = this_iou
    precisions[risk_path] = this_precision
    recalls[risk_path] = this_recall
    f1_scores[risk_path] = this_f1_score

    # ─── colour overlay (green = inverted mask, red = inverted risk) ──────
    overlay = np.zeros((h_box, w_box, 3), np.uint8)
    overlay[..., 1] = mask_crop * 255
    overlay[..., 2] = warped_crop * 255
    cv2.imshow("overlay", overlay)
    cv2.setWindowTitle("overlay", f"{risk_path.name}  |  IoU = {this_iou:.4f}")

    k = cv2.waitKey(delay) & 0xFF
    if k in (27, ord("q"), ord("Q")):
        break

cv2.destroyAllWindows()

# ────────── results ────────────────────────────────────────────────────────────
if ious:
    print("\nPer-file IoUs:")
    for k, val in ious.items():
        path = Path(k)
        last_path = path.parent.name + "/" + path.name
        print(f"  # {last_path}: {val:.4f}")
    print(f"\nMean IoU over {len(ious)} files: {np.mean(list(ious.values())):.4f}")
else:
    print("No risk-map files processed for ious.")

if precisions:
    print("\nPer-file Precisions:")
    for k, val in precisions.items():
        path = Path(k)
        last_path = path.parent.name + "/" + path.name
        print(f"  # {last_path}: {val:.4f}")
    print(f"\nMean Precision over {len(precisions)} files: {np.mean(list(precisions.values())):.4f}")
else:
    print("No risk-map files processed for precision.")

if recalls:
    print("\nPer-file Recalls:")
    for k, val in recalls.items():
        path = Path(k)
        last_path = path.parent.name + "/" + path.name
        print(f"  # {last_path}: {val:.4f}")
    print(f"\nMean Recall over {len(recalls)} files: {np.mean(list(recalls.values())):.4f}")
else:
    print("No risk-map files processed for recall.")

if f1_scores: 
    print("\nPer-file F1 Scores:")
    for k, val in f1_scores.items():
        path = Path(k)
        last_path = path.parent.name + "/" + path.name
        print(f"  # {last_path}: {val:.4f}")
    print(f"\nMean F1 Score over {len(f1_scores)} files: {np.mean(list(f1_scores.values())):.4f}")
else:
    print("No risk-map files processed for F1 score.")
