pythoncomputer-visionpytorchdata-generation

Pytorch Data Generator for extracting 2D images from many 3D cube


I'm struggling in creating a data generator in PyTorch to extract 2D images from many 3D cubes saved in .dat format

There is a total of 200 3D cubes each having a 128*128*128 shape. Now I want to extract 2D images from all of these cubes along length and breadth.

For example, a is a cube having size 128*128*128

So I want to extract all 2D images along length i.e., [:, i, :] which will get me 128 2D images along the length, and similarly i want to extract along width i.e., [:, :, i], which will give me 128 2D images along the width. So therefore i get a total of 256 2D images from 1 3D cube, and i want to repeat this whole process for all 200 cubes, there by giving me 51200 2D images.

So far I've tried a very basic implementation which is working fine but is taking approximately 10 minutes to run. I want you guys to help me create a more optimal implementation keeping in mind time and space complexity. Right now my current approach has a time complexity of O(n2), can we dec it further to reduce the time complexity

I'm providing below the current implementation

from os.path import join as pjoin
import torch
import numpy as np
import os
from tqdm import tqdm
from torch.utils import data


class DataGenerator(data.Dataset):

    def __init__(self, is_transform=True, augmentations=None):

        self.is_transform = is_transform
        self.augmentations = augmentations
        self.dim = (128, 128, 128)

        seismicSections = [] #Input
        faultSections = [] #Ground Truth
        for fileName in tqdm(os.listdir(pjoin('train', 'seis')), total = len(os.listdir(pjoin('train', 'seis')))):
            unrolledVolSeismic = np.fromfile(pjoin('train', 'seis', fileName), dtype = np.single) #dat file contains unrolled cube, we need to reshape it
            reshapedVolSeismic = np.transpose(unrolledVolSeismic.reshape(self.dim)) #need to transpose the axis to get height axis at axis = 0, while length (axis = 1), and width(axis = 2)

            unrolledVolFault = np.fromfile(pjoin('train', 'fault', fileName),dtype=np.single)
            reshapedVolFault = np.transpose(unrolledVolFault.reshape(self.dim))

            for idx in range(reshapedVolSeismic.shape[2]):
                seismicSections.append(reshapedVolSeismic[:, :, idx])
                faultSections.append(reshapedVolFault[:, :, idx])

            for idx in range(reshapedVolSeismic.shape[1]):
                seismicSections.append(reshapedVolSeismic[:, idx, :])
                faultSections.append(reshapedVolFault[:, idx, :])

        self.seismicSections = seismicSections
        self.faultSections = faultSections

    def __len__(self):
        return len(self.seismicSections)

    def __getitem__(self, index):

        X = self.seismicSections[index]
        Y = self.faultSections[index]

        return X, Y

Please Help!!!


Solution

  • why not storing only the 3D data in mem, and let the __getitem__ method "slice" it on the fly?

    class CachedVolumeDataset(Dataset):
      def __init__(self, ...):
        super(...)
        self._volumes_x = # a list of 200 128x128x128 volumes
        self._volumes_y = # a list of 200 128x128x128 volumes
    
      def __len__(self):
        return len(self._volumes_x) * (128 + 128)
    
      def __getitem__(self, index):
        # extract volume index from general index:
        vidx = index // (128 + 128)
        # extract slice index
        sidx = index % (128 + 128)
        if sidx < 128:
          # first dim
          x = self._volumes_x[vidx][:, :, sidx]
          y = self._volumes_y[vidx][:, :, sidx]
        else:
          sidx -= 128
          # second dim
          x = self._volumes_x[vidx][:, sidx, :]
          y = self._volumes_y[vidx][:, sidx, :]
        return torch.squeeze(x), torch.squeeze(y)