javagoogle-cloud-platformgoogle-cloud-automlautomlgcp-ai-platform-training

Google AutoMl Tables: How to set target column in the dataset using JAVA client library


As per my requirement I need to automate the whole cycle of:-

  1. taking data as a CSV in proper format from the Database.
  2. uploading it into the google bucket
  3. creating a dataset in google AutoML tables
  4. importing my uploaded CSV from bucket to the dataset
  5. training the model based on the created dataset
  6. deploying the model after training

I've successfully completed from step 1 to step 4 and I've also completed the step 6 but I'm facing issues in step 5 where I need to define what is the target column in my CSV file.

To solve this I went through the AutoMl docs but couldn't find any way of doing so via it's (AutoMl Tables) JAVA client library. Here's a code provided in the docs where we need 2 important parameters those are tableSpecId and columnSpecId below is the code from the docs:- https://cloud.google.com/automl-tables/docs/train

import com.google.cloud.automl.v1beta1.AutoMlClient;
import com.google.cloud.automl.v1beta1.ColumnSpec;
import com.google.cloud.automl.v1beta1.ColumnSpecName;
import com.google.cloud.automl.v1beta1.LocationName;
import com.google.cloud.automl.v1beta1.Model;
import com.google.cloud.automl.v1beta1.OperationMetadata;
import com.google.cloud.automl.v1beta1.TablesModelMetadata;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

class TablesCreateModel {

  public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String tableSpecId = "YOUR_TABLE_SPEC_ID";
    String columnSpecId = "YOUR_COLUMN_SPEC_ID";
    String displayName = "YOUR_DATASET_NAME";
    createModel(projectId, datasetId, tableSpecId, columnSpecId, displayName);
  }

  // Create a model
  static void createModel(
      String projectId,
      String datasetId,
      String tableSpecId,
      String columnSpecId,
      String displayName)
      throws IOException, ExecutionException, InterruptedException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // A resource that represents Google Cloud Platform location.
      LocationName projectLocation = LocationName.of(projectId, "us-central1");

      // Get the complete path of the column.
      ColumnSpecName columnSpecName =
          ColumnSpecName.of(projectId, "us-central1", datasetId, tableSpecId, columnSpecId);

      // Build the get column spec.
      ColumnSpec targetColumnSpec =
          ColumnSpec.newBuilder().setName(columnSpecName.toString()).build();

      // Set model metadata.
      TablesModelMetadata metadata =
          TablesModelMetadata.newBuilder()
              .setTargetColumnSpec(targetColumnSpec)
              .setTrainBudgetMilliNodeHours(24000)
              .build();

      Model model =
          Model.newBuilder()
              .setDisplayName(displayName)
              .setDatasetId(datasetId)
              .setTablesModelMetadata(metadata)
              .build();

      // Create a model with the model metadata in the region.
      OperationFuture<Model, OperationMetadata> future =
          client.createModelAsync(projectLocation, model);
      // OperationFuture.get() will block until the model is created, which may take several hours.
      // You can use OperationFuture.getInitialFuture to get a future representing the initial
      // response to the request, which contains information while the operation is in progress.
      System.out.format("Training operation name: %s%n", future.getInitialFuture().get().getName());
      *System*.out.println("Training started...");
    }
  }
}

From these 2 important parameters I'm able to get tableSpecId after successfully importing my CSV into the dataset (step 4) but I'm unable to get columnSpecId As it defines the correlation between the columns and as per my understanding it also defines which column is the target column.

After doing bit more research online I found that the below mentioned REST API sends request to set target column in the dataset https://automl.clients6.google.com/v1beta1/projects/[projectId]/locations/[Location]/datasets/[DatasetId]?updateMask=tablesDatasetMetadata.targetColumnSpecId&key=[auth stuff] But I'm using SpringBoot which means I'm working with AutoMl's JAVA client library. Now in the AutoML's client library I'm unable to find any method using which I can explicitly send a request telling the Google AutoMl Tables that this particular column is my target column. I guess there are methods available in the python client library but not in java. I thank you in advance.

Error(When passing Empty string in columnSpecId):-

com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
java.util.concurrent.ExecutionException: com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at com.google.common.util.concurrent.AbstractFuture.getDoneValue(AbstractFuture.java:566)
    at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:547)
    at com.google.common.util.concurrent.FluentFuture$TrustedFuture.get(FluentFuture.java:86)
    at com.google.common.util.concurrent.ForwardingFuture.get(ForwardingFuture.java:62)
    at com.realcoderz.ai.AutoMlTables.TablesCreateModel.createModel(TablesCreateModel.java:76)
    at com.realcoderz.ai.AutoMlTables.TablesCreateModel.getTablesCreateModel(TablesCreateModel.java:29)
    at com.realcoderz.ai.controller.AiMasterController.CreateModel(AiMasterController.java:105)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:64)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:564)
    at org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:197)
    at org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:141)
    at org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:106)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:894)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:808)
    at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
    at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:1060)
    at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:962)
    at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1006)
    at org.springframework.web.servlet.FrameworkServlet.doGet(FrameworkServlet.java:898)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:626)
    at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:883)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:733)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:53)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:100)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.FormContentFilter.doFilterInternal(FormContentFilter.java:93)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:201)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:202)
    at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:97)
    at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:542)
    at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:143)
    at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:92)
    at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:78)
    at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:343)
    at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:374)
    at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:65)
    at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:888)
    at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1597)
    at org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
    at java.base/java.lang.Thread.run(Thread.java:832)
Caused by: com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:49)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:72)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:60)
    at com.google.api.gax.grpc.GrpcExceptionCallable$ExceptionTransformingFuture.onFailure(GrpcExceptionCallable.java:97)
    at com.google.api.core.ApiFutures$1.onFailure(ApiFutures.java:68)
    at com.google.common.util.concurrent.Futures$CallbackListener.run(Futures.java:1041)
    at com.google.common.util.concurrent.DirectExecutor.execute(DirectExecutor.java:30)
    at com.google.common.util.concurrent.AbstractFuture.executeListener(AbstractFuture.java:1215)
    at com.google.common.util.concurrent.AbstractFuture.complete(AbstractFuture.java:983)
    at com.google.common.util.concurrent.AbstractFuture.setException(AbstractFuture.java:771)
    at io.grpc.stub.ClientCalls$GrpcFuture.setException(ClientCalls.java:563)
    at io.grpc.stub.ClientCalls$UnaryStreamToFuture.onClose(ClientCalls.java:533)
    at io.grpc.internal.DelayedClientCall$DelayedListener$3.run(DelayedClientCall.java:464)
    at io.grpc.internal.DelayedClientCall$DelayedListener.delayOrExecute(DelayedClientCall.java:428)
    at io.grpc.internal.DelayedClientCall$DelayedListener.onClose(DelayedClientCall.java:461)
    at io.grpc.internal.ClientCallImpl.closeObserver(ClientCallImpl.java:617)
    at io.grpc.internal.ClientCallImpl.access$300(ClientCallImpl.java:70)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInternal(ClientCallImpl.java:803)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInContext(ClientCallImpl.java:782)
    at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
    at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:123)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
    at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
    at java.base/java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:304)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    ... 1 more
Caused by: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at io.grpc.Status.asRuntimeException(Status.java:533)
    ... 16 more

Here are 2 images for comparison in Image 1 i've sent request via my springboot app where I left columnSpecId Blank and in image 2 I manually set target and trained the model in google cloud console UI Image 1 Image 2


Solution

  • You can get the columnSpecId by calling ListColumnSpecs on the TableSpec: https://googleapis.dev/java/google-cloud-clients/latest/com/google/cloud/automl/v1beta1/AutoMlClient.html#listColumnSpecs-com.google.cloud.automl.v1beta1.ListColumnSpecsRequest-

    Something like this:

    TableSpecName parent = TableSpecName.of("[PROJECT]", "[LOCATION]", "[DATASET]", "[TABLE_SPEC]")
    ColumnSpec targetColumnSpec;
    for (ColumnSpec element : autoMlClient.listColumnSpecs(parent).iterateAll()) {
      if (element.getDisplayName().equals("[MY_TARGET_COLUMN]")) {
        targetColumnSpec = element;
        break;
      }
    }