I am trying to write a code for computing the Mean Average Precision (MAP) for multi-label data. To give a more intuitive understanding kindly please look below
I have written the code for the MAP computation in MATLAB but it is quite slow. Essentially it is slow due to the computation of the variable Lrx for each value of r.
I wanted to make my code much faster.
function [map] = map_at_R(sim_x,L_tr,L_te)
%sim_x(i,j) denote the sim bewteen query j and database i
tn = size(sim_x,2);
APx = zeros(tn,1);
R = 100;
for i = 1 : tn
Px = zeros(R,1);
deltax = zeros(R,1);
label = L_te(i,:);
[~,inxx] = sort(sim_x(:,i),'descend');
% compute Lx - the denominator in the map calculation
% Lx = 1 if the retrieved item has the same label with the query or
% shares atleast one label else Lx = 0
search_set = L_tr(inxx(1:R),:);
for r = 1 : R
%% FAST COMPUTATION
Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);
%% SLOW COMPUTATION
% Lrx = 0;
% for j=1:r
% if sum(label*(search_set(j,:)).')>0
% Lrx = Lrx+1;
% end
% end
if sum(label*(search_set(r,:)).')>0
deltax(r) = 1;
end
Px(r) = Lrx/r;
end
Lx = sum(deltax);
if Lx ~=0
APx(i) = sum(Px.*deltax)/Lx;
end
end
map = mean(APx);
The input to the code is this :
% sim_x = similarity score matrix or distance matrix
sim_x = gallery_data_size X probe_data_size
% L_tr = labels of the gallery set
L_tr = gallery_data_size X c
% L_te = labels of the probe set
L_te = probe_data_size X c
% where c is the number of classes
% please note that the data is multi-label
Is it possible to make the code even faster? I am unable to figure it out myself.
With the delta function APx(i) = sum(Px.*deltax)/Lx
you are throwing away some proportion of your r = 1:R
iterations. Since the delta can be defined before the loop, why not only iterate through r
where deltax(r) == 1
.
% r_range is equivalent to find(deltax(r) == 1);
%Edit 1/4 %Previously :: r_range = find(sum(label*(search_set(1:R,:)).')>0);
% Multiply each row by label
mult = bsxfun(@times,(search_set(1:R,:)),label);
% Sum each row
r_range = find(sum(mult,2)>0);
% r_range @ i should equal find(deltax) @ i
Px = zeros(numel(r_range,1);
for r = r_range
Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);
Px(r == r_range) = Lrx/r;
end
Lx = numel(r_range);
if Lx ~=0
APx(i) = sum(Px)/Lx;
end