当前位置:  开发笔记 > 编程语言 > 正文

"Scala编程"中的合并排序会导致堆栈溢出

如何解决《"Scala编程"中的合并排序会导致堆栈溢出》经验,为你挑选了2个好方法。

直接剪切和粘贴以下算法:

def msort[T](less: (T, T) => Boolean)
            (xs: List[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] =
    (xs, ys) match {
      case (Nil, _) => ys
      case (_, Nil) => xs
      case (x :: xs1, y :: ys1) =>
        if (less(x, y)) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  val n = xs.length / 2
  if (n == 0) xs
  else {
    val (ys, zs) = xs splitAt n
     merge(msort(less)(ys), msort(less)(zs))
  }
}

导致5000个长列表上的StackOverflowError.

有没有办法优化这个,以便不会发生这种情况?



1> Daniel C. So..:

这样做是因为它不是尾递归的.您可以通过使用非严格集合或使其尾递归来解决此问题.

后一个解决方案是这样的:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
      case (Nil, _) => ys.reverse ::: acc 
      case (_, Nil) => xs.reverse ::: acc
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) merge(xs1, ys, x :: acc) 
        else merge(xs, ys1, y :: acc) 
    } 
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse
  } 
} 

使用非严格性涉及按名称传递参数,或使用非严格的集合,如Stream.以下代码Stream仅用于防止堆栈溢出,以及List其他地方:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList
  } 
}


我认为这里值得一提的是,msort本身不是尾递归,而是合并.对于只被编译器说服的人来说,将@tailrec添加到merge的定义中,你会发现它被接受为尾递归函数,就像Daniel概述的那样.

2> timday..:

只是玩scala TailCalls(蹦床支持),我怀疑这个问题最初提出时并不存在.这是Rex答案中合并的递归不可变版本.

import scala.util.control.TailCalls._

def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {

  def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
    if (a.isEmpty) {
      done(b.reverse ::: s)
    } else if (b.isEmpty) {
      done(a.reverse ::: s)
    } else if (a.head

List[Long]在64位OpenJDK(在i7上的Debian/Squeeze amd64)上运行Scala 2.9.1 上的大s上的可变版本的速度和运行速度一样快.

推荐阅读
coco2冰冰
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有