pythondatabaseamazon-s3pytorchamazon-sagemaker

S3 Torchconnector loads data as a list of tensors


I'm setting up a model and about to start training a dataset from an S3 bucket. To load the data from S3 I'm using s3torchconnector.S3MapDataset.from_prefix which loads the data into the Sagemaker space I'm using. However, when I start to training, this S3torchconnector gives a list of tensors (or a Map dataset) instead of a tensor to be processed by samples = samples.to(device, non_blocking=True). As expected I'm getting the Error 'list' object has no attribute 'to'.

Here some snippets of my code:

def load_image(object):
   img = Image.open(object)
   resize = transforms.Resize(size=(224, 224))
   img = resize(img)
   img = transforms.functional.pil_to_tensor(img)
   return (object.key, torchvision.transforms.functional.convert_image_dtype(img, dtype=torch.float32))


dataset_train = s3torchconnector.S3MapDataset.from_prefix(
   args.IMAGES_URI, 
   region=args.REGION, 
   transform=load_image,
) # here apparently s3torchconnector is giving a list of tensors

data_loader_train = torch.utils.data.DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=True,
)

# define the model
model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)

model.to(device)

for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
        data_loader_train.sampler.set_epoch(epoch)
    train_stats = train_one_epoch(
        model, data_loader_train,
        optimizer, device, epoch, loss_scaler,
        log_writer=log_writer,
        args=args
    )

This code works well with a custom dataset and local files in my Sagemaker space, but I cannot implement that way since I need to train millions of images, hence using an S3 dataloader. I'm not sure if S3torchconnector just works this way, or if it is possible to transform whatever it loads into a single Tensor to be processed by the Model and the sampler.


Solution

  • The way this problem was solved was using the following structure in a side code which load the sampler:

    if type(samples)==list:
                samples = samples[1].to(device, non_blocking=True)
            else:
                samples = samples.to(device, non_blocking=True)
    

    This works either using a local stored dataset in Sagemaker or using S3torchconnector. By doing this the code was able to read and train a list.