machine-learningdeep-learningcomputer-visionpytorchskorch

Passing tensorDataset or Dataloader to skorch


I want to apply cross validation in Pytorch using skorch, so I prepared my model and my tensorDataset which returns (image,caption and captions_length) and so it has X and Y, so I'll not be able to set Y in the method

net.fit(dataset)

but when I tried that I got an error :

ValueError: Stratified CV requires explicitly passing a suitable y

Here's part of my code:

start = time.time()
net = NeuralNetClassifier(
        decoder, criterion= nn.CrossEntropyLoss,
        max_epochs=args.epochs,
        lr=args.lr,
        optimizer=optim.SGD,
        device='cuda',  # uncomment this to train with CUDA
       )
net.fit(dataset, y=None)
end = time.time()

Solution

  • You are (implicitly) using the internal CV split of skorch which uses a stratified split in case of the NeuralNetClassifier which in turn needs information about the labels beforehand.

    When passing X and y to fit separately this works fine since y is accessible at all times. The problem is that you are using torch.dataset.Dataset which is lazy and does not give you access to y directly, hence the error.

    Your options are the following.

    Set train_split=None to disable the internal CV split

    net = NeuralNetClassifier(
        train_split=None,
    )
    

    You will lose internal validation and, as such, features like early stopping.

    Split your data beforehand

    Split your dataset into two datasets, dataset_train and dataset_valid, then use skorch.helper.predefined_split:

    net = NeuralNetClassifier(
        train_split=predefined_split(dataset_valid),
    )
    

    You lose nothing but depending on your data this might be complicated.

    Extract your y and pass it to fit

    y_train = np.array([y for X, y in iter(my_dataset)])
    net.fit(my_dataset, y=y_train)
    

    This only works if your y fits into memory. Since you are using TensorDataset you can also do the following to extract your y:

    y_train = my_dataset.y