为什么Clojure的for loop不会栈溢出

Clojure中的for loop是用macro实现的, 我发现展开之后的代码中是包含递归的, 例如下面的代码,

 
(for [a (range 500) :when (odd? a)] (* 2 a))
 

展开之后

 
(clojure.core/fn iter__10202 [s__10203]
  (clojure.core/lazy-seq
    (clojure.core/loop [s__10203 s__10203]
      (clojure.core/when-let [s__10203 (clojure.core/seq s__10203)]
        (if (clojure.core/chunked-seq? s__10203)
          ;; we can skip this, for (range x)
          (clojure.core/let [c__4633__auto__ (clojure.core/chunk-first s__10203) size__4634__auto__ (clojure.core/int (clojure.core/count c__4633__auto__)) b__10205 (clojure.core/chunk-buffer size__4634__auto__)] (if (clojure.core/loop [i__10204 (clojure.core/int 0)] (if (clojure.core/< i__10204 size__4634__auto__) (clojure.core/let [a (.nth c__4633__auto__ i__10204)] (if (odd? a) (do (clojure.core/chunk-append b__10205 (* 2 a)) (recur (clojure.core/unchecked-inc i__10204))) (recur (clojure.core/unchecked-inc i__10204)))) true)) (clojure.core/chunk-cons (clojure.core/chunk b__10205) (iter__10202 (clojure.core/chunk-rest s__10203))) (clojure.core/chunk-cons (clojure.core/chunk b__10205) nil)))
 
          (clojure.core/let [a (clojure.core/first s__10203)]
            (if (odd? a) (clojure.core/cons (* 2 a) (iter__10202 (clojure.core/rest s__10203))) (recur (clojure.core/rest s__10203))))
 
)))))
 

chunked sequence的部分可以先不理会, 如果我们传入的参数是(range 500)的话, 会走if语句的else部分. 这里面有个明显的递归, 因此我猜测如果range足够大, for一定会溢出, 但是实验结果表明, 将range的值加到很大例如100000也不会发生溢出的现象.

然后我写了一个简化版, 从递归的角度看效果是一样的

 
(defn iter [l]
  (if (empty? l) nil (let [a (first l)] (cons (+  2 a ) (iter (rest l)))))
)
 

这个的话就会很快溢出.

那么是不是和开头的lazy-seq, when-let, seq之类的有关呢? 再写一个版本, 和原版的for loop更加靠近

 
(defn for-iter [l]
  (lazy-seq
    (loop [loop-l l]
      (when-let [loop-l (seq loop-l)]
        (let [a (first loop-l)]
          (cons (+ 2 a) (for-iter (rest loop-l)))))))
)
 

这回不会溢出了.

这里面还是有很多噪音, 再剔除掉一些干扰

 
(defn iter [l]
  (lazy-seq
  (if (empty? l) nil (let [a (first l)] (cons (+  2 a ) (iter (rest l)))))
  )
)
 

这个版本同样不会溢出. 所以答案非常明了了, 原因就在lazy-seq. 仅仅是在表达式外层加上一个lazy-seq就可以将一个可能溢出的递归变成不会溢出的.

我们来看看加了lazy-seq之后macro展开的效果

 
user> (macroexpand-1 '(lazy-seq
  (if (empty? l) nil (let [a (first l)] (cons (+  2 a ) (iter (rest l)))))
  ))
 
(new clojure.lang.LazySeq
  (fn* []
    (if (empty? l) nil (let [a (first l)] (cons (+ 2 a) (iter (rest l)))))))
 

我们的表达式被封装在一个匿名函数当中了, 实际就是一个闭包, 并作为参数用来构造clojure.lang.LazySeq 对象. 因此这个调用是不会在iter函数返回之前执行的, 而是在返回之后, 那个时候调用已经不会增加栈的分配.

而原本需要栈来保存的数据, 保存在了闭包当中, 这些数据实际是分配在heap当中, 换句话说, 不是消耗栈而是改为消耗堆.

除了不消耗栈之外, lazy-seq还可以实现按需执行, 没有lazy-seq之前, 所有的递归是一次性执行完的, 加了lazy-seq之后, 这些递归对应每一次的闭包函数调用, 而这些调用是不需要一次性全部执行完的, 而是取一个元素就执行一次闭包, 直到取完为止, 不取则不计算.

不仅是for loop, Clojure当中常常使用map同样是用这个原理实现的.

Update: 什么时候for loop会溢出

上次说到了for loop不会溢出的原因, 是因为LazySeq并没有消耗栈, 现在来谈谈什么时候for 会导致栈溢出. 这部分内容是受到这篇文章的启发Clojure: lazy-seq and the StackOverflowException.

先梳理一下这篇文章的内容, 大意就是嵌套的lazy sequence的问题. 即如果对一个已经是lazy的sequence再次应用lazify是什么后果

 
(defn lazify [xs]
  (map identity xs)
)
 
(lazify (lazify (lazify (lazify [1 2 3 4 5 ]))))
 

这里重点是每一层lazify调用都会产生一个新的对象, 前面说到了map和for的原理一样是使用了LazySeq对象的, 而这个对象的最重要的成员就是一个闭包, 这些闭包之间有一个隐含的触发关系, 即前一个闭包的执行触发后一个闭包的执行, 而后一个闭包返回之前, 前一个闭包都要保留栈, 这样就形成了栈的消耗, 因此这个链条不能太长, 否则一定会栈溢出. 下面的函数可以创建长度为n的链条

 
(defn lazify-n [n seq]
  (loop [n n seq seq]
    (if ( > n 0)
      (recur (- n 1) (lazify seq))
      seq
    )
  )
)
 
 

长度到达一定数值的时候, 会溢出

 
user> (lazify-n 4000 [1 2 3 4 5])
StackOverflowError   clojure.core/seq (core.clj:133)
 

而for loop其实也隐含了这样的嵌套lazy-seq, 例如

 
(for [
      a [1 2]
      b [1 2 3]
     ]
  (* a b)
)
 

包含了两层

 
(lazy-seq
     .....
     (lazy-seq  
          ..... 
          [1 2 3]
          .....)
     [1 2 ]
)     
 

看看宏展开的结果就清楚了

 
(macroexpand
 '(for [
        a [1 2]
        b [1 2 3]
       ]
    (* a b)
  )
)
 

只要绑定的变量超过一个值, 就会溢出. 当然我们不会直接用手写, 那样太费事, 不过幸好我们有macro, 可以批量产生变量然后在用macro的形式嵌入到for的参数里面, 下面用前缀a加数字的方式给绑定变量命名, 然后每个变量绑定到向量[1 2]上面, 实际就是构造如下的vector:

 
[
a1 [1 2]
a2 [1 2]
a3 [1 2]
...
an [1 2]
]
 

可以用下面的代码做到

 
(defn generate-symbols [n]
  (for [a (range n)]
    (symbol (str "a" a))
  )
)
 
(defn implode [v sep]
  (conj (vec (interpose sep v)) sep)
)
 
(implode (generate-symbols 400) [1])
 
 

然后用macro将这个vector设置为for的参数

 
(defmacro test-for [n]
  `(for ~(implode (generate-symbols n) [1])
    nil
  )
)
 
user> (test-for 400)
CompilerException java.lang.StackOverflowError, compiling:(c:\bin\init.clj:1:1) 
 

实际测试的时候, 我们用的向量是[1], 因为如果是[1 2]的话, 那么循环次数将会是2的400次方, 即所有绑定向量的长度的乘积, JVM会直接挂掉, 实际上2的40次方足以让JVM崩溃.