#!/usr/bin/env python3
import argparse
from pathlib import Path

import cv2
import numpy as np

from utils import affine, parse_coords_path


def main():
    parser = argparse.ArgumentParser(
        description="Visualise risk map offsets over an overhead image"
    )
    parser.add_argument("overhead", help="overhead/base image")
    parser.add_argument("offsets", help="offsets text file")
    parser.add_argument(
        "--alpha",
        type=float,
        default=0.6,
        help="weight of the risk map overlay (0-1)",
    )
    parser.add_argument(
        "--delay",
        type=int,
        default=400,
        help="delay per overlay in ms (0 = wait)",
    )
    args = parser.parse_args()

    delay = max(args.delay, 0)

    base_img = cv2.imread(str(args.overhead), cv2.IMREAD_COLOR)
    if base_img is None:
        raise SystemExit(f"Error: cannot open {args.overhead}")
    h_base, w_base = base_img.shape[:2]

    offset_path = Path(args.offsets).expanduser().resolve()
    risk_root = offset_path.parent

    files = parse_coords_path(offset_path)

    cv2.namedWindow("fused", cv2.WINDOW_NORMAL)

    alpha = np.clip(args.alpha, 0.0, 1.0)
    idx = 0
    for fname, vals in files.items():
        idx += 1
        angle = float(vals["ang"])
        scale = float(vals["scl"])
        tx = float(vals["tx"])
        ty = float(vals["ty"])

        # Replace "risk_map" with "risk_map_with_points"
        fname = fname.replace("risk_map", "risk_map_with_points")
        risk_path = Path(fname)
        filestr = f"{risk_path.parent.name}/{risk_path.name}"
        if not risk_path.is_absolute():
            risk_path = risk_root / risk_path

        risk_img = cv2.imread(str(risk_path), cv2.IMREAD_UNCHANGED)
        if risk_img is None:
            print(f"[warning] cannot read {risk_path}; skipping")
            continue

        risk_img = cv2.putText(
            risk_img,
            filestr,
            (50, 50),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 0, 255),
            2,
        )

        if risk_img.ndim == 2:
            risk_img = cv2.cvtColor(risk_img, cv2.COLOR_GRAY2BGR)

        h_risk, w_risk = risk_img.shape[:2]

        M = affine(angle, scale, tx, ty, w_risk, h_risk)
        warped = cv2.warpAffine(
            risk_img,
            M,
            (w_base, h_base),
            flags=cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_TRANSPARENT,
        )

        fused = cv2.addWeighted(base_img, 1 - alpha, warped, alpha, 0)

        cv2.imshow("fused", fused)
        cv2.setWindowTitle("fused", risk_path.name)
        k = cv2.waitKey(delay) & 0xFF
        if k in (27, ord("q"), ord("Q")):
            break

    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()
