I need to create a 2d projection of 3d tensor on the surface plane touching unit sphere in 3d space, in such a way that this projection is differentiable.
from typing import Literal
import torch
import torch.nn.functional as F
import math
def get_proj(volume, right_angle, left_angle, rotation_angle=0.0, distance_to_obj=4, surface_extent=3, N_samples_per_ray=200, H_out=128, W_out=128, grid_sample_mode: Literal['bilinear', 'nearest']='nearest', projection_aggregation: Literal['sum', 'first', 'max']='sum'):
"""
Generates a 2D projection of a 3D volume by casting rays from a specified camera position.
This function simulates an orthographic projection of a 3D volume onto a 2D plane. The camera is positioned on a sphere
centered at the origin, with its position determined by the provided right and left angles. Rays are cast from the camera
through points on a plane tangent to the sphere, and the volume is sampled along these rays to produce the projection.
The projection plane can be rotated around its normal axis using the rotation_angle parameter.
Args:
volume (torch.Tensor): A 5D tensor of shape (N, C, D, H, W) representing the 3D volume to be projected.
right_angle (float): The azimuthal angle (in radians) determining the camera's position around the z-axis.
left_angle (float): The polar angle (in radians) determining the camera's elevation from the xy-plane.
rotation_angle (float, optional): The angle (in radians) to rotate the projection plane around its normal axis. Defaults to 0.0.
distance_to_obj (float, optional): The distance from the camera to the origin. Defaults to 4.
surface_extent (float, optional): The half-extent of the tangent plane in world units. Defines the plane's size. Defaults to 3.
N_samples_per_ray (int, optional): The number of sample points along each ray. Higher values yield more accurate projections. Defaults to 200.
H_out (int, optional): The height (in pixels) of the output 2D projection. Defaults to 128.
W_out (int, optional): The width (in pixels) of the output 2D projection. Defaults to 128.
grid_sample_mode (str, optional): The interpolation mode for grid sampling. Defaults to 'nearest'.
projection_aggregation (str, optional): Method to aggregate samples along each ray. Options are:
- 'sum': Sum of all samples along the ray (default).
- 'max': Maximum value along the ray.
- 'first': First non-zero value along the ray, relative to the camera.
Returns:
torch.Tensor: A 4D tensor of shape (1, 1, H_out, W_out) representing the 2D projection of the input volume.
Raises:
ValueError: If the input volume is not a 5D tensor.
RuntimeError: If the sampling grid is out of the volume's bounds.
Example:
```python
import torch
# Create a sample 3D volume
volume = torch.zeros((1, 1, 32, 32, 32))
volume[0, 0, 16, :, :] = 1 # Add a plane in the middle
# Define camera angles and rotation
right_angle = 0.5 # radians
left_angle = 0.3 # radians
rotation_angle = 0.785 # 45 degrees in radians
# Generate the projection with rotation
projection = get_proj(volume, right_angle, left_angle, rotation_angle=rotation_angle)
# Visualize the projection
import matplotlib.pyplot as plt
plt.imshow(projection.squeeze().cpu().numpy(), cmap='gray')
plt.show()
```
Note:
- Ensure that the input volume is normalized to the range [-1, 1] for proper sampling.
- The function assumes an orthographic projection model.
- Adjust `N_samples_per_ray` for a trade-off between performance and projection accuracy.
"""
if grid_sample_mode=='bilinear' and projection_aggregation=='first':
print("grid_sample_mode='bilinear' with projection_aggregation='first' may break render. Better use 'nearest' render in this case.")
device = volume.device
ra = right_angle
la = left_angle
# Compute camera position p on the unit sphere.
p = torch.tensor([
math.cos(la) * math.cos(ra),
math.cos(la) * math.sin(ra),
math.sin(la)
]).to(device)
p *= distance_to_obj
# p is of shape (3,). (It lies on the unit sphere.)
# The camera is at position p and always looks to the origin.
# Define the opposite point on the sphere:
q = -p # This will be the point of tangency of the projection plane.
# Define an orthonormal basis for the projection plane tangent to the unit sphere at q.
ref = torch.tensor([0.0, 0.0, 1.0]).to(device)
if torch.allclose(torch.abs(q), torch.tensor([1.0, 1.0, 1.0]).to(device) * q[0], atol=1e-3):
ref = torch.tensor([0.0, 1.0, 0.0])
# Compute right as the normalized cross product of ref and q.
right_vec = torch.cross(ref, q, dim=0)
right_vec = right_vec / torch.norm(right_vec)
# Compute up as the cross product of q and right.
up_vec = torch.cross(q, right_vec, dim=0)
up_vec = up_vec / torch.norm(up_vec)
# Rotate the basis vectors around q by rotation_angle
cos_theta = math.cos(rotation_angle)
sin_theta = math.sin(rotation_angle)
right_vec_rot = right_vec * cos_theta + up_vec * sin_theta
up_vec_rot = -right_vec * sin_theta + up_vec * cos_theta
# Build the image plane grid.
extent = surface_extent
u_vals = torch.linspace(-extent, extent, W_out).to(device)
v_vals = torch.linspace(-extent, extent, H_out).to(device)
grid_v, grid_u = torch.meshgrid(v_vals, u_vals, indexing='ij')
# For each pixel (u,v) on the plane, compute its world coordinate using rotated basis vectors.
plane_points = q.unsqueeze(0).unsqueeze(0) + \
grid_u.unsqueeze(-1) * right_vec_rot + \
grid_v.unsqueeze(-1) * up_vec_rot
# Sample along rays from camera through plane points
N_samples = N_samples_per_ray
t_vals = torch.linspace(0, 1, N_samples).to(device)
P_minus_p = plane_points - p.unsqueeze(0).unsqueeze(0)
sample_grid = p.unsqueeze(0).unsqueeze(0).unsqueeze(0) + \
t_vals.view(N_samples, 1, 1, 1) * P_minus_p.unsqueeze(0)
sample_grid = sample_grid.unsqueeze(0)
# Sample the volume along rays
proj_samples = F.grid_sample(volume, sample_grid, mode=grid_sample_mode, align_corners=False)
# Aggregate samples based on projection_aggregation
if projection_aggregation == 'sum':
proj_image = proj_samples.sum(dim=2)
elif projection_aggregation == 'max':
proj_image = proj_samples.max(dim=2)[0]
elif projection_aggregation == 'first':
mask = (proj_samples != 0).float()
first_indices = torch.argmax(mask, dim=2) # Index of first non-zero sample
index = first_indices.unsqueeze(2)
proj_image = torch.gather(proj_samples, 2, index).squeeze(2)
else:
raise ValueError(f"Unknown projection_aggregation: {projection_aggregation}")
return proj_image
It can be used like this
import matplotlib.pyplot as plt
# this is volume that defines 3d object
volume = torch.zeros(1, 1, 32, 32, 32, requires_grad=True).cuda()
def make_cube(volume):
volume[0, 0, :, 0, 0] = 1
volume[0, 0, :, -1, 0] = 1
volume[0, 0, :, 0, -1] = 1
volume[0, 0, :, -1, -1] = 1
volume[0, 0, 0, :, 0] = 1
volume[0, 0, -1, :, 0] = 1
volume[0, 0, 0, :, -1] = 1
volume[0, 0, -1, :, -1] = 1
volume[0, 0, 0, -1, :] = 1
volume[0, 0, 0, 0, :] = 1
volume[0, 0, -1, 0, :] = 1
volume[0, 0, -1, -1, :] = 1
with torch.no_grad():
make_cube(volume)
# Create a figure and axis
fig, ax = plt.subplots()
right_angle =0.5
left_angle = 0.2
proj_image = get_proj(volume, right_angle, left_angle,surface_extent=4)
proj_image=proj_image.cpu().detach()[0, :].transpose(0,-1)
# Display the new image
plt.imshow(proj_image, cmap='gray')