performancematlabmatrixpdist

Is there a faster/compact way of obtaining the indices from squareform? (Matlab)


everyone. I have a 3-dimensional data point matrix called "data", which has a dimension of N*3. Right now, I am trying to get two values:

First, the indices "m" and "n" of a distance matrix "Dist", where

Dist = squareform(pdist(data));

Such that

[m,n] = find( Dist<=rc & Dist>0 );

where "rc" is a certain cutoff distance, "m" is the row index, and "n" is the column index.

Second, the conditional distances "ConDist", where

ConDist = data( pdist(data)<=rc & pdist(data)>0 );

This code works fine for small sized "data" (where N < 3500), however, for large "data" (N > 25000), this process takes too much time/memory. Therefore, I tried to minimize time/memory by doing the following:

Dist = zeros(size(data,1));
Dist(tril(true(size(data,1)),-1)) = pdist(data);
[m,n] = find(Dist <= rc  &  Dist > 0);
ConDist = Dist(Dist <= rc  &  Dist > 0);

Here, I calculated only the lower triangle side of the "squareform" command to reduce calculation time (or memory, I don't know how MATLAB will find this code much simpler). However, it seems like it still takes a lot of time/memory to calculate the "Dist" variable

Would there be a faster/less-memory-consuming way to calculate "m","n", and "ConDist"? Thank you very much in advance.


Solution

  • This could be one approach -

    N = size(data,1); %// datasize
    
    %// Store transpose of data, as we need to use later on at several places
    data_t = data.'  %//'
    
    %// Calculate squared distances with matrix multiplication based technique
    sqdist = tril([-2*data data.^2 ones(N,3)]*[data_t ; ones(3,N) ; data_t.^2])
    
    %// Logical array with size of distance array and ones that are above threshold
    mask_dists = sqdist <= rc^2  &  sqdist > 0
    
    %// Indices & distances from distances array that satisfy thresholding criteria
    [m,n] = find(mask_dists)
    ConDist = sqrt(sqdist(mask_dists))
    

    You can introduce bsxfun here to replace tril (keeping rest of it as it is) and see if that speeds it up a bit further -

    sqdist = [-2*data data.^2 ones(N,3)]*[data_t ; ones(3,N) ; data_t.^2]
    mask_dists = sqdist <= rc^2  &  sqdist > 0 & bsxfun(@ge,[1:N]',1:N)