matlabperformancemetric

Mean Average Precision for Multi-Label Multi-Class Data


I am trying to write a code for computing the Mean Average Precision (MAP) for multi-label data. To give a more intuitive understanding kindly please look below

enter image description here

I have written the code for the MAP computation in MATLAB but it is quite slow. Essentially it is slow due to the computation of the variable Lrx for each value of r.

I wanted to make my code much faster.

function [map] = map_at_R(sim_x,L_tr,L_te)

%sim_x(i,j) denote the sim bewteen query j and database i
tn = size(sim_x,2);
APx = zeros(tn,1);
R = 100;

for i = 1 : tn    
    Px = zeros(R,1);
    deltax = zeros(R,1);
    label = L_te(i,:);
    [~,inxx] = sort(sim_x(:,i),'descend');

    % compute Lx - the denominator in the map calculation
    % Lx = 1 if the retrieved item has the same label with the query or
    % shares atleast one label else Lx = 0
    search_set = L_tr(inxx(1:R),:);

    for r = 1 : R        
        %% FAST COMPUTATION
        Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);

        %% SLOW COMPUTATION
%         Lrx = 0;
%         for j=1:r
%             if sum(label*(search_set(j,:)).')>0
%                 Lrx = Lrx+1;
%             end
%         end        

        if sum(label*(search_set(r,:)).')>0
            deltax(r) = 1;
        end

        Px(r) = Lrx/r;
    end
    Lx = sum(deltax);
    if Lx ~=0
        APx(i) = sum(Px.*deltax)/Lx;
    end
end
map = mean(APx);

The input to the code is this :

% sim_x = similarity score matrix or distance matrix
sim_x = gallery_data_size X probe_data_size 

% L_tr = labels of the gallery set
L_tr = gallery_data_size X c

% L_te = labels of the probe set
L_te = probe_data_size X c

% where c is the number of classes
% please note that the data is multi-label

Is it possible to make the code even faster? I am unable to figure it out myself.


Solution

  • With the delta function APx(i) = sum(Px.*deltax)/Lx you are throwing away some proportion of your r = 1:R iterations. Since the delta can be defined before the loop, why not only iterate through r where deltax(r) == 1.

    % r_range is equivalent to find(deltax(r) == 1);
    %Edit 1/4 %Previously :: r_range = find(sum(label*(search_set(1:R,:)).')>0);
    % Multiply each row by label
    mult = bsxfun(@times,(search_set(1:R,:)),label);
    % Sum each row 
    r_range = find(sum(mult,2)>0);
    % r_range @ i should equal find(deltax) @ i
    
    Px = zeros(numel(r_range,1);
    
    for r = r_range
        Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);
        Px(r == r_range) = Lrx/r;
    end 
    
    Lx = numel(r_range);
    if Lx ~=0
        APx(i) = sum(Px)/Lx;
    end