javascripthtmldeep-learningmediapipe

How to work with MediaPipe Multi-class selfie segmentation model?


I am working with MediaPipe library and wanted to detect hair and clothes from image. On their website, they gave good example and you can download and run the code. However, I want that when I detect hair or clothes, that area will become white and all other area remain same. However, MediaPipe only gave an example with "DeepLab-v3" model. Here you can see their example.

https://github.com/googlesamples/mediapipe/tree/main/examples/image_segmentation/js https://developers.google.com/mediapipe/solutions/vision/image_segmenter/index#models

This is their callback function and I want to detect hair or clothes and make them white.

function callback(result) {
    const cxt = canvasClick.getContext("2d");
    const { width, height } = result.categoryMask;

    let imageData = cxt.getImageData(0, 0, width, height).data;

    canvasClick.width = width;
    canvasClick.height = height;
    let category = "";

    const mask = result.categoryMask.getAsUint8Array();

    for (let i in mask) {
        if (mask[i] > 0) {
            category = labels[mask[i]];
        }
        const legendColor = legendColors[mask[i] % legendColors.length];
        imageData[i * 4] = (legendColor[0] + imageData[i * 4]) / 2;
        imageData[i * 4 + 1] = (legendColor[1] + imageData[i * 4 + 1]) / 2;
        imageData[i * 4 + 2] = (legendColor[2] + imageData[i * 4 + 2]) / 2;
        imageData[i * 4 + 3] = (legendColor[3] + imageData[i * 4 + 3]) / 2;
    }
    
    const uint8Array = new Uint8ClampedArray(imageData.buffer);
    const dataNew = new ImageData(uint8Array, width, height);
    cxt.putImageData(dataNew, 0, 0);
    const p = event.target.parentNode.getElementsByClassName("classification")[0];
    p.classList.remove("removed");
    p.innerText = "Category: " + category;
}

Solution

  • You can select a different model as given in the example code below:

    const createImageSegmenter = async () => {
    const vision = await FilesetResolver.forVisionTasks("https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.0/wasm");
    imageSegmenter = await ImageSegmenter.createFromOptions(vision, {
        baseOptions: {
            modelAssetPath: "./path/to/selfie_multiclass.tflite", // Download the selfie multiclass model from their website or get the CDN URL
            delegate: "GPU"
        },
        runningMode: "IMAGE",
        outputCategoryMask: true,
        outputConfidenceMasks: true
    });
    labels = imageSegmenter.getLabels(); // this gives all the labels present in the model
    };
    

    And then in your code make the following changes:

    function callback(result) {
    const cxt = canvasClick.getContext("2d");
    const { width, height } = result.categoryMask;
    
    let imageData = cxt.getImageData(0, 0, width, height).data;
    
    canvasClick.width = width;
    canvasClick.height = height;
    let category = "";
    
    const mask = result.categoryMask.getAsUint8Array();
    
    for (let i in mask) {
        if (mask[i] > 0) {
            category = labels[mask[i]];
        }
        if (category=="hair" || category=="clothes") {   // draw segmentation masks only for the category you want
        const legendColor = legendColors[mask[i] % legendColors.length];
        imageData[i * 4] = (legendColor[0] + imageData[i * 4]) / 2;
        imageData[i * 4 + 1] = (legendColor[1] + imageData[i * 4 + 1]) / 2;
        imageData[i * 4 + 2] = (legendColor[2] + imageData[i * 4 + 2]) / 2;
        imageData[i * 4 + 3] = (legendColor[3] + imageData[i * 4 + 3]) / 2;
       }
    }
    
    const uint8Array = new Uint8ClampedArray(imageData.buffer);
    const dataNew = new ImageData(uint8Array, width, height);
    cxt.putImageData(dataNew, 0, 0);
    const p = event.target.parentNode.getElementsByClassName("classification")[0];
    p.classList.remove("removed");
    p.innerText = "Category: " + category;
    }
    

    In the legendColors variable, you can give the RGBA value of the color of the mask you would like to draw on the selected area.

    The classes/labels in the selfie multiclass models are as below:

    You can find more information regarding the model on their website: [Link]