"""

.. _sphx_glr__gallery_optimize_world_points.py:

Optimizing world points from two images projections
======================================================================================================

This example illustrates how to estimate the world points from two images with
the projected points or with a computed flow between the two images.

The optimization is done using the ``optimize_chains_input_points_gn`` function, which optimizes the input points of a transformation chain using Gauss-Newton optimization.

.. seealso::

    - :func:`pycvcam.optimize_chains_input_points_gn`: Optimize the input points of a transformation chain using Gauss-Newton optimization.
    - :func:`pycvcam.project_points`: Project 3D points to 2D image points using a specified camera model.
    - :func:`pycvcam.compute_optical_flow`: Compute the optical flow between two images.
    - :func:`pycvcam.optimize_rays_intersect`: Optimize the intersection point of multiple rays using a least squares optimization.

"""

# %%
# Optimizing Input Points from Two Images Projections
# -----------------------------------------------------------------
# In this example, we will optimize the world points from two images projections using the ``optimize_chains_input_points_gn`` function. We will use a simple pinhole camera model for the transformations.
# We will create two transformations, one for each image, and then use the optimization function to find the world points that best fit the projections in both images.
#
# Lets consider a set of 100 random world points in a 3D space, and two camera transformations at different positions and orientations.

import pycvcam
import numpy
import py3dframe

world_x = numpy.random.uniform(-5.0, 5.0, size=(100,))
world_y = numpy.random.uniform(-5.0, 5.0, size=(100,))
world_z = numpy.random.uniform(50.0, 100.0, size=(100,))
world_points = numpy.stack((world_x, world_y, world_z), axis=-1)  # shape (100, 3)

camera_1_position = numpy.array([0.0, -4.0, 0.0])
camera_2_position = numpy.array([0.0, 4.0, 0.0])
camera_target = numpy.array([0.0, 0.0, 75.0])

camera_1_z_axis = camera_target - camera_1_position
camera_1_z_axis /= numpy.linalg.norm(camera_1_z_axis)
camera_1_x_axis = numpy.array([0, 0, 1.0])
camera_1_x_axis = numpy.cross(camera_1_x_axis, camera_1_z_axis)
camera_1_y_axis = numpy.cross(camera_1_z_axis, camera_1_x_axis)

camera_2_z_axis = camera_target - camera_2_position
camera_2_z_axis /= numpy.linalg.norm(camera_2_z_axis)
camera_2_x_axis = numpy.array([0, 0, 1.0])
camera_2_x_axis = numpy.cross(camera_2_x_axis, camera_2_z_axis)
camera_2_y_axis = numpy.cross(camera_2_z_axis, camera_2_x_axis)

camera_1_frame = py3dframe.Frame.from_axes(
    origin=camera_1_position,
    x_axis=camera_1_x_axis,
    y_axis=camera_1_y_axis,
    z_axis=camera_1_z_axis,
)
camera_2_frame = py3dframe.Frame.from_axes(
    origin=camera_2_position,
    x_axis=camera_2_x_axis,
    y_axis=camera_2_y_axis,
    z_axis=camera_2_z_axis,
)

extrinsic_1 = pycvcam.Cv2Extrinsic.from_frame(camera_1_frame)
extrinsic_2 = pycvcam.Cv2Extrinsic.from_frame(camera_2_frame)

image_height, image_width = (2000, 3000)
intrinsic = pycvcam.Cv2Intrinsic.from_matrix(
    numpy.array(
        [
            [1000.0, 0.0, (image_width - 1) / 2],
            [0.0, 1000.0, (image_height - 1) / 2],
            [0.0, 0.0, 1.0],
        ]
    )
)

distortion = pycvcam.ZernikeDistortion(
    parameters=[
        0.8541972545746392,
        -5.468596289790535,
        -5.974287819021697,
        14.292956075116104,
        2.1403205479372627,
        4.544169430137205,
        -0.10099732464199339,
        0.4363509204067417,
        -0.5106374355681896,
        -5.770087687650705,
        -0.39147505788710696,
        11.699411273002498,
    ]  # In pixels units, example values for a small distortion
)
distortion.center = ((image_width - 1) / 2, (image_height - 1) / 2)
distortion.radius = numpy.sqrt(
    (distortion.center[0]) ** 2 + (distortion.center[1]) ** 2
)
distortion.parameters_x = distortion.parameters_x / intrinsic.fx
distortion.parameters_y = distortion.parameters_y / intrinsic.fy
distortion.radius_x = distortion.radius_x / intrinsic.fx
distortion.radius_y = distortion.radius_y / intrinsic.fy
distortion.center_x = (distortion.center_x - intrinsic.cx) / intrinsic.fx
distortion.center_y = (distortion.center_y - intrinsic.cy) / intrinsic.fy

image_points_1 = pycvcam.project_points(
    world_points=world_points,
    intrinsic=intrinsic,
    distortion=distortion,
    extrinsic=extrinsic_1,
).image_points
image_points_2 = pycvcam.project_points(
    world_points=world_points,
    intrinsic=intrinsic,
    distortion=distortion,
    extrinsic=extrinsic_2,
).image_points

print("Image points 1 shape:", image_points_1.shape)
print("Image points 2 shape:", image_points_2.shape)
print("Image points 1 range:", image_points_1.min(axis=0), image_points_1.max(axis=0))
print("Image points 2 range:", image_points_2.min(axis=0), image_points_2.max(axis=0))


# Ensure all points are within the image boundaries
assert numpy.all((image_points_1[:, 0] >= 0) & (image_points_1[:, 0] < image_width))
assert numpy.all((image_points_1[:, 1] >= 0) & (image_points_1[:, 1] < image_height))
assert numpy.all((image_points_2[:, 0] >= 0) & (image_points_2[:, 0] < image_width))
assert numpy.all((image_points_2[:, 1] >= 0) & (image_points_2[:, 1] < image_height))

# %%
# Now assume that we pair the image points from both images,
# and we want to optimize the world points that best fit these projections.
# We can use the ``optimize_chains_input_points_gn`` function for this purpose.
#
# Define to chains of transformations, one for each image, that project the world points to the image points.
#
# .. code-block:: console
#
#    Chain 1: world_points -> extrinsic_1 -> distortion -> intrinsic -> image_points_1
#
#    Chain 2: world_points -> extrinsic_2 -> distortion -> intrinsic -> image_points_2
#

transforms = [extrinsic_1, extrinsic_2, distortion, intrinsic]
chains = [(0, 2, 3), (1, 2, 3)]
outputs = [image_points_1, image_points_2]

optimized_world_points, conv = pycvcam.optimize_chains_input_points_gn(
    seq_transforms=transforms,
    seq_chains=chains,
    seq_outputs=outputs,
    guess=world_points + numpy.random.normal(scale=0.1, size=world_points.shape),
    max_iterations=20,
    ftol=1e-8,
    gtol=1e-8,
    xtol=1e-8,
    eps=1e-2,  # Precision at 0.01 pixels projection (but add ftol, xtol, gtol to stop if cannot reach this precision)
    verbose=True,
    return_convergence=True,
)

error = numpy.linalg.norm(optimized_world_points - world_points, axis=1)  # shape (100,)
rms_error = numpy.sqrt(numpy.mean(error**2))
print("RMS error in world points:", rms_error)

# With noise added to the image points, the optimization should still converge to a solution close to the original world points.
noised_image_points_1 = image_points_1 + numpy.random.normal(
    scale=0.5, size=image_points_1.shape
)
noised_image_points_2 = image_points_2 + numpy.random.normal(
    scale=0.5, size=image_points_2.shape
)

print("\n")
optimized_world_points_noised, conv = pycvcam.optimize_chains_input_points_gn(
    seq_transforms=transforms,
    seq_chains=chains,
    seq_outputs=[noised_image_points_1, noised_image_points_2],
    guess=world_points + numpy.random.normal(scale=0.1, size=world_points.shape),
    max_iterations=100,
    ftol=1e-8,
    gtol=1e-8,
    xtol=1e-8,
    eps=1e-2,
    verbose=True,
    return_convergence=True,
)

error_noised = numpy.linalg.norm(optimized_world_points_noised - world_points, axis=1)
rms_error_noised = numpy.sqrt(numpy.mean(error_noised**2))
print("RMS error in world points with noise:", rms_error_noised)


# %%
# Extract the optimized world points from the Rays intersection
# -------------------------------------------------------------------
# AN other method consist in computing the rays from the camera centers to the image points,
# and then optimizing the intersection point of these rays using the ``optimize_rays_intersect`` function.
#

rays_1 = pycvcam.compute_rays(
    image_points=image_points_1,
    intrinsic=intrinsic,
    distortion=distortion,
    extrinsic=extrinsic_1,
    inverse_distortion_kwargs={
        "max_iterations": 10,
        "ftol": 1e-8,
        "gtol": 1e-8,
        "xtol": 1e-8,
    },
)
rays_2 = pycvcam.compute_rays(
    image_points=image_points_2,
    intrinsic=intrinsic,
    distortion=distortion,
    extrinsic=extrinsic_2,
    inverse_distortion_kwargs={
        "max_iterations": 10,
        "ftol": 1e-8,
        "gtol": 1e-8,
        "xtol": 1e-8,
    },
)
optimized_world_points_rays = pycvcam.optimize_rays_intersect([rays_1, rays_2])
error_rays = numpy.linalg.norm(optimized_world_points_rays - world_points, axis=1)
rms_error_rays = numpy.sqrt(numpy.mean(error_rays**2))
print("RMS error in world points from rays intersection:", rms_error_rays)

rays_noise_1 = pycvcam.compute_rays(
    image_points=noised_image_points_1,
    intrinsic=intrinsic,
    distortion=distortion,
    extrinsic=extrinsic_1,
    inverse_distortion_kwargs={
        "max_iterations": 10,
        "ftol": 1e-8,
        "gtol": 1e-8,
        "xtol": 1e-8,
    },
)
rays_noise_2 = pycvcam.compute_rays(
    image_points=noised_image_points_2,
    intrinsic=intrinsic,
    distortion=distortion,
    extrinsic=extrinsic_2,
    inverse_distortion_kwargs={
        "max_iterations": 10,
        "ftol": 1e-8,
        "gtol": 1e-8,
        "xtol": 1e-8,
    },
)
optimized_world_points_rays_noised = pycvcam.optimize_rays_intersect(
    [rays_noise_1, rays_noise_2]
)
error_rays_noised = numpy.linalg.norm(
    optimized_world_points_rays_noised - world_points, axis=1
)
rms_error_rays_noised = numpy.sqrt(numpy.mean(error_rays_noised**2))
print(
    "RMS error in world points from rays intersection with noise:",
    rms_error_rays_noised,
)


# %%
# Conclusion
# -------------------------------------------------------------------
# Both methods, optimizing the input points of the transformation chains and optimizing the rays intersection, can
# be used to estimate the world points from two images projections. The optimization of the input points of the
# transformation chains can provide a more accurate solution, especially when the noise level is low,
# while the optimization of the rays intersection can be more robust to noise but may have a higher error in the estimated world points.
#
# The first method can avoid inverse transformations computation.
# The second is faster to compute but can be less accurate.
#

print("RMS error in world points from chains optimization:", rms_error)
print("RMS error in world points from rays intersection:", rms_error_rays)
print(
    "RMS error in world points from chains optimization with noise:", rms_error_noised
)
print(
    "RMS error in world points from rays intersection with noise:",
    rms_error_rays_noised,
)
