computational-geometrytorch

Create differentiable 2d projection of 3d tensor in pytorch?


I need to create a 2d projection of 3d tensor on the surface plane touching unit sphere in 3d space, in such a way that this projection is differentiable.


Solution

  • import torch
    import torch.nn.functional as F
    import math
    
    def get_proj(volume,right_angle,left_angle,distance_to_obj = 4,surface_extent=3,N_samples_per_ray=200,H_out=128, W_out=128,grid_sample_mode='bilinear'):
        """
        Generates a 2D projection of a 3D volume by casting rays from a specified camera position.
    
        This function simulates an orthographic projection of a 3D volume onto a 2D plane. The camera is positioned on a sphere
        centered at the origin, with its position determined by the provided right and left angles. Rays are cast from the camera
        through points on a plane tangent to the sphere, and the volume is sampled along these rays to produce the projection.
    
        Args:
            volume (torch.Tensor): A 5D tensor of shape (N, C, D, H, W) representing the 3D volume to be projected.
            right_angle (float): The azimuthal angle (in radians) determining the camera's position around the z-axis.
            left_angle (float): The polar angle (in radians) determining the camera's elevation from the xy-plane.
            distance_to_obj (float, optional): The distance from the camera to the origin. Defaults to 4.
            surface_extent (float, optional): The half-extent of the tangent plane in world units. Defines the plane's size. Defaults to 3.
            N_samples_per_ray (int, optional): The number of sample points along each ray. Higher values yield more accurate projections. Defaults to 200.
            H_out (int, optional): The height (in pixels) of the output 2D projection. Defaults to 128.
            W_out (int, optional): The width (in pixels) of the output 2D projection. Defaults to 128.
    
        Returns:
            torch.Tensor: A 4D tensor of shape (1, 1, H_out, W_out) representing the 2D projection of the input volume.
    
        Raises:
            ValueError: If the input volume is not a 5D tensor.
            RuntimeError: If the sampling grid is out of the volume's bounds.
    
        Example:
            ```python
            import torch
    
            # Create a sample 3D volume
            volume = torch.zeros((1, 1, 32, 32, 32))
            volume[0, 0, 16, :, :] = 1  # Add a plane in the middle
    
            # Define camera angles
            right_angle = 0.5  # radians
            left_angle = 0.3   # radians
    
            # Generate the projection
            projection = get_proj(volume, right_angle, left_angle)
    
            # Visualize the projection
            import matplotlib.pyplot as plt
            plt.imshow(projection.squeeze().cpu().numpy(), cmap='gray')
            plt.show()
            ```
    
        Note:
            - Ensure that the input volume is normalized to the range [-1, 1] for proper sampling.
            - The function assumes an orthographic projection model.
            - Adjust `N_samples_per_ray` for a trade-off between performance and projection accuracy.
        """
        device = volume.device
        
        ra = right_angle
        la = left_angle
        
        # Compute camera position p on the unit sphere.
        p = torch.tensor([
            math.cos(la) * math.cos(ra),
            math.cos(la) * math.sin(ra),
            math.sin(la)
        ]).to(device)
        p*=distance_to_obj
        # p is of shape (3,). (It lies on the unit sphere.)
    
        # The camera is at position p and always looks to the origin.
        # Define the opposite point on the sphere:
        q = -p  # This will be the point of tangency of the projection plane.
    
        # -------------------------------------------------------------------
        # 3. Define an orthonormal basis for the projection plane tangent to the unit sphere at q.
        # We need two vectors (right, up) lying in the plane.
        # One way is to choose a reference vector not colinear with q.
        # -------------------------------------------------------------------
        ref = torch.tensor([0.0, 0.0, 1.0]).to(device)
        if torch.allclose(torch.abs(q), torch.tensor([1.0, 1.0, 1.0]).to(device) * q[0], atol=1e-3):
            ref = torch.tensor([0.0, 1.0, 0.0])
    
        # Compute right as the normalized cross product of ref and q.
        right_vec = torch.cross(ref, q,dim=0)
        right_vec = right_vec / torch.norm(right_vec)
    
        # Compute up as the cross product of q and right.
        up_vec = torch.cross(q, right_vec)
        up_vec = up_vec / torch.norm(up_vec)
    
        # -------------------------------------------------------------------
        # 4. Build the image plane grid.
        #
        # We want to form an image on the plane tangent to the sphere at q.
        # The plane is defined by the equation: q · x = 1.
        #
        # A convenient parameterization is:
        #
        #    For (u, v) in some range, the 3D point on the plane is:
        #       P(u,v) = q + u * right_vec + v * up_vec.
        #
        # Note: Since q is a unit vector, q · q = 1 and q is perpendicular to both right_vec and up_vec,
        # so q · P(u,v) = 1 automatically.
        #
        # Choose an output image resolution and an extent for u and v.
        # -------------------------------------------------------------------
        # Choose an extent so that the sampled points remain in [-1,1]^3.
        # (Since our volume covers [-1,1]^3, a modest extent is needed.)
        extent = surface_extent  # you may adjust this value
    
        u_vals = torch.linspace(-extent, extent, W_out).to(device)
        v_vals = torch.linspace(-extent, extent, H_out).to(device)
        grid_v, grid_u = torch.meshgrid(v_vals, u_vals, indexing='ij')  # shapes: (H_out, W_out)
    
        # For each pixel (u,v) on the plane, compute its world coordinate.
        # P = q + u * right_vec + v * up_vec.
        plane_points = q.unsqueeze(0).unsqueeze(0) + \
                    grid_u.unsqueeze(-1) * right_vec + \
                    grid_v.unsqueeze(-1) * up_vec
        # plane_points shape: (H_out, W_out, 3)
    
        # -------------------------------------------------------------------
        # 5. For each pixel, sample along the ray from the camera p through the point P.
        #
        # Since the camera is at p and the ray passing through a pixel is along the line from p to P,
        # the ray can be parameterized as:
        #
        #    r(t) = p + t*(P - p),   for t in [0, 1]
        #
        # t=0 gives the camera position, t=1 gives the intersection with the image plane (P).
        # -------------------------------------------------------------------
        N_samples = N_samples_per_ray
        t_vals = torch.linspace(0, 1, N_samples).to(device)  # shape: (N_samples,)
    
        # Expand plane_points to sample along t:
        # plane_points has shape (H_out, W_out, 3). We want to combine it with p.
        # Compute (P - p): note that p is a vector; we can reshape it appropriately.
        P_minus_p = plane_points - p.unsqueeze(0).unsqueeze(0)  # shape: (H_out, W_out, 3)
    
        # Now, for each t, compute the sample point:
        # sample_point(t, u, v) = p + t*(P(u,v) - p)
        # We can do:
        sample_grid = p.unsqueeze(0).unsqueeze(0).unsqueeze(0) + \
                    t_vals.view(N_samples, 1, 1, 1) * P_minus_p.unsqueeze(0)
        # sample_grid now has shape: (N_samples, H_out, W_out, 3).
    
        # Add a batch dimension (batch size 1) so that grid_sample sees a grid of shape:
        # (1, N_samples, H_out, W_out, 3)
        sample_grid = sample_grid.unsqueeze(0)
    
        # IMPORTANT: grid_sample expects the grid coordinates in the normalized coordinate system
        # of the input volume. Here our volume is defined on [-1, 1]^3. Make sure that the computed
        # sample_grid falls in that range. (Depending on extent, p, etc., you may need to adjust.)
        # For our setup, choose the parameters so that sample_grid is within [-1, 1].
    
        # -------------------------------------------------------------------
        # 6. Use grid_sample to sample the volume along each ray and integrate.
        # -------------------------------------------------------------------
        # grid_sample expects input volume of shape [N, C, D, H, W] and grid of shape [N, D_out, H_out, W_out, 3].
        proj_samples = F.grid_sample(volume, sample_grid, mode=grid_sample_mode, align_corners=False)
        # proj_samples has shape: (1, 1, N_samples, H_out, W_out)
    
        # For a simple projection (like an X-ray), integrate along the ray.
        # Here we simply sum along the sample (ray) dimension.
        proj_image = proj_samples.sum(dim=2)  # shape: (1, 1, H_out, W_out)
        return proj_image
    

    It can be used like this

    import matplotlib.pyplot as plt
    # this is volume that defines 3d object
    volume = torch.zeros(1, 1, 32, 32, 32, requires_grad=True).cuda()
    
    def make_cube(volume):
        volume[0, 0, :, 0, 0] = 1
        volume[0, 0, :, -1, 0] = 1
        volume[0, 0, :, 0, -1] = 1
        volume[0, 0, :, -1, -1] = 1
    
        volume[0, 0, 0, :, 0] = 1
        volume[0, 0, -1, :, 0] = 1
        volume[0, 0, 0, :, -1] = 1
        volume[0, 0, -1, :, -1] = 1
    
        volume[0, 0, 0, -1, :] = 1
        volume[0, 0, 0, 0, :] = 1
        volume[0, 0, -1, 0, :] = 1
        volume[0, 0, -1, -1, :] = 1
    
    with torch.no_grad():
        make_cube(volume)
    
    # Create a figure and axis
    fig, ax = plt.subplots()
    
    right_angle =0.5
    left_angle = 0.2
    proj_image = get_proj(volume, right_angle, left_angle,surface_extent=4)
    
    proj_image=proj_image.cpu().detach()[0, :].transpose(0,-1)
    
    
    # Display the new image
    plt.imshow(proj_image, cmap='gray')
    

    cube