I need to create a 2d projection of 3d tensor on the surface plane touching unit sphere in 3d space, in such a way that this projection is differentiable.
import torch
import torch.nn.functional as F
import math
def get_proj(volume,right_angle,left_angle,distance_to_obj = 4,surface_extent=3,N_samples_per_ray=200,H_out=128, W_out=128,grid_sample_mode='bilinear'):
"""
Generates a 2D projection of a 3D volume by casting rays from a specified camera position.
This function simulates an orthographic projection of a 3D volume onto a 2D plane. The camera is positioned on a sphere
centered at the origin, with its position determined by the provided right and left angles. Rays are cast from the camera
through points on a plane tangent to the sphere, and the volume is sampled along these rays to produce the projection.
Args:
volume (torch.Tensor): A 5D tensor of shape (N, C, D, H, W) representing the 3D volume to be projected.
right_angle (float): The azimuthal angle (in radians) determining the camera's position around the z-axis.
left_angle (float): The polar angle (in radians) determining the camera's elevation from the xy-plane.
distance_to_obj (float, optional): The distance from the camera to the origin. Defaults to 4.
surface_extent (float, optional): The half-extent of the tangent plane in world units. Defines the plane's size. Defaults to 3.
N_samples_per_ray (int, optional): The number of sample points along each ray. Higher values yield more accurate projections. Defaults to 200.
H_out (int, optional): The height (in pixels) of the output 2D projection. Defaults to 128.
W_out (int, optional): The width (in pixels) of the output 2D projection. Defaults to 128.
Returns:
torch.Tensor: A 4D tensor of shape (1, 1, H_out, W_out) representing the 2D projection of the input volume.
Raises:
ValueError: If the input volume is not a 5D tensor.
RuntimeError: If the sampling grid is out of the volume's bounds.
Example:
```python
import torch
# Create a sample 3D volume
volume = torch.zeros((1, 1, 32, 32, 32))
volume[0, 0, 16, :, :] = 1 # Add a plane in the middle
# Define camera angles
right_angle = 0.5 # radians
left_angle = 0.3 # radians
# Generate the projection
projection = get_proj(volume, right_angle, left_angle)
# Visualize the projection
import matplotlib.pyplot as plt
plt.imshow(projection.squeeze().cpu().numpy(), cmap='gray')
plt.show()
```
Note:
- Ensure that the input volume is normalized to the range [-1, 1] for proper sampling.
- The function assumes an orthographic projection model.
- Adjust `N_samples_per_ray` for a trade-off between performance and projection accuracy.
"""
device = volume.device
ra = right_angle
la = left_angle
# Compute camera position p on the unit sphere.
p = torch.tensor([
math.cos(la) * math.cos(ra),
math.cos(la) * math.sin(ra),
math.sin(la)
]).to(device)
p*=distance_to_obj
# p is of shape (3,). (It lies on the unit sphere.)
# The camera is at position p and always looks to the origin.
# Define the opposite point on the sphere:
q = -p # This will be the point of tangency of the projection plane.
# -------------------------------------------------------------------
# 3. Define an orthonormal basis for the projection plane tangent to the unit sphere at q.
# We need two vectors (right, up) lying in the plane.
# One way is to choose a reference vector not colinear with q.
# -------------------------------------------------------------------
ref = torch.tensor([0.0, 0.0, 1.0]).to(device)
if torch.allclose(torch.abs(q), torch.tensor([1.0, 1.0, 1.0]).to(device) * q[0], atol=1e-3):
ref = torch.tensor([0.0, 1.0, 0.0])
# Compute right as the normalized cross product of ref and q.
right_vec = torch.cross(ref, q,dim=0)
right_vec = right_vec / torch.norm(right_vec)
# Compute up as the cross product of q and right.
up_vec = torch.cross(q, right_vec)
up_vec = up_vec / torch.norm(up_vec)
# -------------------------------------------------------------------
# 4. Build the image plane grid.
#
# We want to form an image on the plane tangent to the sphere at q.
# The plane is defined by the equation: q · x = 1.
#
# A convenient parameterization is:
#
# For (u, v) in some range, the 3D point on the plane is:
# P(u,v) = q + u * right_vec + v * up_vec.
#
# Note: Since q is a unit vector, q · q = 1 and q is perpendicular to both right_vec and up_vec,
# so q · P(u,v) = 1 automatically.
#
# Choose an output image resolution and an extent for u and v.
# -------------------------------------------------------------------
# Choose an extent so that the sampled points remain in [-1,1]^3.
# (Since our volume covers [-1,1]^3, a modest extent is needed.)
extent = surface_extent # you may adjust this value
u_vals = torch.linspace(-extent, extent, W_out).to(device)
v_vals = torch.linspace(-extent, extent, H_out).to(device)
grid_v, grid_u = torch.meshgrid(v_vals, u_vals, indexing='ij') # shapes: (H_out, W_out)
# For each pixel (u,v) on the plane, compute its world coordinate.
# P = q + u * right_vec + v * up_vec.
plane_points = q.unsqueeze(0).unsqueeze(0) + \
grid_u.unsqueeze(-1) * right_vec + \
grid_v.unsqueeze(-1) * up_vec
# plane_points shape: (H_out, W_out, 3)
# -------------------------------------------------------------------
# 5. For each pixel, sample along the ray from the camera p through the point P.
#
# Since the camera is at p and the ray passing through a pixel is along the line from p to P,
# the ray can be parameterized as:
#
# r(t) = p + t*(P - p), for t in [0, 1]
#
# t=0 gives the camera position, t=1 gives the intersection with the image plane (P).
# -------------------------------------------------------------------
N_samples = N_samples_per_ray
t_vals = torch.linspace(0, 1, N_samples).to(device) # shape: (N_samples,)
# Expand plane_points to sample along t:
# plane_points has shape (H_out, W_out, 3). We want to combine it with p.
# Compute (P - p): note that p is a vector; we can reshape it appropriately.
P_minus_p = plane_points - p.unsqueeze(0).unsqueeze(0) # shape: (H_out, W_out, 3)
# Now, for each t, compute the sample point:
# sample_point(t, u, v) = p + t*(P(u,v) - p)
# We can do:
sample_grid = p.unsqueeze(0).unsqueeze(0).unsqueeze(0) + \
t_vals.view(N_samples, 1, 1, 1) * P_minus_p.unsqueeze(0)
# sample_grid now has shape: (N_samples, H_out, W_out, 3).
# Add a batch dimension (batch size 1) so that grid_sample sees a grid of shape:
# (1, N_samples, H_out, W_out, 3)
sample_grid = sample_grid.unsqueeze(0)
# IMPORTANT: grid_sample expects the grid coordinates in the normalized coordinate system
# of the input volume. Here our volume is defined on [-1, 1]^3. Make sure that the computed
# sample_grid falls in that range. (Depending on extent, p, etc., you may need to adjust.)
# For our setup, choose the parameters so that sample_grid is within [-1, 1].
# -------------------------------------------------------------------
# 6. Use grid_sample to sample the volume along each ray and integrate.
# -------------------------------------------------------------------
# grid_sample expects input volume of shape [N, C, D, H, W] and grid of shape [N, D_out, H_out, W_out, 3].
proj_samples = F.grid_sample(volume, sample_grid, mode=grid_sample_mode, align_corners=False)
# proj_samples has shape: (1, 1, N_samples, H_out, W_out)
# For a simple projection (like an X-ray), integrate along the ray.
# Here we simply sum along the sample (ray) dimension.
proj_image = proj_samples.sum(dim=2) # shape: (1, 1, H_out, W_out)
return proj_image
It can be used like this
import matplotlib.pyplot as plt
# this is volume that defines 3d object
volume = torch.zeros(1, 1, 32, 32, 32, requires_grad=True).cuda()
def make_cube(volume):
volume[0, 0, :, 0, 0] = 1
volume[0, 0, :, -1, 0] = 1
volume[0, 0, :, 0, -1] = 1
volume[0, 0, :, -1, -1] = 1
volume[0, 0, 0, :, 0] = 1
volume[0, 0, -1, :, 0] = 1
volume[0, 0, 0, :, -1] = 1
volume[0, 0, -1, :, -1] = 1
volume[0, 0, 0, -1, :] = 1
volume[0, 0, 0, 0, :] = 1
volume[0, 0, -1, 0, :] = 1
volume[0, 0, -1, -1, :] = 1
with torch.no_grad():
make_cube(volume)
# Create a figure and axis
fig, ax = plt.subplots()
right_angle =0.5
left_angle = 0.2
proj_image = get_proj(volume, right_angle, left_angle,surface_extent=4)
proj_image=proj_image.cpu().detach()[0, :].transpose(0,-1)
# Display the new image
plt.imshow(proj_image, cmap='gray')