Scala で diff を書いてみた
練習問題ということで。
うまくないところがたくさんあると思うので、添削してくださる方がいるとうれしいなぁ。
import scala.io.Source.fromFile class LCS[T] (a: Seq[T], b: Seq[T]) { type LCSPos = Tuple3[T, int, int] val a_ary: Array[T] = a.toArray val b_ary: Array[T] = b.toArray def positions: Array[LCSPos] = { val maxD = a_ary.length + b_ary.length if (maxD == 0) Array() else { val v_d: Array[int] = Array.make(maxD * 2 + 1, 0) val v_l: Array[int] = Array.make(maxD * 2 + 1, 0) val v_r: Array[List[LCSPos]] = Array.make(maxD * 2 + 1, List()) def finish() = { var maxL: int = 0 var r: List[LCSPos] = null for (i <- -maxD to maxD) { if (v_l(i + maxD) > maxL) { maxL = v_l(i + maxD) r = v_r(i + maxD) } } r.reverse.toArray } var d, x, y: int = 0 while (d <= maxD && (x < a_ary.length || y < b_ary.length)) { var k: int = -d var l: int = 0 var r: List[LCSPos] = null x = 0 y = 0 while (k <= d && (x < a_ary.length || y < b_ary.length)) { if (k == -d || k != d && v_d(k - 1 + maxD) < v_d(k + 1 + maxD)) { x = v_d(k + 1 + maxD) y = x - k l = v_l(k + 1 + maxD) r = v_r(k + 1 + maxD) } else { x = v_d(k - 1 + maxD) + 1 y = x - k l = v_l(k - 1 + maxD) r = v_r(k - 1 + maxD) } while (x < a_ary.length && y < b_ary.length && a_ary(x) == b_ary(y)) { r = (a_ary(x), x, y)::r x = x + 1 y = y + 1 l = l + 1 } v_d(k + maxD) = x v_l(k + maxD) = l v_r(k + maxD) = r k = k + 2 } d = d + 1 } finish() } } def foldLeft[U](a_only: (U, T) => U, b_only: (U, T) => U, both: (U, T) => U, seed: U):U = { def loop(common: Array[LCSPos], seed: U, a: Array[T], a_pos: int, b: Array[T], b_pos: int):U = if (common.isEmpty) b.foldLeft(a.foldLeft(seed)(a_only))(b_only) else common(0) match { case (elt, a_off, b_off) => { val a_skip = a_off - a_pos val b_skip = b_off - b_pos val a_head = a.take(a_skip) val a_tail = a.drop(a_skip + 1) val b_head = b.take(b_skip) val b_tail = b.drop(b_skip + 1) loop(common.drop(1), both(b_head.foldLeft(a_head.foldLeft(seed)(a_only))(b_only), elt), a_tail, a_off + 1, b_tail, b_off + 1) } } loop(positions, seed, a_ary, 0, b_ary, 0) } } object Diff { def main(args: Array[String]) { new LCS(fromFile(args(0)).getLines.toList, fromFile(args(1)).getLines.toList).foldLeft( (u:Unit, oldline) => {print("- " + oldline)}, (u:Unit, newline) => {print("+ " + newline)}, (u:Unit, comline) => {print(" " + comline)}, {}) } }
Gauche の util.lcs を移植したつもりなんだけど、なんでこんなに while だらけになるんだろう…? (while 書いたら負けって意識に問題があるんだろーか?)
あと、util.lcs を読むのに ViVi の作者さんによるアルゴリズムの解説、¶äridiffjASY を参考にさせていただきました。
ファイルを 1 行ずつ処理する方法が分からなかったのは、http://ja.doukaku.org/66/lang/scala/ を参考にさせていただきました。Scala Standard Library 2.12.8 の見方がなかなか分からなくて困った。