scalafs2

Grouping fs2 streams into sub-streams based on predicate


I need a combinator that solves the following problem:

test("groupUntil") {
  val s = Stream(1, 2, 3, 4, 1, 2, 3, 3, 1, 2, 2).covary[IO]
  val grouped: Stream[IO, Stream[IO, Int]] = s.groupUntil(_ == 1)

  val result =
    for {
      group   <- grouped
      element <- group.fold(0)(_ + _)
  } yield element

  assertEquals(result.compile.toList.unsafeRunSync(), List(10, 9, 5))
}

The inner streams must also be lazy. (note, groupUntil is the imaginary combinator I'm asking for).

NOTE: I must deal with every element of the internal stream as soon as they arrive at the original stream, i.e. I cannot wait to chunk an entire group.


Solution

  • One way you can achieve laziness here is using Stream as container in fold function:

    import cats.effect.IO
    import fs2.Stream
    
    val s = Stream(1, 2, 3, 4, 1, 2, 3, 3, 1, 2, 2).covary[IO]
    val acc: Stream[IO, Stream[IO, Int]] = Stream.empty
    val grouped: Stream[IO, Stream[IO, Int]] = s.fold(acc) {
      case (streamOfStreams, nextInt) if nextInt == 1 =>
        Stream(Stream(nextInt).covary[IO]).append(streamOfStreams)
      case (streamOfStreams, nextInt) =>
        streamOfStreams.head.map(_.append(Stream(nextInt).covary[IO])) ++ 
          streamOfStreams.tail
    }.flatten
    
    val result: Stream[IO, IO[Int]] = for {
      group <- grouped
      element = group.compile.foldMonoid
    } yield element
    assertEquals(result.map(_.unsafeRunSync()).compile.toList.unsafeRunSync().reverse, List(10, 9, 5))
    

    be careful, in result you will get reversed stream, because it's not good idea to work with the last element of the stream, better way is taking head but it requires us to reverse list in the end of our processing.

    Another way is use groupAdjacentBy and group elements by some predicate:

    val groupedOnceAndOthers: fs2.Stream[IO, (Boolean, Chunk[Int])] = 
      s.groupAdjacentBy(x => x == 1)
    

    here you will get groups with pairs:

    (true,Chunk(1)), (false,Chunk(2, 3, 4)), 
    (true,Chunk(1)), (false,Chunk(2, 3, 3)), 
    (true,Chunk(1)), (false,Chunk(2, 2))
    

    to concat groups with 1 and without we can use chunkN (like grouped in scala List) and map result to get rid of boolean pairs and flatMap to flatten Chunks:

    val grouped = groupedOnceAndOthers
      .chunkN(2, allowFewer = true)
      .map(ch => ch.flatMap(_._2).toList)
    

    result grouped is: List(1, 2, 3, 4) List(1, 2, 3, 3) List(1, 2, 2)

    full working sample:

    import cats.effect.IO
    import fs2.Stream
    
    val s = Stream(1, 2, 3, 4, 1, 2, 3, 3, 1, 2, 2).covary[IO]
    val grouped: Stream[IO, Stream[IO, Int]] = s.groupAdjacentBy(x => x == 1)
      .chunkN(2, allowFewer = true)
      .map(ch => Stream.fromIterator[IO](ch.flatMap(_._2).iterator))
    
    val result: Stream[IO, IO[Int]] = for {
      group <- grouped
      element = group.compile.foldMonoid
    } yield element
    assertEquals(result.map(_.unsafeRunSync()).compile.toList.unsafeRunSync(), List(10, 9, 5))