fortran

Element-wise operations on arrays of derived type objects with overloaded operators


I am defining derived types in Fortran and overloading some operators (for example the + operator). This works well when working with scalar objects and I was (optimistically) hoping that this would seamlessly work with arrays of these objects, but it does not.

In the example below, I define a derived type for points in the 2D plane, and I overload the + operator to define what the sum of two points should be (this operation does not need to make sense mathematically -- this is just a simple example...). This works well with individual points, but not with arrays of points.

Error messages when compiling are:

What is the best/simplest approach to get this to work with arrays of points?

module points

  implicit none

  type :: t_point
    integer :: x, y
  contains
    procedure :: add
    generic :: operator(+) => add
  end type t_point

contains

  pure type(t_point) function add(self, other)
    class(t_point), intent(in) :: self, other
    add = t_point(self%x + other%x, self%y + other%y)
  end function add

end module points

program test

  use points

  implicit none

  type(t_point) :: p1, p2, p3

  p1 = t_point(0, 1)
  p2 = t_point(1, 4)
  p3 = t_point(5, -1)

  print*, p1 + p1 ! this works: (0, 2)                                                                                                                             
  print*, p1 + p2 ! this works: (1, 5)                                                                                                                             
  print*, p1 + p3 ! this works: (5, 0)                                                                                                                             
  print*, p2 + p3 ! this works: (6, 3)                                                                                                                             

  print*, (/1, 2, 3/) + (/10, 20, 30/) ! this works with intrinsic data types: (11, 22, 33)                                                                        

  print*, (/p1, p2, p3/) + (/p3, p3, p3/) ! compilers do not like this line                                                                                        

end program test


Solution

  • The (pure) function defined by

      pure type(t_point) function add(self, other)
        class(t_point), intent(in) :: self, other
        add = t_point(self%x + other%x, self%y + other%y)
      end function add
    

    when used as an operator, operates on scalar left- and right-hand sides.

    The left- and right-hand sides in (/p1, p2, p3/) + (/p3, p3, p3/) are arrays, not scalars: the function add won't be used.

    You need to define a function which accepts array arguments. A natural way for + is to make add an elemental function:

      elemental type(t_point) function add(self, other)
        class(t_point), intent(in) :: self, other
        add = t_point(self%x + other%x, self%y + other%y)
      end function add
    

    (which is pure as it's not marked as impure). Such an add accepts two scalar arguments and two array arguments of the same shape.

    Note, of course, that this change to add for the operator + is exactly that which allows add([p1,p2,p3],[p3,p3,p3]) and add(p1,p3).