HiveBrain v1.2.0
Get Started
← Back to all entries
snippetpythonMinor

Inversion count using merge sort

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
mergeinversionsortusingcount

Problem

count = 0

def merge_sort(li):

    if len(li) < 2: return li 
    m = len(li) / 2 
    return merge(merge_sort(li[:m]), merge_sort(li[m:])) 

def merge(l, r):
    global count
    result = [] 
    i = j = 0 
    while i < len(l) and j < len(r): 
        if l[i] < r[j]: 
            result.append(l[i])
            i += 1 
        else: 
            result.append(r[j])
            count = count + (len(l) - i)
            j += 1
    result.extend(l[i:]) 
    result.extend(r[j:]) 
    return result

unsorted = [10,2,3,22,33,7,4,1,2]
print merge_sort(unsorted)
print count

Solution

Rather than a global count, I would suggest using either a parameter, or to return a tuple that keeps the count during each recursive call. This would also assure you thread safety.

def merge_sort(li, c):
    if len(li) < 2: return li 
    m = len(li) / 2 
    return merge(merge_sort(li[:m],c), merge_sort(li[m:],c),c) 

def merge(l, r, c):
    result = []


Since l and r are copied in merge_sort, we can modify them without heart burn. We first reverse the two lists O(n) so that we can use s.pop() from the correct end in O(1) (Thanks to @ofer.sheffer for pointing out the mistake).

l.reverse()
    r.reverse()
    while l and r:
        s = l if l[-1] < r[-1] else r
        result.append(s.pop())


Counting is separate from the actual business of merge sort. So it is nicer to move it to a separate line.

if (s == r): c[0] += len(l)


Now, add what ever is left in the array

rest = l if l else r
    rest.reverse()
    result.extend(rest)
    return result

unsorted = [10,2,3,22,33,7,4,1,2]


Use a mutable DS to simulate pass by reference.

count = [0]
print merge_sort(unsorted, count)
print count[0]

Code Snippets

def merge_sort(li, c):
    if len(li) < 2: return li 
    m = len(li) / 2 
    return merge(merge_sort(li[:m],c), merge_sort(li[m:],c),c) 

def merge(l, r, c):
    result = []
l.reverse()
    r.reverse()
    while l and r:
        s = l if l[-1] < r[-1] else r
        result.append(s.pop())
if (s == r): c[0] += len(l)
rest = l if l else r
    rest.reverse()
    result.extend(rest)
    return result


unsorted = [10,2,3,22,33,7,4,1,2]
count = [0]
print merge_sort(unsorted, count)
print count[0]

Context

StackExchange Code Review Q#12922, answer score: 7

Revisions (0)

No revisions yet.