javadeep-learningdl4jnd4j

dl4j lstm not successful


Im trying to copy the exrcise about halfway down the page on this link: https://d2l.ai/chapter_recurrent-neural-networks/sequence.html

The exercise uses a sine function to create 1000 data points between -1 through 1 and use a recurrent network to approximate the function.

Below is the code I used. I'm going back to study more why this isn't working as it doesn't make much sense to me now when I was easily able to use a feed forward network to approximate this function.

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

Can you explain the code I would need for a 1 in 10 hidden and 1 out lstm network to approximate a sine function?

Im not using any normalization as function is already -1:1 and Im using the Y input as the feature and the following Y Input as the label to train the network.

You notice i am building a class that allows for easier construction of nets and I have tried throwing many changes at the problem but I am sick of guessing.

Here are some examples of my results. Blue is data red is result

enter image description here

enter image description here


Solution

  • This is one of those times were you go from wondering why was this not working to how in the hell were my original results were as good as they were.

    My failing was not understanding the documentation clearly and also not understanding BPTT.

    With feed forward networks each iteration is stored as a row and each input as a column. An example is [dataset.size, network inputs.size]

    However with recurrent input its reversed with each row being a an input and each column an iteration in time necessary to activate the state of the lstm chain of events. At minimum my input needed to be [0, networkinputs.size, dataset.size] But could also be [dataset.size, networkinputs.size, statelength.size]

    In my previous example I was training the network with data in this format [dataset.size, networkinputs.size, 1]. So from my low resolution understanding the lstm network should never have worked at all but somehow produced at least something.

    There may have also been some issue with converting the dataset to a list as I also changed how I feed the network but but I think the bulk of the issue was a data structure issue.

    Below are my new results Not perfect but this is 5 epochs of training so good considering