snippetModerate
Merge sort in Scala
Viewed 0 times
sortscalamerge
Problem
I've implemented merge sort in Scala:
Could you look at this? Are there some Scala tricks to do it better, quicker, smaller or cleaner?
object Lunch {
def doMergeSortExample() = {
val values:Array[Int] = List(5,11,8,4,2).toArray
sort(values)
printArray(values)
}
def sort(array:Array[Int]) {
if (array.length > 1 ){
var firstArrayLength = (array.length/2)
var first:Array[Int] = array.slice(0, firstArrayLength)
var second:Array[Int] = array.slice(firstArrayLength, array.length)
sort(first)
sort(second)
merge(array, first, second)
}
}
def merge(result:Array[Int], first:Array[Int], second:Array[Int]) {
var i:Int = 0
var j:Int = 0
for (k =first.length && j<second.length){
result(k) = second(j)
j=j+1
} else {
result(k) = first(i)
i=i+1
}
}
}
def printArray(array: Array[Int]) = {
println(array.deep.mkString(", "))
}
def main(args: Array[String]) {
doMergeSortExample();
}
}Could you look at this? Are there some Scala tricks to do it better, quicker, smaller or cleaner?
Solution
Just a few unsorted ideas:
-
Mergesort can be very nicely expressed using Scala's streams. In particular:
It'll be slower than working with arrays, but the method is much more
concise, and it's completely stateless. And, it will be fully lazy - it will
compute only those elements that you ask for. With such a lazy merge sort,
you can sort a sequence, then ask only for the first element,
and you'll get it in O(n) time instead of O(n log n).
-
Instead of splitting the input into smaller and smaller pieces and then merging them, you can split it into singletons in a single pass and then just merge those singletons. For example, create a
and then merge pairs of them repeatedly. (Be sure to merge streams with the same or similar length, otherwise the process will be inefficient.)
-
This can be further improved: Instead of just splitting the input into singletons, you can split the input into non-decreasing subsequences. For example (using an informal list notation), you'd split
-
A further improvement is to look for both non-decreasing and non-increasing sequences (and reverse the non-increasing ones before merging them).
All these ideas can be seen in Haskell's sort implementation: (Haskell's lists are lazy, just like Scala's
-
Mergesort can be very nicely expressed using Scala's streams. In particular:
def merge(first: Stream[Int], second: Stream[Int]): Stream[Int] =
(first, second) match {
case (x #:: xs, ys@(y #:: _)) if x x #:: merge(xs, ys)
case (xs, y #:: ys) => y #:: merge(xs, ys)
case (xs, Empty) => xs
case (Empty, ys) => ys
}It'll be slower than working with arrays, but the method is much more
concise, and it's completely stateless. And, it will be fully lazy - it will
compute only those elements that you ask for. With such a lazy merge sort,
you can sort a sequence, then ask only for the first element,
and you'll get it in O(n) time instead of O(n log n).
-
Instead of splitting the input into smaller and smaller pieces and then merging them, you can split it into singletons in a single pass and then just merge those singletons. For example, create a
Stream of Streams likedef col2strstr(c: Iterable[Int]): Stream[Stream[Int]] =
for(x <- c.toStream) yield Stream(x);and then merge pairs of them repeatedly. (Be sure to merge streams with the same or similar length, otherwise the process will be inefficient.)
-
This can be further improved: Instead of just splitting the input into singletons, you can split the input into non-decreasing subsequences. For example (using an informal list notation), you'd split
[7,8,9,4,5,6,1,2,3] into [[7,8,9],[4,5,6],[1,2,3]]. This can dramatically reduce the number of merges. In particular, if you pass an already sorted input, it will just check that it's sorted in O(n) without doing any merge.-
A further improvement is to look for both non-decreasing and non-increasing sequences (and reverse the non-increasing ones before merging them).
All these ideas can be seen in Haskell's sort implementation: (Haskell's lists are lazy, just like Scala's
Streams.)sort = sortBy compare
sortBy cmp = mergeAll . sequences
where
sequences (a:b:xs)
| a `cmp` b == GT = descending b [a] xs
| otherwise = ascending b (a:) xs
sequences xs = [xs]
descending a as (b:bs)
| a `cmp` b == GT = descending b (a:as) bs
descending a as bs = (a:as): sequences bs
ascending a as (b:bs)
| a `cmp` b /= GT = ascending b (\ys -> as (a:ys)) bs
ascending a as bs = as [a]: sequences bs
mergeAll [x] = x
mergeAll xs = mergeAll (mergePairs xs)
mergePairs (a:b:xs) = merge a b: mergePairs xs
mergePairs xs = xs
merge as@(a:as') bs@(b:bs')
| a `cmp` b == GT = b:merge as bs'
| otherwise = a:merge as' bs
merge [] bs = bs
merge as [] = asCode Snippets
def merge(first: Stream[Int], second: Stream[Int]): Stream[Int] =
(first, second) match {
case (x #:: xs, ys@(y #:: _)) if x <= y => x #:: merge(xs, ys)
case (xs, y #:: ys) => y #:: merge(xs, ys)
case (xs, Empty) => xs
case (Empty, ys) => ys
}def col2strstr(c: Iterable[Int]): Stream[Stream[Int]] =
for(x <- c.toStream) yield Stream(x);sort = sortBy compare
sortBy cmp = mergeAll . sequences
where
sequences (a:b:xs)
| a `cmp` b == GT = descending b [a] xs
| otherwise = ascending b (a:) xs
sequences xs = [xs]
descending a as (b:bs)
| a `cmp` b == GT = descending b (a:as) bs
descending a as bs = (a:as): sequences bs
ascending a as (b:bs)
| a `cmp` b /= GT = ascending b (\ys -> as (a:ys)) bs
ascending a as bs = as [a]: sequences bs
mergeAll [x] = x
mergeAll xs = mergeAll (mergePairs xs)
mergePairs (a:b:xs) = merge a b: mergePairs xs
mergePairs xs = xs
merge as@(a:as') bs@(b:bs')
| a `cmp` b == GT = b:merge as bs'
| otherwise = a:merge as' bs
merge [] bs = bs
merge as [] = asContext
StackExchange Code Review Q#21575, answer score: 15
Revisions (0)
No revisions yet.