pythonmachine-learningmachine-learning-model

Model Trainer Issue on End-to-End ML Project - TypeError: initiate_model_training() missing 4 required positional arguments


I am following the process shown on Wine Quality Prediction End-to-End ML Project on Krish Naik's YouTube channel to do a Flight Fare Prediction Project.

I run this cell of model trainer pipeline on 04_model_trainer.ipynb:

try:
    config = ConfigurationManager()
    model_trainer_config = config.get_model_trainer_config()
    model_trainer_config = ModelTrainer(model_trainer_config)
    # model_trainer_config.train()
    model_trainer_config.initiate_model_training()
except Exception as e:
    raise e

I get this error:

TypeError: initiate_model_training() missing 4 required positional arguments: 'X_train', 'X_test', 'y_train', and 'y_test'

Here is the full traceback:

[2023-12-16 21:58:22,484: INFO: common: yaml file: config\config.yaml loaded successfully]
[2023-12-16 21:58:22,493: INFO: common: yaml file: params.yaml loaded successfully]
[2023-12-16 21:58:22,493: INFO: common: yaml file: schema.yaml loaded successfully]
[2023-12-16 21:58:22,493: INFO: common: created directory at: artifacts]
[2023-12-16 21:58:22,493: INFO: common: created directory at: artifacts/model_trainer]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[25], line 8
      6     model_trainer_config.initiate_model_training()
      7 except Exception as e:
----> 8     raise e

Cell In[25], line 6
      4     model_trainer_config = ModelTrainer(model_trainer_config)
      5     # model_trainer_config.train()
----> 6     model_trainer_config.initiate_model_training()
      7 except Exception as e:
      8     raise e

TypeError: initiate_model_training() missing 4 required positional arguments: 'X_train', 'X_test', 'y_train', and 'y_test'

Here is the code of ModelTrainer class:

class ModelTrainer:

    def __init__(self, model_trainer_config):
        self.model_trainer_config = model_trainer_config  
    # def __init__(self):
        # self.model_trainer_config = ModelTrainerConfig()    


    def save_obj(file_path, obj):
        try:

            dir_path = os.path.dirname(file_path)
            os.makedirs(dir_path, exist_ok=True)

            with open(file_path, 'wb') as file_obj:
                joblib.dump(obj, file_obj, compress= ('gzip'))

        except Exception as e:
            logger.info('Error occured in utils save_obj')
            raise e
        

    def evaluate_model(X_train, y_train, X_test, y_test, models):

        try:
            report = {}
            for i in range(len(models)):

                model = list(models.values())[i]

                # Train model
                model.fit(X_train,y_train)

                # Predict Testing data
                y_test_pred = model.predict(X_test)

                # Get R2 scores for train and test data
                test_model_score = r2_score(y_test,y_test_pred)

                report[list(models.keys())[i]] =  test_model_score

            return report

        except Exception as e:
            logger.info('Exception occured during model training')
            raise e    



    def initiate_model_training(self, X_train, X_test, y_train, y_test):
        try:
            logger.info('Splitting ')

            models={
            'LinearRegression':LinearRegression(),
            'Lasso':Lasso(),
            'Ridge':Ridge(),
            'Elasticnet':ElasticNet(),
            'RandomForestRegressor': RandomForestRegressor(),
            'GradientBoostRegressor()' : GradientBoostingRegressor(),
            "AdaBoost" : AdaBoostRegressor(),
            'DecisionTreeRegressor' : DecisionTreeRegressor(),
            "SupportVectorRegressor" : SVR(),
            "KNN" : KNeighborsRegressor()
            }

            model_report:dict = ModelTrainer.evaluate_model(X_train,y_train, X_test, y_test, models)
            print(model_report)
            print("\n====================================================================================")
            logger.info(f'Model Report : {model_report}')

            # to get best model score from dictionary
            best_model_score = max(sorted(model_report.values()))

            best_model_name = list(model_report.keys())[
                list(model_report.values()).index(best_model_score)
            ]

            best_model = models[best_model_name]

            print(f"Best Model Found, Model Name :{best_model_name}, R2-score: {best_model_score}")
            print("\n====================================================================================")
            logger.info(f"Best Model Found, Model name: {best_model_name}, R2-score: {best_model_score}")
            logger.info(f"{best_model.feature_names_in_}")
            
            ModelTrainer.save_obj(
            file_path = self.model_trainer_config.trained_model_file_path,
            obj = best_model
            )

        except Exception as e:
            logger.info('Exception occured at model trianing')
            raise e

Here is my file in GitHub.

My file encoding is UTF-8

Would you please help me to fix this issue?


Solution

  • So based on the logic that you have in your jupyter notebook, you should likely have initial_model_training() modified to look like the below function:

        def initiate_model_training(self): # removing the required variables to be passed into the function because those variables are created below (assuming they were correctly generated in train() )
            # lines below taken from your commented out train() function
            train_data = pd.read_csv(self.config.train_data_path)
            test_data = pd.read_csv(self.config.test_data_path)
    
            X_train = train_data.drop([self.config.target_column], axis=1)
            X_test = test_data.drop([self.config.target_column], axis=1)
            y_train = train_data[[self.config.target_column]]
            y_test = test_data[[self.config.target_column]]
            # lines above taken from your commented out train() function
    
            try:
                logger.info('Splitting ')
    
                models={
                'LinearRegression':LinearRegression(),
                'Lasso':Lasso(),
                'Ridge':Ridge(),
                'Elasticnet':ElasticNet(),
                'RandomForestRegressor': RandomForestRegressor(),
                'GradientBoostRegressor()' : GradientBoostingRegressor(),
                "AdaBoost" : AdaBoostRegressor(),
                'DecisionTreeRegressor' : DecisionTreeRegressor(),
                "SupportVectorRegressor" : SVR(),
                "KNN" : KNeighborsRegressor()
                }
    
                model_report:dict = ModelTrainer.evaluate_model(X_train,y_train, X_test, y_test, models)
                print(model_report)
                print("\n====================================================================================")
                logger.info(f'Model Report : {model_report}')
    
                # to get best model score from dictionary
                best_model_score = max(sorted(model_report.values()))
    
                best_model_name = list(model_report.keys())[
                    list(model_report.values()).index(best_model_score)
                ]
    
                best_model = models[best_model_name]
    
                print(f"Best Model Found, Model Name :{best_model_name}, R2-score: {best_model_score}")
                print("\n====================================================================================")
                logger.info(f"Best Model Found, Model name: {best_model_name}, R2-score: {best_model_score}")
                logger.info(f"{best_model.feature_names_in_}")
                
                ModelTrainer.save_obj(
                file_path = self.model_trainer_config.trained_model_file_path,
                obj = best_model
                )
    
            except Exception as e:
                logger.info('Exception occured at model trianing')
                raise e
    

    you should look through and understand how these https://drive.google.com/file/d/1c7k8i1l2X_r9i4yWAkQzxiP1Nu8_wqap/view these files work (from the tutorial that you are following), before you start to modify the code.