Scala で diff を書いてみた

練習問題ということで。

うまくないところがたくさんあると思うので、添削してくださる方がいるとうれしいなぁ。

import scala.io.Source.fromFile

class LCS[T] (a: Seq[T], b: Seq[T]) {
  type LCSPos = Tuple3[T, int, int]

  val a_ary: Array[T] = a.toArray
  val b_ary: Array[T] = b.toArray

  def positions: Array[LCSPos] = {
    val maxD = a_ary.length + b_ary.length
    if (maxD == 0)
      Array()
    else {
      val v_d: Array[int] = Array.make(maxD * 2 + 1, 0)
      val v_l: Array[int] = Array.make(maxD * 2 + 1, 0)
      val v_r: Array[List[LCSPos]] = Array.make(maxD * 2 + 1, List())

      def finish() = {
        var maxL: int = 0
        var r: List[LCSPos] = null

        for (i <- -maxD to maxD) {
          if (v_l(i + maxD) > maxL) {
            maxL = v_l(i + maxD)
            r = v_r(i + maxD)
          }
        }
        r.reverse.toArray
      }

      var d, x, y: int = 0

      while (d <= maxD && (x < a_ary.length || y < b_ary.length)) {
        var k: int = -d
        var l: int = 0
        var r: List[LCSPos] = null

        x = 0
        y = 0

        while (k <= d && (x < a_ary.length || y < b_ary.length)) {
          if (k == -d || k != d && v_d(k - 1 + maxD) < v_d(k + 1 + maxD)) {
            x = v_d(k + 1 + maxD)
            y = x - k
            l = v_l(k + 1 + maxD)
            r = v_r(k + 1 + maxD)
          }
          else {
            x = v_d(k - 1 + maxD) + 1
            y = x - k
            l = v_l(k - 1 + maxD)
            r = v_r(k - 1 + maxD)
          }
          while (x < a_ary.length && y < b_ary.length && a_ary(x) == b_ary(y)) {
            r = (a_ary(x), x, y)::r
            x = x + 1
            y = y + 1
            l = l + 1
          }
          v_d(k + maxD) = x
          v_l(k + maxD) = l
          v_r(k + maxD) = r

          k = k + 2
        }

        d = d + 1
      }

      finish()
    }
  }

  def foldLeft[U](a_only: (U, T) => U,
                  b_only: (U, T) => U,
                  both: (U, T) => U,
                  seed: U):U = {
    def loop(common: Array[LCSPos], seed: U,
             a: Array[T], a_pos: int, b: Array[T], b_pos: int):U =
      if (common.isEmpty)
        b.foldLeft(a.foldLeft(seed)(a_only))(b_only)
      else
        common(0) match {
          case (elt, a_off, b_off) => {
            val a_skip = a_off - a_pos
            val b_skip = b_off - b_pos
            val a_head = a.take(a_skip)
            val a_tail = a.drop(a_skip + 1)
            val b_head = b.take(b_skip)
            val b_tail = b.drop(b_skip + 1)
            loop(common.drop(1),
                 both(b_head.foldLeft(a_head.foldLeft(seed)(a_only))(b_only),
                      elt),
                 a_tail, a_off + 1, b_tail, b_off + 1)
          }
        }
    loop(positions, seed, a_ary, 0, b_ary, 0)
  }
}

object Diff {
  def main(args: Array[String]) {
    new LCS(fromFile(args(0)).getLines.toList,
            fromFile(args(1)).getLines.toList).foldLeft(
              (u:Unit, oldline) => {print("- " + oldline)},
              (u:Unit, newline) => {print("+ " + newline)},
              (u:Unit, comline) => {print("  " + comline)},
              {})
  }
}

Gauche の util.lcs を移植したつもりなんだけど、なんでこんなに while だらけになるんだろう…? (while 書いたら負けって意識に問題があるんだろーか?)

あと、util.lcs を読むのに ViVi の作者さんによるアルゴリズムの解説、•¶‘”äŠridiffjƒAƒ‹ƒSƒŠƒYƒ€ を参考にさせていただきました。

ファイルを 1 行ずつ処理する方法が分からなかったのは、http://ja.doukaku.org/66/lang/scala/ を参考にさせていただきました。Scala Standard Library 2.12.8 の見方がなかなか分からなくて困った。

追記

Rainy Day Codings の中の人に添削というか、模範解答を示していただきました。このコードとは全然べつもんです。トラックバック元をご覧下さい。

…きれーに書けるもんだなぁ。