I am studying a text generation example https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java. The output of lstm network is a probability distribution, as I understand it, this is an double array, where each value shows the probability of the character corresponding to the index in the array. So I cannot understand the following code where we get the character index from the distribution:
/** Given a probability distribution over discrete classes, sample from the distribution
* and return the generated class index.
* @param distribution Probability distribution over classes. Must sum to 1.0
*/
static int sampleFromDistribution(double[] distribution, Random rng){
double d = 0.0;
double sum = 0.0;
for( int t=0; t<10; t++ ) {
d = rng.nextDouble();
sum = 0.0;
for( int i=0; i<distribution.length; i++ ){
sum += distribution[i];
if( d <= sum ) return i;
}
//If we haven't found the right index yet, maybe the sum is slightly
//lower than 1 due to rounding error, so try again.
}
//Should be extremely unlikely to happen if distribution is a valid probability distribution
throw new IllegalArgumentException("Distribution is invalid? d="+d+", sum="+sum);
}
It seems that we are getting a random value. Why don't we just choose the index where the value is highest? What should I do if I want to select not one, but two or three most likely next characters?
This function samples from the distribution, instead of simply returning the most probable character class.
That also means that you aren't getting the most likely character, instead, you are getting a random character with the probability that the given probability distribution defines.
This works by first getting a random value between 0 and 1 from a uniform distribution (rng.nextDouble()
) and then finding where that value falls in the given distribution.
You can imagine it to be something like this (if your had only a to f in your alphabet):
[ a | b | c | d | e | f ]
0.0 0.3 0.5 1.0
If the random value that is drawn is just over 0.5, it would produce an e
, if it is just less than that it would be a d
.
Each letter occupies a proportional amount of space on this line between 0 and 1 according to the weight it has in the distribution.