from rclpy.serialization import deserialize_message
from rosbag2_py import SequentialReader, StorageOptions, ConverterOptions
from rosidl_runtime_py.utilities import get_message
from typing import List, Tuple, Any
import os
import numpy as np
from sensor_msgs_py import point_cloud2

def load_messages(bag_path: str, topic_name: str, msg_type_str: str) -> List[Tuple[float, Any]]:
    if not os.path.exists(bag_path):
        raise FileNotFoundError(f"Bag path does not exist: {bag_path}")
    
    reader: SequentialReader = SequentialReader()
    storage_options: StorageOptions = StorageOptions(uri=bag_path, storage_id='mcap')
    converter_options: ConverterOptions = ConverterOptions(input_serialization_format='cdr', output_serialization_format='cdr')
    reader.open(storage_options, converter_options)

    msg_type = get_message(msg_type_str)

    topic_msgs = []

    while reader.has_next():
        topic, data, t = reader.read_next()
        if topic == topic_name:
            msg = deserialize_message(data, msg_type)
            stamp = msg.header.stamp.sec + msg.header.stamp.nanosec * 1e-9
            topic_msgs.append((stamp, msg))
    return topic_msgs

def match_closest_msgs(risk_map_msgs: List[Any], odometry_msgs: List[Any], pointcloud_msgs: List[Any]) -> List[Tuple[Any, Any, Any]]:
    odometry_stamps = [t for t, _ in odometry_msgs]
    pointcloud_stamps = [t for t, _ in pointcloud_msgs]
    matched = []
    for ref_time, ref_msg in risk_map_msgs:
        odometry_idx = min(range(len(odometry_stamps)), key=lambda i: abs(odometry_stamps[i] - ref_time))
        pointcloud_idx = min(range(len(pointcloud_stamps)), key=lambda i: abs(pointcloud_stamps[i] - ref_time))
        matched.append((ref_msg, odometry_msgs[odometry_idx][1], pointcloud_msgs[pointcloud_idx][1]))
    return matched

def transform_cloud_to_odom_frame(pointcloud_msg: Any, odometry_msg: Any) -> Any:

    # Extract the transformation from the odometry message
    position = odometry_msg.pose.pose.position
    orientation = odometry_msg.pose.pose.orientation

    # Convert quaternion to rotation matrix
    qx, qy, qz, qw = orientation.x, orientation.y, orientation.z, orientation.w
    rotation_matrix = np.array([
        [1 - 2*qy**2 - 2*qz**2, 2*qx*qy - 2*qz*qw, 2*qx*qz + 2*qy*qw],
        [2*qx*qy + 2*qz*qw, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz - 2*qx*qw],
        [2*qx*qz - 2*qy*qw, 2*qy*qz + 2*qx*qw, 1 - 2*qx**2 - 2*qy**2]
    ])

    # Transform the point cloud
    points = point_cloud2.read_points(pointcloud_msg, field_names=("x", "y", "z"), skip_nans=True)
    transformed_points = []
    for x, y, z in points:
        point = np.array([x, y, z])
        transformed_point = rotation_matrix @ point + np.array([position.x, position.y, position.z])
        transformed_points.append(transformed_point)

    return transformed_points