computational-geometrytorch

Create differentiable 2d projection of 3d tensor in pytorch?


I need to create a 2d projection of 3d tensor on the surface plane touching unit sphere in 3d space, in such a way that this projection is differentiable.


Solution

  • from typing import Literal
    import torch
    import torch.nn.functional as F
    import math
    
    def get_proj(volume, right_angle, left_angle, rotation_angle=0.0, distance_to_obj=4, surface_extent=3, N_samples_per_ray=200, H_out=128, W_out=128, grid_sample_mode: Literal['bilinear', 'nearest']='nearest', projection_aggregation: Literal['sum', 'first', 'max']='sum'):
        """
        Generates a 2D projection of a 3D volume by casting rays from a specified camera position.
    
        This function simulates an orthographic projection of a 3D volume onto a 2D plane. The camera is positioned on a sphere
        centered at the origin, with its position determined by the provided right and left angles. Rays are cast from the camera
        through points on a plane tangent to the sphere, and the volume is sampled along these rays to produce the projection.
        The projection plane can be rotated around its normal axis using the rotation_angle parameter.
    
        Args:
            volume (torch.Tensor): A 5D tensor of shape (N, C, D, H, W) representing the 3D volume to be projected.
            right_angle (float): The azimuthal angle (in radians) determining the camera's position around the z-axis.
            left_angle (float): The polar angle (in radians) determining the camera's elevation from the xy-plane.
            rotation_angle (float, optional): The angle (in radians) to rotate the projection plane around its normal axis. Defaults to 0.0.
            distance_to_obj (float, optional): The distance from the camera to the origin. Defaults to 4.
            surface_extent (float, optional): The half-extent of the tangent plane in world units. Defines the plane's size. Defaults to 3.
            N_samples_per_ray (int, optional): The number of sample points along each ray. Higher values yield more accurate projections. Defaults to 200.
            H_out (int, optional): The height (in pixels) of the output 2D projection. Defaults to 128.
            W_out (int, optional): The width (in pixels) of the output 2D projection. Defaults to 128.
            grid_sample_mode (str, optional): The interpolation mode for grid sampling. Defaults to 'nearest'.
            projection_aggregation (str, optional): Method to aggregate samples along each ray. Options are:
                - 'sum': Sum of all samples along the ray (default).
                - 'max': Maximum value along the ray.
                - 'first': First non-zero value along the ray, relative to the camera.
    
        Returns:
            torch.Tensor: A 4D tensor of shape (1, 1, H_out, W_out) representing the 2D projection of the input volume.
    
        Raises:
            ValueError: If the input volume is not a 5D tensor.
            RuntimeError: If the sampling grid is out of the volume's bounds.
    
        Example:
            ```python
            import torch
    
            # Create a sample 3D volume
            volume = torch.zeros((1, 1, 32, 32, 32))
            volume[0, 0, 16, :, :] = 1  # Add a plane in the middle
    
            # Define camera angles and rotation
            right_angle = 0.5  # radians
            left_angle = 0.3   # radians
            rotation_angle = 0.785  # 45 degrees in radians
    
            # Generate the projection with rotation
            projection = get_proj(volume, right_angle, left_angle, rotation_angle=rotation_angle)
    
            # Visualize the projection
            import matplotlib.pyplot as plt
            plt.imshow(projection.squeeze().cpu().numpy(), cmap='gray')
            plt.show()
            ```
    
        Note:
            - Ensure that the input volume is normalized to the range [-1, 1] for proper sampling.
            - The function assumes an orthographic projection model.
            - Adjust `N_samples_per_ray` for a trade-off between performance and projection accuracy.
        """
        if grid_sample_mode=='bilinear' and projection_aggregation=='first':
            print("grid_sample_mode='bilinear' with projection_aggregation='first' may break render. Better use 'nearest' render in this case.")
        device = volume.device
        
        ra = right_angle
        la = left_angle
        
        # Compute camera position p on the unit sphere.
        p = torch.tensor([
            math.cos(la) * math.cos(ra),
            math.cos(la) * math.sin(ra),
            math.sin(la)
        ]).to(device)
        p *= distance_to_obj
        # p is of shape (3,). (It lies on the unit sphere.)
    
        # The camera is at position p and always looks to the origin.
        # Define the opposite point on the sphere:
        q = -p  # This will be the point of tangency of the projection plane.
    
        # Define an orthonormal basis for the projection plane tangent to the unit sphere at q.
        ref = torch.tensor([0.0, 0.0, 1.0]).to(device)
        if torch.allclose(torch.abs(q), torch.tensor([1.0, 1.0, 1.0]).to(device) * q[0], atol=1e-3):
            ref = torch.tensor([0.0, 1.0, 0.0])
    
        # Compute right as the normalized cross product of ref and q.
        right_vec = torch.cross(ref, q, dim=0)
        right_vec = right_vec / torch.norm(right_vec)
    
        # Compute up as the cross product of q and right.
        up_vec = torch.cross(q, right_vec, dim=0)
        up_vec = up_vec / torch.norm(up_vec)
    
        # Rotate the basis vectors around q by rotation_angle
        cos_theta = math.cos(rotation_angle)
        sin_theta = math.sin(rotation_angle)
        right_vec_rot = right_vec * cos_theta + up_vec * sin_theta
        up_vec_rot = -right_vec * sin_theta + up_vec * cos_theta
    
        # Build the image plane grid.
        extent = surface_extent
        u_vals = torch.linspace(-extent, extent, W_out).to(device)
        v_vals = torch.linspace(-extent, extent, H_out).to(device)
        grid_v, grid_u = torch.meshgrid(v_vals, u_vals, indexing='ij')
    
        # For each pixel (u,v) on the plane, compute its world coordinate using rotated basis vectors.
        plane_points = q.unsqueeze(0).unsqueeze(0) + \
                       grid_u.unsqueeze(-1) * right_vec_rot + \
                       grid_v.unsqueeze(-1) * up_vec_rot
    
        # Sample along rays from camera through plane points
        N_samples = N_samples_per_ray
        t_vals = torch.linspace(0, 1, N_samples).to(device)
        P_minus_p = plane_points - p.unsqueeze(0).unsqueeze(0)
        sample_grid = p.unsqueeze(0).unsqueeze(0).unsqueeze(0) + \
                      t_vals.view(N_samples, 1, 1, 1) * P_minus_p.unsqueeze(0)
        sample_grid = sample_grid.unsqueeze(0)
    
        # Sample the volume along rays
        proj_samples = F.grid_sample(volume, sample_grid, mode=grid_sample_mode, align_corners=False)
    
        # Aggregate samples based on projection_aggregation
        if projection_aggregation == 'sum':
            proj_image = proj_samples.sum(dim=2)
        elif projection_aggregation == 'max':
            proj_image = proj_samples.max(dim=2)[0]
        elif projection_aggregation == 'first':
            mask = (proj_samples != 0).float()
            first_indices = torch.argmax(mask, dim=2)  # Index of first non-zero sample
            index = first_indices.unsqueeze(2)
            proj_image = torch.gather(proj_samples, 2, index).squeeze(2)
        else:
            raise ValueError(f"Unknown projection_aggregation: {projection_aggregation}")
    
        return proj_image
    

    It can be used like this

    import matplotlib.pyplot as plt
    # this is volume that defines 3d object
    volume = torch.zeros(1, 1, 32, 32, 32, requires_grad=True).cuda()
    
    def make_cube(volume):
        volume[0, 0, :, 0, 0] = 1
        volume[0, 0, :, -1, 0] = 1
        volume[0, 0, :, 0, -1] = 1
        volume[0, 0, :, -1, -1] = 1
    
        volume[0, 0, 0, :, 0] = 1
        volume[0, 0, -1, :, 0] = 1
        volume[0, 0, 0, :, -1] = 1
        volume[0, 0, -1, :, -1] = 1
    
        volume[0, 0, 0, -1, :] = 1
        volume[0, 0, 0, 0, :] = 1
        volume[0, 0, -1, 0, :] = 1
        volume[0, 0, -1, -1, :] = 1
    
    with torch.no_grad():
        make_cube(volume)
    
    # Create a figure and axis
    fig, ax = plt.subplots()
    
    right_angle =0.5
    left_angle = 0.2
    proj_image = get_proj(volume, right_angle, left_angle,surface_extent=4)
    
    proj_image=proj_image.cpu().detach()[0, :].transpose(0,-1)
    
    
    # Display the new image
    plt.imshow(proj_image, cmap='gray')
    

    cube