javascriptnode.jsjsonbrain.js

pre-train brain js model


My Question


I just started learning brain js and developed a model which gives you category based on the input text.

I want to ask that each time running the model depends on iterations greater the number of iterations the more it will take time but it improves the accuracy of the model.

Is there any way I can pre-trained my model so it won't let user to wait for the output.

An example will really help me.

My Code


// JSON file data //

[
  {
    "text": "my unit test failed",
    "category": "software"
  },
  {
    "text": "my driver is working",
    "category": "hardware"
  }
]

const brain = require('brain.js');
const data = require('./data.json');                 //data receiving from json//

const network = new brain.recurrent.LSTM();

const trainingData = data.map(item => ({
  input: item.text,
  output: item.category
}));

network.train(trainingData, {
  log: (error) => console.log(error),
  iterations: 1000
});

console.log(network.run('buy me a driver'));         // output is Hardware //


Solution

  • You can separate the script into two. In one we train the network with the data, then save it to a JSON file, using the network.toJSON() function.

    In the second, we load the network state from the JSON file using the network.fromJSON() function, then run it against our data.

    train-network.js

    const brain = require('brain.js');
    const data = require('./data.json');    
    const fs = require("fs");
    
    const network = new brain.recurrent.LSTM();
    
    const trainingData = data.map(item => ({
      input: item.text,
      output: item.category
    }));
    
    network.train(trainingData, {
      log: (error) => console.log(error),
      iterations: 1000
    });
    
    // Save network state to JSON file.
    const networkState = network.toJSON();
    fs.writeFileSync("network_state.json",  JSON.stringify(networkState), "utf-8");
    

    load-network.js

    const brain = require('brain.js');
    const fs = require("fs");
    
    let network = new brain.recurrent.LSTM();
    
    // Load the trained network data from JSON file.
    const networkState = JSON.parse(fs.readFileSync("network_state.json", "utf-8").toString());
    network.fromJSON(networkState);
    
    console.log(network.run('buy me a driver'));