javadeep-learningrecurrent-neural-networkdeeplearning4jdl4j

Deeplearning4J RNN Training : Exception 3D input expected to RNN layer expected, got 2


with the following code (tweaked for hours with different params), I keep getting an exception java.lang.IllegalStateException: 3D input expected to RNN layer expected, got 2

What I am trying to accomplish is to train a RNN to predict the next value (double) in a sequence based on a bunch of training sequences. I am generating the features with a simple random data generator, and using the last val in a sequence as the training label (in this case predicted value).

my code:

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.util.Random;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class RnnPredictionExample {

  public static void main(String[] args) {
    //generate 100 rows of data that have 50 columns/features each
    DataSet trainingdata = getRandomDataset(100, 51, 1);
    // Train the RNN model...
    MultiLayerNetwork trainedModel = trainRnnModel(trainingdata, 50, 10, 1);

    // generate a sequence, and Perform next value prediction on the sequence
    double[] inputSequence = randomData(50, 1);
    double predictedValue = predictNextValue(trainedModel, inputSequence);
    System.out.println("Predicted Next Value: " + predictedValue);
  }

  public static MultiLayerNetwork trainRnnModel(DataSet trainingdataandlabels, int sequenceLength, int numHiddenUnits, int numEpochs) {
    // ... Create network configuration ...

    // Create and initialize the network
    MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
            //.seed(123)
            .list()
            .layer(new LSTM.Builder()
                    .nIn(1)
                    .nOut(50)
                    .build()
            )
            .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                    .activation(Activation.IDENTITY)
                    .nIn(50)
                    .nOut(1) // Set nOut to 1
                    .build()
            )
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(config);
    net.init();

    for (int i = 0; i < numEpochs; i++) {
      net.fit(trainingdataandlabels);
    }

    return net;
  }

  public static double predictNextValue(MultiLayerNetwork trainedModel, double[] inputSequence) {
    INDArray inputArray = Nd4j.create(inputSequence);
    INDArray predicted = trainedModel.rnnTimeStep(inputArray);

    // Predicted value is the last element of the predicted sequence
    return predicted.getDouble(predicted.length() - 1);
  }

  static Random random = new Random();

  public static double[] randomData(int length, int rangeMultiplier) {

    double[] out = new double[length];
    for (int i = 0; i < out.length; i++) {
      out[i] = random.nextDouble() * rangeMultiplier;
    }
    return out;
  }

  //assumes labes is the last val in each sequence
  public static DataSet getRandomDataset(int numRows, int lengthEach, int rangeMultiplier) {
    INDArray training = Nd4j.zeros(numRows, lengthEach - 1);
    INDArray labels = Nd4j.zeros(numRows, 1);

    for (int i = 0; i < numRows; i++) {
      double[] randomData = randomData(lengthEach, rangeMultiplier);
      for (int j = 0; j < randomData.length - 1; j++) {
        training.putScalar(new int[]{i, j}, randomData[j]);
      }
      labels.putScalar(new int[]{i, 0}, randomData[randomData.length - 1]);

    }

    return new DataSet(training, labels);

  }
}

thanks

For those interested, I made the changes based on the accepted answer, and here is the working code again in entirety

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.util.Random;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class RnnPredictionExample {

  public static void main(String[] args) {
    //generate 100 rows of data that have 50 columns/features each
    DataSet trainingdata = getRandomDataset(100, 51, 1);
    // Train the RNN model...
    MultiLayerNetwork trainedModel = trainRnnModel(trainingdata, 50, 10, 1);

    // generate a sequence, and Perform next value prediction on the sequence
    double[] inputSequence = randomData(50, 1);
    double predictedValue = predictNextValue(trainedModel, inputSequence);
    System.out.println("Predicted Next Value: " + predictedValue);
  }

  public static MultiLayerNetwork trainRnnModel(DataSet trainingdataandlabels, int sequenceLength, int numHiddenUnits, int numEpochs) {
    // ... Create network configuration ...

    // Create and initialize the network
    MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
            //.seed(123)
            .list()
            .layer(new LSTM.Builder()
                    .nIn(50)
                    .nOut(1)
                    .build()
            )
            .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                    .activation(Activation.IDENTITY)
                    .nIn(1)
                    .nOut(1) // Set nOut to 1
                    .build()
            )
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(config);
    net.init();

    for (int i = 0; i < numEpochs; i++) {
      net.fit(trainingdataandlabels);
    }

    return net;
  }

  public static double predictNextValue(MultiLayerNetwork trainedModel, double[] inputSequence) {
    // INDArray inputArray = Nd4j.create(inputSequence);
    INDArray inputArray = Nd4j.create(inputSequence).reshape(1, inputSequence.length, 1);
    INDArray predicted = trainedModel.rnnTimeStep(inputArray);

    // Predicted value is the last element of the predicted sequence
    return predicted.getDouble(predicted.length() - 1);
  }

  static Random random = new Random();

  public static double[] randomData(int length, int rangeMultiplier) {

    double[] out = new double[length];
    for (int i = 0; i < out.length; i++) {
      out[i] = random.nextDouble() * rangeMultiplier;
    }
    return out;
  }

  //assumes labes is the last val in each sequence
  public static DataSet getRandomDataset(int numRows, int lengthEach, int rangeMultiplier) {
    //INDArray training = Nd4j.zeros(numRows, lengthEach - 1);
    INDArray training = Nd4j.zeros(numRows, lengthEach - 1, 1);
    //INDArray labels = Nd4j.zeros(numRows, 1);
    INDArray labels = Nd4j.zeros(numRows, 1, 1);

    for (int i = 0; i < numRows; i++) {
      double[] randomData = randomData(lengthEach, rangeMultiplier);
      for (int j = 0; j < randomData.length - 1; j++) {
        // training.putScalar(new int[]{i, j}, randomData[j]);
        training.putScalar(new int[]{i, j, 0}, randomData[j]);
      }
      //labels.putScalar(new int[]{i, 0}, randomData[randomData.length - 1]);
      labels.putScalar(new int[]{i, 0, 0}, randomData[randomData.length - 1]);
    }

    return new DataSet(training, labels);

  }
}

Solution

  • RNN expect sequences of data, and the data should be structured as a 3D tensor with dimensions (batchSize, sequenceLength, numFeatures).

    However, the generated random data is only 2D, and you need to convert it into the appropriate 3D format.

    Please modify as below.

    return predicted.getDouble(predicted.length() - 1);
    return predicted.getDouble(0, predicted.length(1) - 1, 0);
    
    INDArray inputArray = Nd4j.create(inputSequence);
    INDArray inputArray = Nd4j.create(inputSequence).reshape(1, inputSequence.length, 1);
    
    INDArray training = Nd4j.zeros(numRows, lengthEach - 1);
    INDArray training = Nd4j.zeros(numRows, lengthEach - 1, 1);
    
    INDArray labels = Nd4j.zeros(numRows, 1);
    INDArray labels = Nd4j.zeros(numRows, 1, 1);
    
    training.putScalar(new int[]{i, j}, randomData[j]);
    training.putScalar(new int[]{i, j, 0}, randomData[j]);
    
    labels.putScalar(new int[]{i, 0}, randomData[randomData.length - 1]);
    labels.putScalar(new int[]{i, 0, 0}, randomData[randomData.length - 1]);