scalashapelessscala-catskind-projector

Extend NaturalTransformation of Coproducts


I have

F ~> H
G ~> H

Where ~> is cats.NaturalTransformation.

I'm able to construct a

λ[A => F[A] :+: G[A] :+: CNil] ~> H

Using the kind-projector syntax for readability

Here's how I'm doing it

def or[G[_]](g: G ~> H): λ[A => F[A] :+: G[A] :+: CNil] ~> H =
  new (λ[A => F[A] :+: G[A] :+: CNil] ~> H) {
    def apply[A](fa: F[A] :+: G[A] :+: CNil): H[A] =
      (fa.select[F[A]], fa.select[G[A]]) match {
        case (Some(ff), None) => f(ff)
        case (None, Some(gg)) => g(gg)
        // this can't happen, due to the definition of Coproduct
        case _ => throw new Exception("Something is wrong")
  }
}

This works, although I'm open to suggestions as it doesn't look pretty.

Now, if I have

λ[A => F[A] :+: G[A] :+: CNil] ~> H
K ~> H

I should also be able to construct a

λ[A => F[A] :+: G[A] :+: K[A] :+: CNil] ~> H

and here's where I got stuck. I tried using ExtendRight from shapeless, but I can't get to make it work. Here's my attempt:

def or[F[_] <: Coproduct, G[_], H[_], FG[_] <: Coproduct](f: F ~> H, g: G ~> H)(
  implicit e: ExtendRight.Aux[F[_], G[_], FG[_]]
): FG ~> H = new (FG ~> H) {
  def apply[A](fg: FG[A])(implicit
    sf: Selector[FG[A], F[A]],
    sg: Selector[FG[A], G[A]]
  ): H[A] =
    (fg.select[F[A]], fg.select[G[A]]) match {
      case (Some(ff), None) => f(ff)
      case (None, Some(gg)) => g(gg)
      // this can't happen, due to the definition of Coproduct
      case _ => throw new Exception("Something is wrong")
    }

}

However the compiler can't find an implicit evidence for the ExtendRight parameter.

Here's a MWE to play with

import shapeless._
import shapeless.ops.coproduct._
import cats.~>

object Bar {
  val optionToList = new (Option ~> List) {
    def apply[A](x: Option[A]): List[A] = x match {
      case None => Nil
      case Some(a) => List(a)
    }
  }

  val idToList = new (Id ~> List) {
    def apply[A](x: Id[A]): List[A] = List(x)
  }

  val tryToList = new (scala.util.Try ~> List) {
    def apply[A](x: scala.util.Try[A]): List[A] = x match {
      case scala.util.Failure(_) => Nil
      case scala.util.Success(a) => List(a)
    }
  }

  type OI[A] = Option[A] :+: Id[A] :+: CNil
  val optionAndId: OI ~> List = Foo.or(optionToList, idToList)
  val all = Foo.or2(optionAndId, tryToList)

}

object Foo {
  def or[F[_], G[_], H[_]](f: F ~> H, g: G ~> H): λ[A => F[A] :+: G[A] :+: CNil] ~> H =
    new (λ[A => F[A] :+: G[A] :+: CNil] ~> H) {
      def apply[A](fa: F[A] :+: G[A] :+: CNil): H[A] =
        (fa.select[F[A]], fa.select[G[A]]) match {
          case (Some(ff), None) => f(ff)
          case (None, Some(gg)) => g(gg)
          // this can't happen, due to the definition of Coproduct
          case _ => throw new Exception("Something is wrong, most likely in the type system")
        }
    }

  def or2[F[_] <: Coproduct, G[_], H[_], FG[_] <: Coproduct](f: F ~> H, g: G ~> H)(implicit
    e: ExtendRight.Aux[F[_], G[_], FG[_]]
    ): FG ~> H = new (FG ~> H) {
      def apply[A](fg: FG[A])(implicit
        sf: Selector[FG[A], F[A]],
        sg: Selector[FG[A], G[A]]
      ): H[A] =
        (fg.select[F[A]], fg.select[G[A]]) match {
          case (Some(ff), None) => f(ff)
          case (None, Some(gg)) => g(gg)
          // this can't happen, due to the definition of Coproduct
          case _ => throw new Exception("Something is wrong, most likely in the type system")
        }
  }
}

Solution

  • Sorry I didn't get around to posting this the other day, but I think it does what you want. The trick is to arrange things in such a way that you don't even need the selectors.

    import shapeless._
    import shapeless.ops.coproduct._
    import cats.~>
    
    def or[F[_], G[_], H[_]](
      f: F ~> H,
      g: G ~> H
    ): ({ type L[x] = F[x] :+: G[x] :+: CNil })#L ~> H =
      new (({ type L[x] = F[x] :+: G[x] :+: CNil })#L ~> H) {
        object fg extends Poly1 {
          implicit def atF[A]: Case.Aux[F[A], H[A]] = at(f(_))
          implicit def atG[A]: Case.Aux[G[A], H[A]] = at(g(_))
        }
    
        def apply[A](c: F[A] :+: G[A] :+: CNil): H[A] = c.fold(fg)
      }
    
    def or2[F[_], G[_] <: Coproduct, H[_]](
      f: F ~> H,
      g: G ~> H
    ): ({ type L[x] = F[x] :+: G[x] })#L ~> H =
      new (({ type L[x] = F[x] :+: G[x] })#L ~> H) {
        def apply[A](c: F[A] :+: G[A]): H[A] = c match {
          case Inl(fa) => f(fa)
          case Inr(ga) => g(ga)
        }
      }
    

    (Note that I'm using a Poly1 in or to avoid the need for handling the CNil case with an exception.)

    Now you can write this:

    val optionToList = new (Option ~> List) {
      def apply[A](x: Option[A]): List[A] = x.fold[List[A]](Nil)(List(_))
    }
    
    val idToList = new (Id ~> List) {
      def apply[A](x: Id[A]): List[A] = List(x)
    }
    
    val tryToList = new (scala.util.Try ~> List) {
      def apply[A](x: scala.util.Try[A]): List[A] = x match {
        case scala.util.Failure(_) => Nil
        case scala.util.Success(a) => List(a)
      }
    }
    

    And then:

    scala> type OI[A] = Option[A] :+: Id[A] :+: CNil
    defined type alias OI
    
    scala> val optionAndId: OI ~> List = or(optionToList, idToList)
    optionAndId: cats.~>[OI,List] = $anon$1@55224c4a
    
    scala> val all = or2(tryToList, optionAndId)
    all: cats.~>[[x]shapeless.:+:[scala.util.Try[x],OI[x]],List] = $anon$2@536a993
    
    scala> all(Inl(scala.util.Try('foo)))
    res8: List[Symbol] = List('foo)
    
    scala> all(Inr(Inl(Option('foo))))
    res9: List[Symbol] = List('foo)
    
    scala> all(Inr(Inr(Inl('foo))))
    res10: List[Symbol] = List('foo)
    

    (You could also of course use Coproduct[...](Option('foo)), etc. if you don't mind writing out the type.)