scalafor-loopfor-comprehension

Understanding for in scala


I run the following code to understand the functioning of the for comprehensions in scala

def f(): List[(Int, Int)] = {
  val out = for 
    i<- 1 to 4
    p1 = println(s"i $i")
    j<- 2 to 3
    p2 = println(s"j $j")
  yield (i,j)
  println(out)
  out.toList
}

The output that I see tells me that maybe it is not a straightforward nested loop applicaion

output:

i 1
i 2
i 3
i 4
j 2
j 3
j 2
j 3
j 2
j 3
j 2
j 3
Vector((1,2), (1,3), (2,2), (2,3), (3,2), (3,3), (4,2), (4,3))

Similar behavior is apparent in the case of for loops as well

def f(): Unit = {
for
  i <- List(10,20)
  p1 = println(s"i $i")
  j <- List(50, 60)
  p2 = println(s"i $i, j $j")
do

  {
    i+j
  }
}

output:

i 10
i 20
i 10, j 50
i 10, j 60
i 20, j 50
i 20, j 60

So I do not suppose it is a case of the way for expressions are implemented in contrast to the for loops as suggested in the comments

The final vector output seems to be fine but the inline printlns tells a different story. It is as if the outer loop executes entirely once and then, the inner loop executes the outer loop number of times again. Can you help me understand what is happening here?


Solution

  • Preface: I realize I've written this as if the reader already knows for comprehensions translate to chained flatmap/filter/map/foreach calls. I won't explain that here, because that's not the heart of the issue. Instead, I'll leave this doc for reference: https://docs.scala-lang.org/tour/for-comprehensions.html

    So, the confusing issue here is how the println/assignments statements are desugared. The flatMap and map from the for comprehension execute the way you expect, as evidenced by the final output.

    What happens is that the assignment statements execute only once per generator, earlier than where you might expect them to. The desugared code looks more like this:

    (1 to 4).map(i => {println(s"i $i"); i}).flatMap { i =>
      (2 to 3).map(j => {println(s"j $j"); j}).map { j =>
        (i, j)
      }
    }
    

    Notice the extra map calls after the ranges. (I'm cutting corners here, because the desugaring is a lot more complicated, but the extra .map is the main point I want to get across here.) This code would output the same result as your for comprehension.

    This would be as if you had written you for comprehension this way (again, not accurate, but a useful mental model):

    for
      i <- (1 to 4).tapEach(i => println(s"i $i"))
      j <- (2 to 3).tapEach(j => println(s"j $j"))
    yield (i, j)
    

    I should also note that another answer claimed that the desugaring would look like so:

    // This is incorrect!!!
    (1 to 4).flatMap { i =>
      val p1 = println(s"i $i")
      (2 to 3).map { j =>
        val p2 = println(s"j $j")
        (i, j)
      }
    }
    

    This, while sometimes being a useful way to think of the desugaring, is incorrect, and would lead to the result you were probably orignally expecting, rather than the result you actually received.

    As a final note, if you would like to see the desugaring for yourself in the repl, you can start it up with scala -Xprint:typer It'll be a lot noiser than what I've posted, but it'll be roughly the same idea.

    Without the println, you probably wouldn't have found anything to be suprising. The moral of the story is to be careful about side effects inside for comprehensions ;)

    Addendum:

    Here's the full REPL output when you run your for comprehension with -Xprint:typer:

    scala> for
         |   i <- 1 to 4
         |   p1 = println(s"i $i")
         |   j <- 2 to 3
         |   p2 = println(s"j $j")
         | yield (i,j)
    i 1
    i 2
    i 3
    i 4
    j 2
    j 3
    j 2
    j 3
    j 2
    j 3
    j 2
    j 3
    [[syntax trees at end of                     typer]] // rs$line$1
    package <empty> {
      final lazy module val rs$line$1: rs$line$1 = new rs$line$1()
      final module class rs$line$1() extends Object() { this: rs$line$1.type =>
        val res0: IndexedSeq[(Int, Int)] =
          intWrapper(1).to(4).map[(Int, Unit)](
            {
              def $anonfun(i: Int): (Int, Unit) =
                {
                  val p1: Unit = println(_root_.scala.StringContext.apply(["i ","" : String]*).s([i : Any]*))
                  Tuple2.apply[Int, Unit](i, p1)
                }
              closure($anonfun)
            }
          ).flatMap[(Int, Int)](
            {
              def $anonfun(x$1: (Int, Unit)): IterableOnce[(Int, Int)] =
                x$1:(x$1 : (Int, Unit)) @unchecked match
                  {
                    case Tuple2.unapply[Int, Unit](i @ _, p1 @ _) =>
                      intWrapper(2).to(3).map[(Int, Unit)](
                        {
                          def $anonfun(j: Int): (Int, Unit) =
                            {
                              val p2: Unit = println(_root_.scala.StringContext.apply(["j ","" : String]*).s([j : Any]*))
                              Tuple2.apply[Int, Unit](j, p2)
                            }
                          closure($anonfun)
                        }
                      ).map[(Int, Int)](
                        {
                          def $anonfun(x$1: (Int, Unit)): (Int, Int) =
                            x$1:(x$1 : (Int, Unit)) @unchecked match
                              {
                                case Tuple2.unapply[Int, Unit](j @ _, p2 @ _) => Tuple2.apply[Int, Int](i, j)
                              }
                          closure($anonfun)
                        }
                      )
                  }
              closure($anonfun)
            }
          )
      }
    }
    
    val res0: IndexedSeq[(Int, Int)] = Vector((1,2), (1,3), (2,2), (2,3), (3,2), (3,3), (4,2), (4,3))