scalamacrosscala-3

Check a list contains an instance of each possible case class


In scala 3, say that I have code such as:

sealed trait Shape

final case class Circle(radius:Float) extends Shape
final case class Square(side:Float) extends Shape
final case class Rectangle(width: Float, height: Float) extends Shape

How can I define a method such as:

def listContainsAtLeastOneOfEach(shapes:List[Shape]): Boolean = ???

that returns true when the passed in list contains at least one instance of each descendant, and false if some descendant is missing, without having to manually maintain a list of the descendants within the method.

I'm trying already to do something with scala.deriving.Mirror.SumOf but I can't pin it down. I'm open to using macros if necessary, though I feel this should be doable with just inlining.


Solution

  • You could indeed use scala.deriving.Mirror.SumOf along with its ordinal() method:

    import scala.compiletime.constValue
    import scala.deriving.*
    
    inline def listContainsAtLeastOneOfSubtype[T](list: List[T])(using m: Mirror.SumOf[T]): Boolean =
      val ordinals: Set[Int] = list.foldLeft(Set.empty) { (ords, elem) =>
        ords + m.ordinal(elem)
      }
      val nSubtypes = constValue[Tuple.Size[m.MirroredElemTypes]]
      ordinals.size == nSubtypes
    

    (Of course, the above still can be optimized, e.g. by stopping to iterate over the list once we know it will return true.)