javacudnndl4j

DL4J-Image become too bright


Currently, I've been asked to write CNN code using DL4J using YOLOv2 architecture. But the problem is after the model has complete, I do a simple GUI for validation testing then the image shown is too bright and sometimes the image can be displayed. Im not sure where does this problem comes from whether at earliest stage of training or else. Here, I attach the code that I have for now. For Iterator:

public class faceMaskIterator {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(faceMaskIterator.class);
private static final int seed = 123;
private static Random rng = new Random(seed);
private static String dataDir;
private static Path pathDirectory;
private static InputSplit trainData, testData;
private static final String[] allowedFormats  = NativeImageLoader.ALLOWED_FORMATS;
private static final double splitRatio = 0.8;
private static final int nChannels = 3;
public static final int gridWidth = 13;
public static final int gridHeight = 13;
public static final int yolowidth = 416;
public static final int yoloheight = 416;

private static RecordReaderDataSetIterator makeIterator(InputSplit split, Path dir, int batchSize) throws Exception {

    ObjectDetectionRecordReader recordReader = new ObjectDetectionRecordReader(yoloheight, yolowidth, nChannels,
            gridHeight, gridWidth, new VocLabelProvider(dir.toString()));
    recordReader.initialize(split);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, 1,true);
    iter.setPreProcessor(new ImagePreProcessingScaler(0, 1));

    return iter;
}

public static RecordReaderDataSetIterator trainIterator(int batchSize) throws Exception {
    return makeIterator(trainData, pathDirectory, batchSize);
}

public static RecordReaderDataSetIterator testIterator(int batchSize) throws Exception {
    return makeIterator(testData, pathDirectory, batchSize);
}

public static void setup() throws IOException {
    log.info("Load data...");
    dataDir = Paths.get(
            System.getProperty("user.home"),
            Helper.getPropValues("dl4j_home.data")
    ).toString();
    pathDirectory = Paths.get(dataDir,"face_mask_dataset");
    FileSplit fileSplit = new FileSplit(new File(pathDirectory.toString()),allowedFormats,rng);
    PathFilter pathFilter = new RandomPathFilter(rng,allowedFormats);
    InputSplit[] sample = fileSplit.sample(pathFilter, splitRatio,1-splitRatio);
    trainData = sample[0];
    testData = sample[1];
}}

For training:

public class faceMaskPreTrained {
private static final Logger log = LoggerFactory.getLogger(ai.certifai.groupProjek.faceMaskPreTrained.class);
private static int seed = 420;
private static double detectionThreshold = 0.9;
private static int nBoxes = 3;
private static double lambdaNoObj = 0.7;
private static double lambdaCoord = 1.0;
private static double[][] priorBoxes = {{1, 1}, {2, 1}, {1, 2}};

private static int batchSize = 3;
private static int nEpochs = 1;
private static double learningRate = 1e-4;
private static int nClasses = 3;
private static List<String> labels;

private static File modelFilename = new File(System.getProperty("user.dir"), "generated-models/facemask_detector.zip");
private static ComputationGraph model;
private static Frame frame = null;
private static final Scalar GREEN = RGB(0, 255.0, 0);
private static final Scalar YELLOW = RGB(255, 255, 0);
private static final Scalar RED = RGB(255, 0, 0);
private static Scalar[] colormap = {GREEN, YELLOW, RED};
private static String labeltext = null;

public static void main(String[] args) throws Exception {
    faceMaskIterator.setup();
    RecordReaderDataSetIterator trainIter = faceMaskIterator.trainIterator(batchSize);
    RecordReaderDataSetIterator testIter = faceMaskIterator.testIterator(1);
    labels = trainIter.getLabels();

    if (modelFilename.exists()) {
        Nd4j.getRandom().setSeed(seed);
        log.info("Load model...");
        model = ModelSerializer.restoreComputationGraph(modelFilename);
    } else {
        Nd4j.getRandom().setSeed(seed);
        INDArray priors = Nd4j.create(priorBoxes);

        log.info("Build model...");
        ComputationGraph pretrained = (ComputationGraph) YOLO2.builder().build().initPretrained();

        
        FineTuneConfiguration fineTuneConf = getFineTuneConfiguration();
        model = getComputationGraph(pretrained, priors, fineTuneConf);
        System.out.println(model.summary(InputType.convolutional(
                faceMaskIterator.yoloheight,
                faceMaskIterator.yolowidth,
                nClasses)));

        log.info("Train model...");
        UIServer server = UIServer.getInstance();
        StatsStorage storage = new InMemoryStatsStorage();
        server.attach(storage);
        model.setListeners(new ScoreIterationListener(5), new StatsListener(storage,5));

        for (int i = 1; i < nEpochs + 1; i++) {
            trainIter.reset();
            while (trainIter.hasNext()) {
                model.fit(trainIter.next());
            }
            log.info("*** Completed epoch {} ***", i);
        }
        ModelSerializer.writeModel(model, modelFilename, true);
        System.out.println("Model saved.");
    }
    //   Evaluate the model's accuracy by using the test iterator.
    OfflineValidationWithTestDataset(testIter);
    //   Inference the model and process the webcam stream and make predictions.
    doInference();
}

private static ComputationGraph getComputationGraph(ComputationGraph pretrained, INDArray priors, FineTuneConfiguration fineTuneConf) {

    return new TransferLearning.GraphBuilder(pretrained)
            .fineTuneConfiguration(fineTuneConf)
            .removeVertexKeepConnections("conv2d_23")
            .removeVertexKeepConnections("outputs")
            .addLayer("conv2d_23",
                    new ConvolutionLayer.Builder(1, 1)
                            .nIn(1024)
                            .nOut(nBoxes * (5 + nClasses))
                            .stride(1, 1)
                            .convolutionMode(ConvolutionMode.Same)
                            .weightInit(WeightInit.XAVIER)
                            .activation(Activation.IDENTITY)
                            .build(),
                    "leaky_re_lu_22")
            .addLayer("outputs",
                    new Yolo2OutputLayer.Builder()
                            .lambdaNoObj(lambdaNoObj)
                            .lambdaCoord(lambdaCoord)
                            .boundingBoxPriors(priors.castTo(DataType.FLOAT))
                            .build(),
                    "conv2d_23")
            .setOutputs("outputs")
            .build();
}

private static FineTuneConfiguration getFineTuneConfiguration() {

    return new FineTuneConfiguration.Builder()
            .seed(seed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
            .gradientNormalizationThreshold(1.0)
            .updater(new Adam.Builder().learningRate(learningRate).build())
            .l2(0.00001)
            .activation(Activation.IDENTITY)
            .trainingWorkspaceMode(WorkspaceMode.ENABLED)
            .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
            .build();
}

//    Evaluate visually the performance of the trained object detection model
private static void OfflineValidationWithTestDataset(RecordReaderDataSetIterator test) throws InterruptedException {
    NativeImageLoader imageLoader = new NativeImageLoader();
    CanvasFrame canvas = new CanvasFrame("Validate Test Dataset");
    OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
    org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
    Mat convertedMat = new Mat();
    Mat convertedMat_big = new Mat();

    while (test.hasNext() && canvas.isVisible()) {
        org.nd4j.linalg.dataset.DataSet ds = test.next();
        INDArray features = ds.getFeatures();
        INDArray results = model.outputSingle(features);
        List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
        YoloUtils.nms(objs, 0.4);
        Mat mat = imageLoader.asMat(features);
        mat.convertTo(convertedMat, CV_8U, 255, 0);
        int w = mat.cols() * 2;
        int h = mat.rows() * 2;
        resize(convertedMat, convertedMat_big, new Size(w, h));
        convertedMat_big = drawResults(objs, convertedMat_big, w, h);
        canvas.showImage(converter.convert(convertedMat_big));
        canvas.waitKey();
    }
    canvas.dispose();
}

// Stream video frames from Webcam and run them through YOLOv2 model and get predictions
private static void doInference() {

    String cameraPos = "front";
    int cameraNum = 0;
    Thread thread = null;
    NativeImageLoader loader = new NativeImageLoader(
            faceMaskIterator.yolowidth,
            faceMaskIterator.yoloheight,
            3,
            new ColorConversionTransform(COLOR_BGR2RGB));
    ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);

    if (!cameraPos.equals("front") && !cameraPos.equals("back")) {
        try {
            throw new Exception("Unknown argument for camera position. Choose between front and back");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    FrameGrabber grabber = null;
    try {
        grabber = FrameGrabber.createDefault(cameraNum);
    } catch (FrameGrabber.Exception e) {
        e.printStackTrace();
    }
    OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();

    try {
        grabber.start();
    } catch (FrameGrabber.Exception e) {
        e.printStackTrace();
    }

    CanvasFrame canvas = new CanvasFrame("Object Detection");
    int w = grabber.getImageWidth();
    int h = grabber.getImageHeight();
    canvas.setCanvasSize(w, h);

    while (true) {
        try {
            frame = grabber.grab();
        } catch (FrameGrabber.Exception e) {
            e.printStackTrace();
        }

        //if a thread is null, create new thread
        if (thread == null) {
            thread = new Thread(() ->
            {
                while (frame != null) {
                    try {
                        Mat rawImage = new Mat();

                        //Flip the camera if opening front camera
                        if (cameraPos.equals("front")) {
                            Mat inputImage = converter.convert(frame);
                            flip(inputImage, rawImage, 1);
                        } else {
                            rawImage = converter.convert(frame);
                        }

                        Mat resizeImage = new Mat();
                        resize(rawImage, resizeImage, new Size(faceMaskIterator.yolowidth, faceMaskIterator.yoloheight));
                        INDArray inputImage = loader.asMatrix(resizeImage);
                        scaler.transform(inputImage);
                        INDArray outputs = model.outputSingle(inputImage);
                        org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
                        List<DetectedObject> objs = yout.getPredictedObjects(outputs, detectionThreshold);
                        YoloUtils.nms(objs, 0.4);
                        rawImage = drawResults(objs, rawImage, w, h);
                        canvas.showImage(converter.convert(rawImage));
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            });
            thread.start();
        }

        KeyEvent t = null;
        try {
            t = canvas.waitKey(33);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        if ((t != null) && (t.getKeyCode() == KeyEvent.VK_Q)) {
            break;
        }
    }
}

private static Mat drawResults(List<DetectedObject> objects, Mat mat, int w, int h) {
    for (DetectedObject obj : objects) {
        double[] xy1 = obj.getTopLeftXY();
        double[] xy2 = obj.getBottomRightXY();
        String label = labels.get(obj.getPredictedClass());
        int x1 = (int) Math.round(w * xy1[0] / faceMaskIterator.gridWidth);
        int y1 = (int) Math.round(h * xy1[1] / faceMaskIterator.gridHeight);
        int x2 = (int) Math.round(w * xy2[0] / faceMaskIterator.gridWidth);
        int y2 = (int) Math.round(h * xy2[1] / faceMaskIterator.gridHeight);
        //Draw bounding box
        rectangle(mat, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()], 2, 0, 0);
        //Display label text
        labeltext = label + " " + String.format("%.2f", obj.getConfidence() * 100) + "%";
        int[] baseline = {0};
        Size textSize = getTextSize(labeltext, FONT_HERSHEY_DUPLEX, 1, 1, baseline);
        rectangle(mat, new Point(x1 + 2, y2 - 2), new Point(x1 + 2 + textSize.get(0), y2 - 2 - textSize.get(1)), colormap[obj.getPredictedClass()], FILLED, 0, 0);
        putText(mat, labeltext, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, RGB(0, 0, 0));
    }
    return mat;
}

Solution

  • CanvasFrame tries to do gamma correction by default because it's typically needed by cameras used for CV, but cheap webcams usually output gamma corrected images, so make sure to let CanvasFrame know about it this way:

    // We should also specify the relative monitor/camera response for proper gamma correction.
    CanvasFrame frame = new CanvasFrame("Some Title", CanvasFrame.getDefaultGamma()/grabber.getGamma());