pythonnaivebayespy-bnlearn

How to predict target variable with Naive Bayes from evidence not present in training set?


I'm working on a uni project and the deadline is coming soon. I was assigned a naive bayes classification problem using bnlearn. After reading from the dataset, I've done the usual train-test split, but the split always is made so that in the test set there are some rows with values (not columns) not present in the training set. For example, the training set might contain several values for age, ranging from 20 to 60, but (again, for example) the training set is missing the value "40" for the age attribute. What happens when I try to predict from the test set and this contains a row with age = 40? Boom, it stops, it gives me an error saying: "KeyError: 40.0". Even if it's the only missing value, whereas I expected it to ignore it. The value was present in the original dataset, but after the split it was only included in the test set, thus the model doesn't know anything about it and fails to predict. This happens for multiple values over multiple attributes.

I've tried stratifying the split but with no success. I've also let the model learn and predict from the WHOLE dataset, but as you may guess it went into overfitting. How should I handle unprecedented evidence in a naive bayes model? Do I just skip to the next row of the test set? I fear this might leave me with a very limited model, incapable of drawing any solid conclusions.


Solution

  • As both Dan Nagle and Ivan suggested, splitting each column into bins helped with grouping values into categories, this way it's more probable for a value in the test set to be accepted by the model, since the model doesn't read the value itself but its bin/category, which is shared with other values present in the training set. By splitting the columns with the most values into bins, the model's accuracy went from 24% to 64%.

    The code for this is:

    for name in names:
        num_of_val = len(set(x[name].tolist()))
        if num_of_val > 10:
            x[name] = pd.cut(x[name].tolist(), bins=10,
                             labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], include_lowest=True)
    

    Thanks everyone!