I'm studing Deeplearning4j (ver. 1.0.0-M1.1) for building neural networks.
I use IrisClassifier from Deeplearning4j as an example.
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
final int numInputs = 4;
int outputNum = 3;
long seed = 6;
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(0.1))
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
.nIn(3).nOut(outputNum).build())
.build();
//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));
for(int i=0; i<1000; i++ ) {
model.fit(trainingData);
}
//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output);
log.info(eval.stats());
Inputs for training model looks like:
5.1,3.5,1.4,0.2,0
...
7.0,3.2,4.7,1.4,1
...
6.3,3.3,6.0,2.5,2
where the last item is a class for set inputs.
It works great, trains the model and tests.
Now I want to use trained model to predict classes of new inputs, but don't understang how to do it.
Ok, I can save the model, and load again:
// Save the Model
File locationToSave = new File("C:/Projects/deeplearning4j/trained_iris_model.zip");
ModelSerializer.writeModel(model, locationToSave, false);
// Open the model
File locationToLoad = new File("C:/Projects/deeplearning4j/trained_iris_model.zip");
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToLoad);
Next, I load as example the same data as used for training, but without classes.
5.1,3.5,1.4,0.2
...
7.0,3.2,4.7,1.4
...
6.3,3.3,6.0,2.5
int numLinesToSkip = 0;
char delimiter = ',';
CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); //skip no lines at the top - i.e. no header
recordReader.initialize(new FileSplit(new File("C:/Projects/deeplearning4j/iris-to-predict.txt")));
But what next?
How I can get prediction?
Thanx!
So, adding this code solved my problem:
int batchSize = 150;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize);
DataSet allData = iterator.next();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allData);
normalizer.transform(allData);
INDArray output = model.output(allData.getFeatures());
// Output
System.out.println(output);