pythongoogle-cloud-automlgoogle-ai-platform

Get feature importance from trained Vertex AI Tabular regression model using Python


I am working with models trained using the Tabular automl of Vertex in GCP. Training and batch predictions work fine. I am truing to use the feature importance in visualizations and trying to get to them from within python. I can get to the model:

client = aiplatform.gapic.ModelServiceClient(client_options=client_options)
name = client.model_path(project=project, location='us-central1', model=modelnum)
response = client.get_model(name=name)

But I can not figure out how to get the trained model's feature importance's that were generated in the training pipeline. I can see them on the model page but can't access them from python.


Solution

  • To get details on the "Evaluate" page, you need to use list_model_evaluations(). This will return google.cloud.aiplatform_v1.services.model_service.pagers.ListModelEvaluationsPager that contains the values you see on "Evaluate" page. Since you mentioned you want to get the feature importance, you need to loop through the said object and get model_explanation. See code below:

    from google.cloud import aiplatform_v1 as aiplatform
    
    api_endpoint = 'us-central1-aiplatform.googleapis.com'
    client_options = {"api_endpoint": api_endpoint} # api_endpoint is required for client_options
    client_model = aiplatform.services.model_service.ModelServiceClient(client_options=client_options)
    project_id = 'your-project-id'
    location = 'us-central1'
    model_id = '9999999999999'
    
    model_name = f'projects/{project_id}/locations/{location}/models/{model_id}'
    list_eval_request = aiplatform.types.ListModelEvaluationsRequest(parent=model_name)
    list_eval = client_model.list_model_evaluations(request=list_eval_request)
    for val in list_eval:
        print(val.model_explanation)
    

    For the testing, I used Google's sample data (gs://cloud-ml-tables-data/bank-marketing.csv).

    Response from the code:

    mean_attributions {
      feature_attributions {
        struct_value {
          fields {
            key: "Age"
            value {
              number_value: 0.027145349596062344
            }
          }
          fields {
            key: "Balance"
            value {
              number_value: 0.009469658279914696
            }
          }
          fields {
            key: "Campaign"
            value {
              number_value: 0.009621628534664564
            }
          }
          fields {
            key: "Contact"
            value {
              number_value: 0.006477007587775141
            }
          }
          fields {
            key: "Day"
            value {
              number_value: 0.013976069802316006
            }
          }
          fields {
            key: "Default"
            value {
              number_value: 1.528606850783311e-08
            }
          }
          fields {
            key: "Duration"
            value {
              number_value: 0.1395725763431482
            }
          }
          fields {
            key: "Education"
            value {
              number_value: 0.007015091678270283
            }
          }
          fields {
            key: "Housing"
            value {
              number_value: 0.055101036115872845
            }
          }
          fields {
            key: "Job"
            value {
              number_value: 0.021222775094579954
            }
          }
          fields {
            key: "Loan"
            value {
              number_value: 0.002048753814978598
            }
          }
          fields {
            key: "MaritalStatus"
            value {
              number_value: 0.005709941134721149
            }
          }
          fields {
            key: "Month"
            value {
              number_value: 0.12325089337437695
            }
          }
          fields {
            key: "PDays"
            value {
              number_value: 0.023952343173674555
            }
          }
          fields {
            key: "POutcome"
            value {
              number_value: 0.06695149606670256
            }
          }
          fields {
            key: "Previous"
            value {
              number_value: 0.03921166116430856
            }
          }
        }
      }
    }
    

    From "Evaluate" page: enter image description here

    EDIT: 20210920

    I used my regressions model and get the data using aiplatform library. Still I got the attribute model_explanation. I'm using google-cloud-aiplatform==1.4.3 for the library version.

    Code used:

    from google.cloud import aiplatform
    
    api_endpoint = 'us-central1-aiplatform.googleapis.com'
    client_options = {"api_endpoint": api_endpoint}
    client_model = aiplatform.gapic.ModelServiceClient(client_options=client_options)
    #client_model = aiplatform.services.model_service.ModelServiceClient(client_options=client_options)
    
    project_id = 'your-project-id'
    location = 'us-central1'
    model_id = '999999999'
    
    model_name = f'projects/{project_id}/locations/{location}/models/{model_id}'
    list_eval = client_model.list_model_evaluations(parent=model_name)
    print(list_eval)
    

    Full JSON response:

    ListModelEvaluationsPager<model_evaluations {
      name: "projects/xxxxxxx/locations/us-central1/models/99999999/evaluations/8888888"
      metrics_schema_uri: "gs://google-cloud-aiplatform/schema/modelevaluation/regression_metrics_1.0.0.yaml"
      metrics {
        struct_value {
          fields {
            key: "meanAbsoluteError"
            value {
              number_value: 0.1303236
            }
          }
          fields {
            key: "meanAbsolutePercentageError"
            value {
              number_value: 9.991856
            }
          }
          fields {
            key: "rSquared"
            value {
              number_value: 0.39691383
            }
          }
          fields {
            key: "rootMeanSquaredError"
            value {
              number_value: 0.24697715
            }
          }
          fields {
            key: "rootMeanSquaredLogError"
            value {
              number_value: 0.10037828
            }
          }
        }
      }
      create_time {
        seconds: 1632106497
        nanos: 416614000
      }
      model_explanation {
        mean_attributions {
          feature_attributions {
            struct_value {
              fields {
                key: "Age"
                value {
                  number_value: 0.033690840005874634
                }
              }
              fields {
                key: "Balance"
                value {
                  number_value: 0.021756498143076897
                }
              }
              fields {
                key: "Campaign"
                value {
                  number_value: 0.03156016394495964
                }
              }
              fields {
                key: "Contact"
                value {
                  number_value: 0.09849491715431213
                }
              }
              fields {
                key: "Day"
                value {
                  number_value: 0.08989512920379639
                }
              }
              fields {
                key: "Default"
                value {
                  number_value: 0.00012870959471911192
                }
              }
              fields {
                key: "Duration"
                value {
                  number_value: 0.3097792863845825
                }
              }
              fields {
                key: "Education"
                value {
                  number_value: 0.01789841242134571
                }
              }
              fields {
                key: "Housing"
                value {
                  number_value: 0.05525226518511772
                }
              }
              fields {
                key: "Job"
                value {
                  number_value: 0.010000345297157764
                }
              }
              fields {
                key: "Loan"
                value {
                  number_value: 0.00856288243085146
                }
              }
              fields {
                key: "MaritalStatus"
                value {
                  number_value: 0.01715957187116146
                }
              }
              fields {
                key: "Month"
                value {
                  number_value: 0.22002224624156952
                }
              }
              fields {
                key: "PDays"
                value {
                  number_value: 0.026749607175588608
                }
              }
              fields {
                key: "POutcome"
                value {
                  number_value: 0.05268073454499245
                }
              }
              fields {
                key: "Previous"
                value {
                  number_value: 0.00636840146034956
                }
              }
            }
          }
        }
      }
    }
    >
    

    From "Evaluate" page:

    enter image description here