javaneural-networkmxnetimage-classificationdjl

Error when trying to train Deep Java Library neural Network with an Image folder Dataset


With this code:

public class Train {                                                                                          
    public static void main(String[] args) throws IOException, TranslateException {                           
                                                                                                              
        Application application = Application.CV.IMAGE_CLASSIFICATION;                                        
        long inputSize = 28*28;                                                                               
        long outputSize = 10;                                                                                 
        int batchSize=20;                                                                                     
        Shape shape= new Shape(1,28*28);                                                                      
        SequentialBlock block = new SequentialBlock()                                                         
                .add(Blocks.batchFlattenBlock(inputSize))                                                     
                .add(Linear.builder().setUnits(inputSize).build())                                            
                .add(Activation::relu)                                                                        
                .add(Linear.builder().setUnits(128).build())                                                  
                .add(Activation::sigmoid)                                                                     
                .add(Linear.builder().setUnits(outputSize).build());                                          
        Repository repository = Repository.newInstance("folder", Paths.get("src/main/java/Ressources"));      
                                                                                                              
        ImageFolder dataset = ImageFolder.builder()                                                           
                .setRepository(repository)                                                                    
                .optFlag(Image.Flag.GRAYSCALE)                                                                
                .addTransform(new Resize( 28,28))                                                             
                .addTransform(new ToTensor())                                                                 
                .setSampling(batchSize,true)                                                                  
                .build();                                                                                     
        dataset.prepare(new ProgressBar());                                                                   
                                                                                                              
                                                                                                                                                                                                                 
                                                                                                                                                                                                
        Model model = Model.newInstance("mlp");                                                               
        //model.setBlock(new Mlp((int) inputSize, (int) outputSize, new int[] {128, 64}));                    
        model.setBlock(block);                                                                                
        TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss())                                      
                .addEvaluator(new Accuracy())                                                                 
                .optOptimizer(Optimizer.adadelta().build())                                                   
                .addTrainingListeners(TrainingListener.Defaults.logging());                                   
                                                                                                              
        Trainer trainer = model.newTrainer(config);                                                           
        trainer.initialize(shape);                                                                            
        int epoch = 2;                                                                                        
        EasyTrain.fit(trainer,epoch,dataset,null);                                                            
                                                                 
                                                                                                              
        Path modelDir = Paths.get("build/mlp");                                                               
        Files.createDirectories(modelDir);                                                                    
                                                                                                              
        model.setProperty("Epoch", String.valueOf(epoch));                                                    
                                                                                                              
        model.save(modelDir, "mlp");                                                                          
        try {                                                                                                 
            Use.main(new String[]{});                                                                         
        } catch (MalformedModelException e) {                                                                 
            throw new RuntimeException(e);                                                                    
        }                                                                                                     
                                                                                                              
                                                                                                              
    }                                                                                                         
}                                                                                                             
                                                                                                              

When running it the EasyTrain.fit function gives following error:

Exception in thread "main" ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: src.Size() == dst->Size() (20 vs. 200) : Cannot reshape array of size 20 into shape [20,10]

The dataset at least seemed to work correctly.

I tried reducing the batch size, but dst.size is allways 10 times src.size

My dataset are multiple 28*28 jpg images


Solution

  • This error seems to be from reshape operation. You mentioned dataSet is working. So most probably it is in the excution of the blocks in the sequence blocks.

    To debug it, you can always first set break point to find at which step this error happens. And from there, given message " Cannot reshape array of size 20 into shape [20,10]", you check whether the size 20 is not as expected or the target shape [20,10] is not as expected.