pythonlistscikit-learnpprint

How to print many values of a hyperparameter row rise (along a line) and not all along the column when "pprint()" is used during a RandomizedSearchCV


When I am fitting my random forest model during a Randomized Search (with sklearn), I would like to print the considered values of the hyperparameter "n_estimators" :

# build a classifier
clf_rf = RandomForestClassifier()

# Set up the hyperparameter search
param_dist = {"max_depth": [3, None],
              "n_estimators": list(range(10, 200)),
              "max_features": list(range(1, X_test.shape[1]+1)),
              "min_samples_split": list(range(2, 11)),
              "min_samples_leaf": list(range(1, 11)),
              "bootstrap": [True, False],
              "criterion": ["gini", "entropy"]}

pprint(param_dist) # to see all the researched hyperparam above 

As you can see n_estimators is a list(range(10, 200)), so there are many values here.

Ideally, I would like that the things are printed in this way : [10,11,12,13,14,15 etc ...] -> along a row

Instead of something like :

'n_estimators': [10,
                  11,
                  12,
                  13,
                  14,
                  15,
                  16,
                  17,
                  etc
                    ]

Do you know how I can print this list of n_estimators, row wise, with pprint ? Such that I can get things in this fashion please :

{'bootstrap': [True, False],
 'criterion': ['gini', 'entropy'],
 'max_depth': [3, None],
 'max_features': [1, 2, 3, 4, 5, 6, 7, 8],
 'min_samples_leaf': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'min_samples_split': [2, 3, 4, 5, 6, 7, 8, 9, 10]}

Thank you


Solution

  • If you want to print this in the specified format, you can pass the width parameter to pprint like so

    pprint(param_dist, width=1000)
    
    # Outputs:
    
    # {'bootstrap': [True, False],
    #  'criterion': ['gini', 'entropy'],
    #  'max_depth': [3, None],
    #  'max_features': [1, 2, 3, 4],
    #  'min_samples_leaf': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    #  'min_samples_split': [2, 3, 4, 5, 6, 7, 8, 9, 10],
    #  'n_estimators': [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, ...}
    

    Note: The reason pprint is displaying this across lines is that your n_estimators list is of length 190. The ellipsis I include in the output is shorthand for the skipped values of n_estimators.