haskellfunctional-programmingthunk

How Haskell's thunks are so efficient?


The following Haskell implementation of Fibonacci runs in linear time:

fib = 0 : 1 : zipWith (+) fib (tail fib)

From what I understand fib calls are thunks here, so they are evaluated progressively and lazily.

However what really surprise me is that it's running in linear time. Coming from a procedural language if I look at all the inner fib calls they suppose to be called recursively and run around exponential time.

I've tried a somewhat equivalent version in Ruby:

def fib
  Enumerator.new do |yielder|
    yielder.to_proc.call(0)
    yielder.to_proc.call(1)

    gen_a = fib
    gen_b = fib
    gen_b.next()

    while true
      yielder.to_proc.call(gen_a.next() + gen_b.next())
    end
  end
end

gen = fib
20.times { puts(gen.next()) }

and Python:

def fib():
    yield 0
    yield 1

    lhs = fib()
    rhs = fib()
    rhs.__next__()

    while True:
        yield (lhs.__next__() + rhs.__next__())


gen = fib()
for _ in range(0, 20):
    print(gen.__next__())

They both are lazy and run in exponential time as I'd expect.

It's possible it's my implementation that is wrong, however I cannot help but wonder if thunks leverage the functional immutability of pure functions and it stores the values for reuse, and maybe that's why Haskell's execution can skip most recursive calls.

Is this true? Or I'm totally wrong? Is there any good material to read about it?


Solution

  • def fib
      Enumerator.new do |yielder|
        yielder.to_proc.call(0)
        yielder.to_proc.call(1)
    
        gen_a = fib
        gen_b = fib
        gen_b.next()
    
        while true
          yielder.to_proc.call(gen_a.next() + gen_b.next())
        end
      end
    end
    

    Here you're defining a function that returns an enumerator and this function is called twice by the enumerator, creating a new independent enumerator each time. Clearly this will have exponential runtime. The Haskell version doesn't do that. The Haskell version creates a single list and no functions.

    To be equivalent to the Haskell version, your Ruby version would have to get rid of the method and instead define fib as fib = Enumerator.new do ... end. However, it's not quite as easy because if we just do that, we end up with this code (which won't work):

    fib = Enumerator.new do |yielder|
      yielder << 0
      yielder << 1
    
      gen_a = fib
      gen_b = fib
      gen_b.next()
    
      while true
        yielder << gen_a.next() + gen_b.next()
      end
    end
    

    So why doesn't this work?

    1. Because now gen_a = fib and gen_b = fib aren't function calls anymore and instead set gen_a and gen_b as a reference to the same enumerator. That is closer to the Haskell version which doesn't have multiple lists/enumerators either, but the problem is that, unlike Haskell's lists, enumerators are mutable. So calling gen_b.next() now also affects gen_a and your entire logic doesn't work anymore; and
    2. Because enumerators don't allow you to access the enumerator being created while creating it.

    To get something that works like the Haskell example, you can define a Ruby version of a lazy list like this:

    class LazyList
      attr_reader :head
    
      def initialize(head, &blk)
        @head = head
        @tail = :unevaluated
        @tail_blk = blk
      end
    
      def tail
        if @tail == :unevaluated
          @tail = @tail_blk[]
        end
        @tail
      end
    
      def take(n)
        result = []
        list = self
        while n > 0 && list
          result << list.head
          list = list.tail
          n -= 1
        end
        result
      end
    
      def zip_with(other_list, &blk)
        LazyList.new(blk[head, other_list.head]) do
          tail.zip_with(other_list.tail, &blk)
        end
      end
    end
    
    fib = LazyList.new(0) do
      LazyList.new(1) do
        fib.zip_with(fib.tail, &:+)
      end
    end
    
    p fib.take(42)
    

    This works in linear time (if we pretend integer operations are O(1) anyway) just like the Haskell version.


    PS: Since you've specifically asked about thunks, here's another version that uses explicit thunks just to show that thunks aren't magical:

    class Thunk
      def initialize(&blk)
        @value = :unevaluated
        @blk = blk
      end
    
      def get
        if @value == :unevaluated
          @value = @blk[]
        end
        @value
      end
    end
    
    class LazyList
      attr_reader :head
    
      def initialize(head, tail)
        @head = head
        @tail = tail
      end
    
      def tail
        @tail.get
      end
    
      def take(n)
        result = []
        list = self
        while n > 0 && list
          result << list.head
          list = list.tail
          n -= 1
        end
        result
      end
    
      def zip_with(other_list, &blk)
        LazyList.new(blk[head, other_list.head], Thunk.new do
          tail.zip_with(other_list.tail, &blk)
        end)
      end
    end
    
    fib = LazyList.new(0, Thunk.new do
      LazyList.new(1, Thunk.new do
        fib.zip_with(fib.tail, &:+)
      end)
    end)
    
    p fib.take(42)