😽

Scalaによる再帰/末尾再帰処理の基礎

2021/02/24に公開

はじめに

関数プログラミングことはじめ in 福岡を使ってScalaでの再帰処理の基礎のメモ。
Qrunchに載せてたけどサービス終了してしまったので、今回ローカルに保存していたのを見返す機会があったので記事にしておく。

末尾再帰とは

再帰呼び出しの最後から最初の呼び出しに戻ることなく、
最後の呼び出し時点で結果が求まるもの。
呼び出しを戻る必要がないため呼び出し元をメモリに保存する必要が無く
コンパイラによりループに最適化される?

1からnまでの自然数の和を求めよ

  • 再帰
def sumN(n:Int):BigInt = n match {
  case 1 => 1
  case n => n + sumN(n-1)
}
  • 末尾再帰
def sumN(n:Int):BigInt = {
  def loop(acc:BigInt,n:Int):BigInt = n match {
    case 1 => acc
    case n => loop(acc + n,n-1)
  }
  loop(0,n)
}

自然数nの階乗を求めよ

  • 再帰
def factorial(n :Int) :BigInt = n match {
  case 0 => 1
  case n => n * factorial(n-1)
}
  • 末尾再帰
def factorial(n: Int): BigInt = {
  def loop(acc:BigInt,n: Int): BigInt = n match {
    case 0 => acc
    case n => loop(acc * n,n-1)
  }
  loop(1,n)
}

フィボナッチ数列のn番目の値を求めよ

フィボナッチ数列とは前2つの値の和が自分の値になる数列。
n番目の値 = n-1番目の値 + n-2番目の値
f(0) = 0,f(1) = 1とする。

  • 再帰
def fibonacci(n:Int): BigInt = n match {
  case 0 => 0
  case 1 => 1
  case n => fibonacci(n-1) + fibonacci(n-2)
}
  • 末尾再帰
def fibonacci(n:Int): BigInt = {
  def loop(acc0:BigInt,acc1:BigInt,n:Int):BigInt = n match {
    case 0 => acc0
    case 1 => acc1
    case n => loop(acc1,acc0 + acc1,n-1)
  }
  loop(0,1,n)
}

リスト内の数の合計値を求めよ

  • 再帰
def sum(ints:List[Int]):Int = ints match {
  case Nil => 0
  case head::tail => head + sum(tail)
}
  • 末尾再帰
def sum(ints:List[Int]):Int = {
  def loop(acc:Int,lst:List[Int]):Int = lst match {
    case Nil => acc
    case head::tail => loop(acc + head,tail)
  }
  loop(0,ints)
}

リスト内の全ての数を掛け合わせた数を求めよ

  • 再帰
def product(ints:List[Int]):BigInt = ints match {
  case Nil => 1
  case head::tail => head * product(tail)
}
  • 末尾再帰
def product(ints:List[Int]):BigInt = {
  def loop(acc:BigInt,lst:List[Int]):BigInt = lst match {
    case Nil => acc
    case head::tail => loop(acc * head,tail)
  }
  loop(1,ints)
}

リスト内の最大値を求めよ

  • 再帰
def max(ints: List[Int]): Int = ints match {
  case Nil => 0
  case head::tail => {
    val maxValue = max(tail)
    if (head > maxValue) head else maxValue
  }
}
  • 末尾再帰
def max(ints: List[Int]): Int = {
  def loop(acc: Int, lst: List[Int]): Int = lst match {
    case Nil => acc
    case head::tail => {
      val maxValue = if (head > acc) head else acc
      loop(maxValue,tail)
    }
  }
  ints match {
    case Nil => 0
    case ints => loop(ints.head,ints.tail)
  }
}

リストを逆順にしたものを求めよ

  • 再帰
def reverse(ints: List[Int]): List[Int] = ints match {
  case Nil => Nil
  case head::tail => reverse(tail) ::: List(head)
}
  • 末尾再帰
def reverse(ints: List[Int]): List[Int] = {
  def loop(acc: List[Int],lst: List[Int]): List[Int] = lst match {
    case Nil => acc
    case head::tail => loop(head :: acc,tail)
  }
  loop(Nil,ints)
}

リストの長さを求めよ

  • 再帰
def length(ints: List[Int]): Int = ints match {
  case Nil => 0
  case _::tail => 1 + length(tail)
}
  • 末尾再帰
def length(ints: List[Int]): Int = {
  def loop(acc: Int,lst: List[Int]): Int = lst match {
    case Nil => acc
    case _::tail => loop(acc + 1,tail)
  }
  loop(0,ints)
}

二分木のデータ構造

二分木の各ノードは値と左右の子要素を持つ

// sealed abstract class Tree(trait Treeの代わりにこっちでもOK)
trait Tree
case class Node(value: Int,left: Tree,right: Tree) extends Tree
case object Empty extends Tree

二分木の数の合計を求めよ

  • 再帰
def sum(tree: Tree):Int = tree match {
  case Empty => 0
  case Node(x,left,right) => x + sum(left) + sum(right)
}
  • 末尾再帰
def sum(tree: Tree): Int = {
  def loop(tree: Tree,next: List[Tree],acc: Int):Int = tree match {
    case Empty => {
      next match {
        case Nil => acc
        case head::tail => loop(head,tail,acc)
      }
    }
    case Node(x,left,right) => loop(left,right :: next,acc + x)
  }
  tree match {
    case Empty => 0
    case Node(x,left,right) => loop(left,List(right),x)
  }
}

二分木内の数の最大値を求めよ

  • 再帰
def max(tree: Tree):Int = tree match {
  case Empty => 0
  case Node(x,left,right) => {
    val maxLeft = max(left)
    val maxRight = max(right)
    val maxChild = if (maxLeft > maxRight) maxLeft else maxRight
    if (x > maxChild) x else maxChild
  }
}
  • 末尾再帰
def max(tree: Tree): Int = {
  def loop(tree: Tree,next: List[Tree],acc: Int):Int = tree match {
    case Empty => {
      next match {
        case Nil => acc
        case head::tail => loop(head,tail,acc)
      }
    }
    case Node(x,left,right) => {
      val maxValue = if (x > acc) x else acc
      loop(left,right :: next,maxValue)
    }
  }
  tree match {
    case Empty => 0
    case Node(x,left,right) => loop(left,List(right),x)
  }
}

二分木のノードの数を求めよ

  • 再帰
def size(tree: Tree): Int = tree match {
  case Empty => 0
  case Node(_,left,right) => 1 + size(left) + size(right)
}
  • 末尾再帰
def sizet(tree: Tree): Int = {
  def loop(tree: Tree,next: List[Tree],acc: Int):Int = tree match {
    case Empty => {
      next match {
        case Nil => acc
        case head::tail => loop(head,tail,acc)
      }
    }
    case Node(_,left,right) => loop(left,right :: next,acc + 1)
  }
  tree match {
    case Empty => 0
    case Node(_,left,right) => loop(left,List(right),1)
  }
}

二分木の末端のノードまでの最長パスを求めよ(最大の深さ)

  • 再帰
def depth(tree: Tree): Int = tree match {
  case Empty => 0
  case Node(_,left,right) => {
    val depthLeft = depth(left)
    val depthRight = depth(right)
    val depthChild = if (depthLeft > depthRight) depthLeft else depthRight
    1 + depthChild
  }
}
  • 末尾再帰
def depth(tree: Tree): Int = {
  def loop(tree: (Tree,Int),next: List[(Tree,Int)],acc: Int): Int = tree match {
    case (Empty,_) => next match {
      case Nil => acc
      case head::tail => {
        val maxDepth = if (head._2 > acc) head._2 else acc
        loop(head,tail,maxDepth)
      }
    }
    case (Node(_,left,right),d) => loop((left,d+1),(right,d+1)::next,maxDepth)
  }
  tree match {
    case Empty => 0 
    case Node(_,left,right) => loop((left,1),List((right,1)),0)
  }
}

参考文献

Discussion