by definition area interpolation is just weighted by the pixel area.
So I suppose if scale factor is 1.5, then output pixel 00 contain full pixel of 00, half of the 01 and 10, 1/4 of the 11. the weight will be in_pixel_area/(1.5)^2
However, this does not seem to be the case:
x = torch.tensor(
[[3, 106, 107, 40, 148, 112, 254, 151],
[62, 173, 91, 93, 33, 111, 139, 25],
[99, 137, 80, 231, 101, 204, 74, 219],
[240, 173, 85, 14, 40, 230, 160, 152],
[230, 200, 177, 149, 173, 239, 103, 74],
[19, 50, 209, 82, 241, 103, 3, 87],
[252, 191, 55, 154, 171, 107, 6, 123],
[7, 101, 168, 85, 115, 103, 32, 11]],
dtype=torch.float).unsqueeze(0).unsqueeze(0)
print(x.shape, x.sum())
for scale in [8/6, 2]:
pixel_area = scale**2
y = F.interpolate(x, scale_factor=1/scale, mode="area")
print(y.shape, y, y.sum()*pixel_area)
print((3 + 106*(scale-1) + 62*(scale-1) + 173*(scale-1)**2)/pixel_area)
print((11 + 123*(scale-1) + 32*(scale-1) + 6*(scale-1)**2)/pixel_area)
the output is:
torch.Size([1, 1, 8, 8]) tensor(7707.)
torch.Size([1, 1, 6, 6]) tensor([[[[ 86.0000, 119.2500, 82.7500, 101.0000, 154.0000, 142.2500],
[117.7500, 120.2500, 123.7500, 112.2500, 132.0000, 114.2500],
[162.2500, 118.7500, 102.5000, 143.7500, 167.0000, 151.2500],
[124.7500, 159.0000, 154.2500, 189.0000, 112.0000, 66.7500],
[128.0000, 126.2500, 125.0000, 155.5000, 54.7500, 54.7500],
[137.7500, 128.7500, 115.5000, 124.0000, 62.0000, 43.0000]]]]) tensor(7665.7778)
43.99999999999999
35.62499999999999
torch.Size([1, 1, 4, 4]) tensor([[[[ 86.0000, 82.7500, 101.0000, 142.2500],
[162.2500, 102.5000, 143.7500, 151.2500],
[124.7500, 154.2500, 189.0000, 66.7500],
[137.7500, 115.5000, 124.0000, 43.0000]]]]) tensor(7707.)
86.0
43.0
we can see if scale = 2, then it works fine. But when scale = 8/6, it gives out strange result.
First the y.sum()*pixel_area
does not equal to x.sum()
2nd I try to directly calculate the pixel value via weight it give out 44 instead of 86.
3rd I would expect the output 00 pixel has different result when the scale is different, but apparently the 00 is still 86. why?
update with closer look, it seems when scale = 8/6, it simply does 2x2 kernel average with stride 1x1. But isn't this against the definition of area interpolation?
When using mode="area"
, pytorch computes the output using an adaptive average pooling operation. You can find the relevant code here
...
if input.dim() == 4 and mode == "area":
assert output_size is not None
return adaptive_avg_pool2d(input, output_size)
...
You can verify this via:
x = ...
scale_factor = 0.75
pool1 = F.interpolate(x, scale_factor=scale_factor, mode="area")
output_size = [int(i*scale_factor) for i in x.shape[-2:]]
pool2 = F.adaptive_avg_pool2d(x, output_size)
pool1 == pool2
Adaptive average pooling breaks the input into roughly even sized chunks and computes a simple average of each chunk. There is no weighting of pixels like you describe. You can see the code for adaptive pooling here and the indexing code here.
It may help to look at the indexing code:
inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
return (a / b) * c + ((a % b) * c) / b;
}
inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
return 1 + ((a + 1) * c - 1) / b;
}
Here, a
is the output position in a dimension of size b
mapping to an input of size c
.
Take your example where scale_factor = 1/(8/6) = 0.75
. The input is size (..., 8, 8)
so the output will be of size (..., 6, 6)
(int(0.75*8) = 6
).
You can use the following to compute the value of a specific output element:
def start_index(a, b, c):
return (a * c) // b
def end_index(a, b, c):
return ((a + 1) * c + b - 1) // b
input_height = 8
input_width = 8
output_height = 6
output_width = 6
out_row = 0
out_col = 1
h0 = start_index(out_row, output_height, input_height)
h1 = end_index(out_row, output_height, input_height)
w0 = start_index(out_col, output_width, input_width)
w1 = end_index(out_col, output_width, input_width)
kh = h1-h0
kw = w1-w0
x[:, :, h0:h1, w0:w1].sum() / (kh*kw)
Also note that for adaptive pooling, stride is not constant. For example:
for a in range(6):
start = start_index(a, 6, 8)
end = end_index(a, 6, 8)
print(f"Position {a}: {start} -> {end}")
Position 0: 0 -> 2
Position 1: 1 -> 3
Position 2: 2 -> 4
Position 3: 4 -> 6 # note the jump here from 2 to 4
Position 4: 5 -> 7
Position 5: 6 -> 8