torchvisionmask-rcnn

Problem with MaskRCNN ("NotImplementedError") which seems to be related to torchvision.transforms.v2._transform


I am learning MaskRCNN and to this end, I startet to follow this tutorial step by step.

Everything is working fine until I reach the block entitled "Test the transforms" which reads

# Extract the labels for the sample
    labels = [shape['label'] for shape in annotation_df.loc[file_id]['shapes']]
    # Extract the polygon points for segmentation mask
    shape_points = [shape['points'] for shape in annotation_df.loc[file_id]['shapes']]
    # Format polygon points for PIL
    xy_coords = [[tuple(p) for p in points] for points in shape_points]
    # Generate mask images from polygons
    mask_imgs = [create_polygon_mask(sample_img.size, xy) for xy in xy_coords]
    # Convert mask images to tensors
    masks = torch.concat([Mask(transforms.PILToTensor()(mask_img), dtype=torch.bool) for mask_img in mask_imgs])
    # Generate bounding box annotations from segmentation masks
    bboxes = BoundingBoxes(data=torchvision.ops.masks_to_boxes(masks), format='xyxy', canvas_size=sample_img.size[::-1])
    
    # Get colors for dataset sample
    sample_colors = [int_colors[i] for i in [class_names.index(label) for label in labels]]
    
    # Prepare mask and bounding box targets
    targets = {
        'masks': Mask(masks), 
        'boxes': bboxes, 
        'labels': torch.Tensor([class_names.index(label) for label in labels])
    }
    
    # Crop the image
    cropped_img, targets = iou_crop(sample_img, targets)
    
    # Resize the image
    resized_img, targets = resize_max(cropped_img, targets)
    
    # Pad the image
    padded_img, targets = pad_square(resized_img, targets)
    
    # Ensure the padded image is the target size
    resize = transforms.Resize([train_sz] * 2, antialias=True)
    resized_padded_img, targets = resize(padded_img, targets)
    sanitized_img, targets = transforms.SanitizeBoundingBoxes()(resized_padded_img, targets)
    
    # Annotate the sample image with segmentation masks
    annotated_tensor = draw_segmentation_masks(
        image=transforms.PILToTensor()(sanitized_img), 
        masks=targets['masks'], 
        alpha=0.3, 
        colors=sample_colors
    )
    
    # Annotate the sample image with labels and bounding boxes
    annotated_tensor = draw_bboxes(
        image=annotated_tensor, 
        boxes=targets['boxes'], 
        labels=[class_names[int(label.item())] for label in targets['labels']], 
        colors=sample_colors
    )
    
    # # Display the annotated image
    display(tensor_to_pil(annotated_tensor))
    
    pd.Series({
        "Source Image:": sample_img.size,
        "Cropped Image:": cropped_img.size,
        "Resized Image:": resized_img.size,
        "Padded Image:": padded_img.size,
        "Resized Padded Image:": resized_padded_img.size,
    }).to_frame().style.hide(axis='columns')

When executing this block I get the following error message which I am not able to understand or to repair:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call 
last)
Cell In[33], line 28
      25 cropped_img, targets = iou_crop(sample_img, targets)
      27 # Resize the image
---> 28 resized_img, targets = resize_max(cropped_img, targets)
      30 # Pad the image
      31 padded_img, targets = pad_square(resized_img, targets)

File
~\anaconda3\envs\test-env\lib\site-packages\torch\nn\modules\module.py:1739,
in Module._wrapped_call_impl(self, *args, **kwargs)
    1737     return self._compiled_call_impl(*args, **kwargs)  # type: 
ignore[misc]
    1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File
~\anaconda3\envs\test-env\lib\site-packages\torch\nn\modules\module.py:1750,
in Module._call_impl(self, *args, **kwargs)
    1745 # If we don't have any hooks, we want to skip the rest of the logic in
    1746 # this function, and just call forward.
    1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
    1748         or _global_backward_pre_hooks or _global_backward_hooks
    1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
    1752 result = None
    1753 called_always_called_hooks = set()

File
~\anaconda3\envs\test-env\lib\site-packages\torchvision\transforms\v2\_transform.py:68,
in Transform.forward(self, *inputs)
      63 needs_transform_list = self._needs_transform_list(flat_inputs)
      64 params = self.make_params(
      65     [inpt for (inpt, needs_transform) in zip(flat_inputs, 
needs_transform_list) if needs_transform]
      66 )
---> 68 flat_outputs = [
      69     self.transform(inpt, params) if needs_transform else inpt
      70     for (inpt, needs_transform) in zip(flat_inputs, 
needs_transform_list)
      71 ]
      73 return tree_unflatten(flat_outputs, spec)

File
~\anaconda3\envs\test-env\lib\site-packages\torchvision\transforms\v2\_transform.py:69,
in <listcomp>(.0)
      63 needs_transform_list = self._needs_transform_list(flat_inputs)
      64 params = self.make_params(
      65     [inpt for (inpt, needs_transform) in zip(flat_inputs, 
needs_transform_list) if needs_transform]
      66 )
      68 flat_outputs = [
---> 69     self.transform(inpt, params) if needs_transform else inpt
      70     for (inpt, needs_transform) in zip(flat_inputs, 
needs_transform_list)
      71 ]
      73 return tree_unflatten(flat_outputs, spec)

File
~\anaconda3\envs\test-env\lib\site-packages\torchvision\transforms\v2\_transform.py:55,
in Transform.transform(self, inpt, params)
      51 def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
      52     """Method to override for custom transforms.
      53
      54     See 
:ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`"""
---> 55     raise NotImplementedError

NotImplementedError:

This problem seems to be related to torchvision.transforms.v2._transform which contains the following:

def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        """Method to override for custom transforms.

        See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`"""
        raise NotImplementedError

Does anybody have an idea what to do?


Solution

  • That's the version issue of pytorch that I found.

    You must downgrade pytorch version.

    cjm_torchvision_tfms==0.0.26 or 0.0.25 versions will try to upgrade torch, so I downgrade that.
    

    Below settings are good to me.

    python 3.11.3
    pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
    pip install matplotlib pandas pillow torchtnt==0.2.0 tqdm tabulate
    pip install distinctipy
    pip install cjm_pandas_utils cjm_psl_utils cjm_pil_utils cjm_pytorch_utils cjm_torchvision_tfms==0.0.24 ipywidgets albumentations
    pip install jupyterlab