image-processingvideo-trackingmean-shiftback-projection

MeanShift formula understanding


I'm implementing MeanShift algorithm for object tracking, using ideas from here: http://www.cse.psu.edu/~rtc12/CSE598C/meanShiftColor.pdf

Now I have backprojection image for consequent frame. Each pixel in such image marks the probability of belonging to tracked object:enter image description here

MeanShift formula in aforementioned source looks so: enter image description here

w(xi) = pixel in backprojection image.
x = current center pixel.

I don't understand what is spatial kernel.
Assuming it can be 2D Gaussian kernel of size say 5x5, K(xi-x)*w(xi) can be replaced by pixel of pre-blurred image.

My code look so:

    fvect2 findMeanShift(const PlainImage<uint8>& weights_smoothed, fvect2 old_center, DebugOutput& debug)
    {
        //LOGE("first center: %.2f %.2f", old_center.x, old_center.y);

        const int w=weights_smoothed.getWidth(), h=weights_smoothed.getHeight();

        int iter_count = 0;
        fvect2 total_shift(0.0,0.0);

        while(iter_count++ < 20)
        {
            fvect2 fTop(0,0);
            float fBottom=0.0;
            for(int y=0;y<h;++y)
                for(int x=0;x<w;++x)
                {
                    fvect2 cur_center(x, y);
                    float mult = weights_smoothed.at(x, y)[0]/255.0;
                    fBottom += mult;
                    fTop += (cur_center-old_center) * mult;
                }
            fvect2 mean_shift = fTop/fBottom;
            //printf("mean_shift: %.2f %.2f", mean_shift.x, mean_shift.y);

            debug.addArrow(old_center, old_center+mean_shift);

            old_center += mean_shift;
            //printf("old_center: %.2f %.2f", old_center.x, old_center.y);

            total_shift += mean_shift;

            if(mean_shift.lengthF()<0.1)
                break;
        }

        return total_shift;
    }

So I just iterate by smoothed backprojection image, and for each pixel:
add its value to denominator
add its value multiplied by shift from current center to denumerator.

It converges at second iteration, but shift is wrong, and I don't know how to debug it. Probably the problem in formula implementation. enter image description here

Please explain me in human language what the spatial kernel is and how to apply it to weight image. Thanks!


Solution

  • Well, I've understood. Spatial mask is search window. Backprojection doesn't need to be blurred. Applying to entire image, 1's are inside search window and 0's are outside. On the page 2 of this is simpler explanation.

    If to search far from tracked object, results are bad. So the wrong shift point in the question probably is the mass center of entire image.

    Now my func looks so:

        fvect2 findMeanShift(const PlainImage<uint8>& weights, const ImageRect& spatial_roi, fvect2 old_center, DebugOutput& debug)
        {
            const int w=weights.getWidth(), h=weights.getHeight();
    
            int iter_count = 0;
            fvect2 total_shift(0.0,0.0);
    
            while(iter_count++ < 20)
            {
                fvect2 fTop(0,0);
                float fBottom=0.0;
    
                for(int y=std::max(spatial_roi.m_y, 0);y<std::min(h, spatial_roi.m_y+spatial_roi.m_height);++y)
                    for(int x=std::max(spatial_roi.m_x, 0);x<std::min(w, spatial_roi.m_x+spatial_roi.m_width);++x)
                    {
                        assert(y>=0 && x>=0 && y<h && x<w);
    
                        fvect2 cur_center(x, y);
                        float mult = weights.at(x, y)[0]/255.0;
                        fBottom += mult;
                        fTop += (cur_center-old_center) * mult;
                    }
                fvect2 mean_shift = fTop/fBottom;
                //LOGE("mean_shift: %.2f %.2f", mean_shift.x, mean_shift.y);
    
                debug.addArrow(old_center, old_center+mean_shift);
    
                old_center += mean_shift;
                //LOGE("old_center: %.2f %.2f", old_center.x, old_center.y);
    
                total_shift += mean_shift;
    
                if(mean_shift.lengthF()<0.1)
                    break;
            }
    
            return total_shift;
        }