scalaenumeration

How to write a function defined for a subset of all enumeration values?


Suppose I've got enumeration A :

object A extends Enumeration { type A = Int; val A1 = 1; val A2 = 2; val A3 = 3 }

Also, I have a function defined only for A1 or A2 but not for A3.

def foo(a: A): Int = a match { 
  case A.A1 => 1 // do something 
  case A.A2 => 2 // do something else 
  case A.A3 => throw new UnsupportedOperationException() 
}

Now I would like to get a compilation error for foo(A.A3). In pseudocode I define foo like this:

def foo(a: A1 | A2): Int = ???

How would you suggest write foo to prevent calling it with A.A3 ?


Solution

  • The given pseudo code is on the right track; the supported types need to be in the type signature of foo and unsupported types not.

    // Scala 3
    
    object A extends Enumeration {
      type A = Int
      val A1 = 1
      val A2 = 2
      val A3 = 3
    }
    
    def foo(a: A.A1.type | A.A2.type): Int = 1
    
    @main
    def main(): Unit = {
      foo(A.A1)
      foo(A.A3)
    }
    

    Compilation error:

    Found:    (A.A3 : Int)
    Required: (A.A1 : Int) | (A.A2 : Int)
      foo(A.A3)
    

    To make this possible in scala 2 (literal types needed, so only scala 2.13 or Typelevel's 2.12) we need give a more refined type to the A1-A3 constants, and an implicit conversion to the 'union' type A12:

    object A extends Enumeration {
      type A = Int
      val A1: 1 = 1
      val A2: 2 = 2
      val A3: 3 = 3
    
      sealed trait A12 { val a: A }
      private case class AnA12(a: A) extends A12
      implicit def a1ToA12(a1: 1): A12 = AnA12(a1)
      implicit def a2ToA12(a2: 2): A12 = AnA12(a2)
    }
    
    object Foo {
      import A._
    
      def foo(a: A12): Int = a.a
    
      def main(args: Array[String]): Unit = {
        foo(A1)
        foo(A3)
      }
    }
    

    Compilation fails with:

    type mismatch;
     found   : 3
     required: Foo.A12
        foo(A3)
    

    AnA12 is private to make it impossible to use foo(AnA12(A3)).