Is there some way to use flat indexing for the remaining dimensions with NumPy? I'm trying to translate the following MATLAB function to Python
function [indices, weights] = locate(values, gridpoints)
indices = ones(size(values));
weights = zeros([2, size(values)]);
for ix = 1:numel(values)
if values(ix) <= gridpoints(1)
indices(ix) = 1;
weights(:, ix) = [1; 0];
elseif values(ix) >= gridpoints(end)
indices(ix) = length(gridpoints) - 1;
weights(:, ix) = [0; 1];
else
indices(ix) = find(gridpoints <= values(ix), 1, 'last');
weights(:, ix) = ...
[gridpoints(indices(ix) + 1) - values(ix); ...
values(ix) - gridpoints(indices(ix))] ...
/ (gridpoints(indices(ix) + 1) - gridpoints(indices(ix)));
end
end
end
but I can't wrap my head around what the NumPy equivalent of MATLAB's weights(:, ix)
would be---that is, linear indexing only in the remaining dimensions.
I was hoping that the syntax could be directly translated, but suppose that values
is a 3-by-4 array, then weights
becomes a 2-by-3-by-4 array. In MATLAB, weights(:, ix)
is then a 2-by-1 array, whereas in Python weights[:, ix]
is a 2-by-3 array.
I think that I have handled everything else in the function below.
import numpy as np
def locate(values, gridpoints):
indices = np.zeros(np.shape(values), dtype=int)
weights = np.zeros((2,) + np.shape(values))
for ix in range(values.size):
if values.flat[ix] <= gridpoints[0]:
indices.flat[ix] = 0
# weights[:, ix] = [1, 0]
elif values.flat[ix] >= gridpoints[-1]:
indices.flat[ix] = gridpoints.size - 2
# weights[:, ix] = [0, 1]
else:
indices.flat[ix] = (
np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
)
# weights[:, ix] = (
# np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
# values.flat[ix] - gridpoints[indices.flat[ix]]])
# / (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
# )
return indices, weights
Do you have any suggestions? Perhaps I'm just thinking about the problem all wrong. I have also tried to write the code as simply as possible as I intend to use Numba to speed it up later.
As per hpaulj's comment, there doesn't seem to be a direct NumPy equivalent. In lack thereof, the best I can think of is to reshape the weights
array as in the code below and the suggestion from NumPy for Matlab Users.
import numpy as np
def locate(values, gridpoints):
indices = np.zeros(values.shape, dtype=int)
weights = np.zeros((2, values.size)) # Temporarily make weights 2-by-N
for ix in range(values.size):
if values.flat[ix] <= gridpoints[0]:
indices.flat[ix] = 0
weights[:, ix] = [1, 0]
elif values.flat[ix] >= gridpoints[-1]:
indices.flat[ix] = gridpoints.size - 2
weights[:, ix] = [0, 1]
else:
indices.flat[ix] = (
np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
)
weights[:, ix] = (
np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
values.flat[ix] - gridpoints[indices.flat[ix]]])
/ (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
)
# Give weights correct dimensions
weights.shape = (2,) + values.shape
return indices, weights