from parse_rosbags import load_messages, match_closest_msgs, transform_cloud_to_odom_frame

import numpy as np
import cv2
import matplotlib.pyplot as plt
import sys
import os

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python script.py <bag_path>")
        sys.exit(1)

    rosbag_path = sys.argv[1]
    bag_name = bag_name = os.path.basename(rosbag_path)

    filepath = os.path.dirname(__file__)
    out_dir = f"{filepath}/validation_imgs/{bag_name}"

    odometry_messages = load_messages(rosbag_path, '/dlio/odom_node/odom', 'nav_msgs/msg/Odometry')
    risk_map_messages = load_messages(rosbag_path, '/obstacle_detection/blurred_risk_map_rviz', 'nav_msgs/msg/OccupancyGrid')
    pointcloud_messages = load_messages(rosbag_path, '/ouster/points', 'sensor_msgs/msg/PointCloud2')

    matched_msgs = match_closest_msgs(risk_map_messages, odometry_messages, pointcloud_messages)
    print(f"Matched {len(matched_msgs)} messages.")

    # choose 5 indicies to visualize
    indices = [80, 100, 120, 130, 140]
    matched_msgs = [matched_msgs[i] for i in indices]
    
    # transform point clouds to odometry frame
    transformed_pointclouds = []
    for _, odometry_msg, pointcloud_msg in matched_msgs:
        transformed_pointcloud = transform_cloud_to_odom_frame(pointcloud_msg, odometry_msg)
        transformed_pointclouds.append(transformed_pointcloud)

    # project each point cloud to its respective risk map
    for i, (risk_map_msg, odometry_msg, pointcloud_msg) in enumerate(matched_msgs):
        risk_map_data = np.array(risk_map_msg.data).reshape(risk_map_msg.info.height, risk_map_msg.info.width)
        risk_map_data = np.clip(risk_map_data, 0, 100)
        risk_map_gray = (255 - (risk_map_data * 2.55)).astype(np.uint8)
        risk_map_image = cv2.cvtColor(risk_map_gray, cv2.COLOR_GRAY2BGR)
        print(f"Processing risk map {i+1}/{len(matched_msgs)}")

        zoom_factor = 10
        resized_image = cv2.resize(
            risk_map_image,
            None,
            fx=zoom_factor,
            fy=zoom_factor,
            interpolation=cv2.INTER_NEAREST
        )
        resized_image = np.flipud(resized_image)
        resized_image = cv2.rotate(resized_image, cv2.ROTATE_90_COUNTERCLOCKWISE)
        cv2.imwrite(f"{out_dir}/risk_map_{i+1}.png", resized_image)
        
        # Overlay the point cloud on the risk map image
        for point in transformed_pointclouds[i]:
            x, y = point[:2]
            # Original pixel coords (float)
            px = (x - risk_map_msg.info.origin.position.x) / risk_map_msg.info.resolution
            py = (y - risk_map_msg.info.origin.position.y) / risk_map_msg.info.resolution

            # Rotate 90 degrees CCW
            px_rot = -py
            py_rot = px

            # Offset px_rot to positive pixel space
            pixel_x = resized_image.shape[1] - int(px_rot * zoom_factor + resized_image.shape[1])
            pixel_y = int(py_rot * zoom_factor)

            if 0 <= pixel_x < resized_image.shape[1] and 0 <= pixel_y < resized_image.shape[0]:
                cv2.circle(resized_image, (pixel_x, pixel_y), 1, (255, 0, 0), -1)
        cv2.imwrite(f"{out_dir}/risk_map_with_points_{i+1}.png", resized_image)

        cv2.imshow("Risk Map Overlay", resized_image)
        cv2.waitKey(0)