deeplearning4jdl4j

ModelSerializer.writeModel and saveUpdater param


According to the description from the link: https://github.com/deeplearning4j/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/features/modelsavingloading/SaveLoadMultiLayerNetwork.java

boolean saveUpdater = true;
//Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future

When using saveUpdater = true, zip archive must contain three files - configuration.json, coefficients.bin and updaterState.bin

But in my case zip containes only configuration.json and coefficients.bin

What am I missing?

Below IrisClassifierLearning from the developer's site.

//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File("iris.txt")));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData);     //Apply normalization to the training data
normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

final int numInputs = 4;
int outputNum = 3;
long seed = 6;

log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .activation(Activation.TANH)
        .weightInit(WeightInit.XAVIER)
        .updater(new Sgd(0.1))
        .l2(1e-4)
        .list()
        .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
                .build())
        .layer(new DenseLayer.Builder().nIn(3).nOut(3)
                .build())
        .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
                .nIn(3).nOut(outputNum).build())
        .build();

//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));

for(int i=0; i<1000; i++ ) {
    model.fit(trainingData);
}

//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());

System.out.println("--- Output ---");
System.out.println(output);

System.out.println("--- Labels ---");
System.out.println(testData.getLabels());

eval.eval(testData.getLabels(), output);
log.info(eval.stats());

// Save the Model
File locationToSave = new File("trained_iris_model.zip");
ModelSerializer.writeModel(model, locationToSave, true);

Solution

  • SGD does not have any state to save. If you need to use, adam, adagrad or anything that stores the current state. You'll only see something in there if it's required.