Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Performance issues in Burrows-Wheeler in python

I was trying to implement Burrows-Wheeler transform in python. (This is one of the assignments in online course, but I hope I have done some work to be qualified to ask for help).

The algorithm works as follows. Take a string which ends with a special character ($ in my case) and create all cyclic strings from this string. Sort all these strings alphabetically, having a special character always less then any other character. After this get the last element of each string.

This gave me a oneliner:

''.join([i[-1] for i in sorted([text[i:] + text[0:i] for i in xrange(len(text))])]

Which is correct and reasonably fast for reasonably big strings (which is enough to solve the problem):

 60 000 chars - 16 secs
 40 000 chars - 07 secs
 25 000 chars - 02 secs

But when I tried to process a really huge string with few millions of chars, I failed (it takes too much time to process).

I assume that the problem is with storing too many strings in the memory.

Is there any way to overcome this?

P.S. just want to point out that also this might look like a homework problem, my solution already passes the grader and I am just looking for a way to make it faster. Also I am not spoiling the fun for other people, because if they would like to find solution, wiki article has one which is similar to mine. I also checked this question which sounds similar but answers a harder question, how to decode the string coded with this algorithm.

like image 941
Salvador Dali Avatar asked Jan 23 '14 01:01

Salvador Dali


3 Answers

I agree with the previous answer, string/list slicing in python becomes a bottleneck when performing huge algorithmic computations. The idea is not slicing.

[EDIT: not also slicing, but list indexing. If you use array.array instead of lists, the execution time reduces to a half. Indexing arrays is straightforward, indexing lists is a more complicated process) ]

Here there is a more functional solution to your problem.

The idea, is having a generator the will act as a slicer (rslice). It's a similar idea to itertools.islice but it goes to the beginning of the string when it reaches the end. And it will stop before reaching the start position you specified when creating it. With this trick you are not copying any substrings in memory, so in the end you only have pointers moving over your string without creating copies everywhere.

So we create a list containing [rslices,lastchar of the slice] and we sort it using as key the rslice ( as you can see in cf sort function).

When it's sorted, you will only need to collect for each element in the list the second element (last element of the slice previously stored).

from itertools import izip
def cf(i1,i2):
    for i,j in izip(i1[0](),i2[0]()): # We grab the the first element (is a lambda) and execute it to get the generator
        if i<j: return -1
        elif i>j: return 1
    return 0

def rslice(cad,pos): # Slice that rotates through the string (it's a generator)
    pini=pos
    lc=len(cad)
    while pos<lc:
        yield cad[pos]
        pos+=1
    pos=0
    while pos<pini-1:
        yield cad[pos]
        pos+=1

def lambdagen(start,cad): # Closure to hold a generator
    return lambda: rslice(cad,start)

def bwt(txt):
    lt=len(txt)
    arry=list(txt)+[None]

    l=[(lambdagen(0,arry),None)]+[(lambdagen(i,arry),arry[i-1]) for i in range(1,lt+1)]
    # What we keep in the list is the generator for the rotating-slice, plus the 
    # last character of the slice, so we save the time of going through the whole 
    # string to get the last character

    l.sort(cmp=cf)   # We sort using our cf function
    return [i[1] for i in l]

print bwt('Text I want to apply BTW to :D')

# ['D', 'o', 'y', 't', 'o', 'W', 't', 'I', ' ', ' ', ':', ' ', 'B', None, 'T', 'w', ' ', 
# 'T', 'p', 'a', 't', 't', 'p', 'a', 'x', 'n', ' ', ' ', ' ', 'e', 'l']

EDIT: Using arrays (execution time reduced by 2):

def bwt(txt):
    lt=len(txt)
    arry=array.array('h',[ord(i) for i in txt])
    arry.append(-1)

    l=[(lambdagen(0,arry),None)]+[(lambdagen(i,arry),arry[i-1]) for i in range(1,lt+1)]

    l.sort(cmp=cf)
    return [i[1] for i in l]
like image 120
Carlos del Ojo Avatar answered Nov 11 '22 12:11

Carlos del Ojo


It takes a long time to make all those string slices with long strings. It's at least O(N^2) (since you create N strings of N length, and each one has to be copied into memory taking its source data from the original), which destroys the overall performance and makes the sorting irrelevant. Not to mention the memory requirement!

Instead of actually slicing the string, the next thought is to order the i values you use to create the cyclic strings, in order of how the resulting string would compare - without actually creating it. This turns out to be somewhat tricky. (Removed/edited some stuff here that was wrong; please see @TimPeters' answer.)

The approach I've taken here is to bypass the standard library - which makes it difficult (though not impossible) to compare those strings 'on demand' - and do my own sorting. The natural choice of algorithm here is radix sort, since we need to consider the strings one character at a time anyway.

Let's get set up first. I am writing code for version 3.2, so season to taste. (In particular, in 3.3 and up, we could take advantage of yield from.) I am using the following imports:

from random import choice
from timeit import timeit
from functools import partial

I wrote a general-purpose radix sort function like this:

def radix_sort(values, key, step=0):
    if len(values) < 2:
        for value in values:
            yield value
        return

    bins = {}
    for value in values:
        bins.setdefault(key(value, step), []).append(value)

    for k in sorted(bins.keys()):
        for r in radix_sort(bins[k], key, step + 1):
            yield r

Of course, we don't need to be general-purpose (our 'bins' can only be labelled with single characters, and presumably you really mean to apply the algorithm to a sequence of bytes ;) ), but it doesn't hurt. Might as well have something reusable, right? Anyway, the idea is simple: we handle a base case, and then we drop each element into a "bin" according to the result from the key function, and then we pull values out of the bins in sorted bin order, recursively sorting each bin's contents.

The interface requires that key(value, n) gives us the nth "radix" of value. So for simple cases, like comparing strings directly, that could be a simple as lambda v, n: return v[n]. Here, though, the idea is to compare indices into the string, according to the data in the string at that point (considered cyclically). So let's define a key:

def bw_key(text, value, step):
    return text[(value + step) % len(text)]

Now the trick to getting the right results is to remember that we're conceptually joining up the last characters of the strings we aren't actually creating. If we consider the virtual string made using index n, its last character is at index n - 1, because of how we wrap around - and a moment's thought will confirm to you that this still works when n == 0 ;) . [However, when we wrap forwards, we still need to keep the string index in-bounds - hence the modulo operation in the key function.]

This is a general key function that needs to be passed in the text to which it will refer when transforming the values for comparison. That's where functools.partial comes in - you could also just mess around with lambda, but this is arguably cleaner, and I've found it's usually faster, too.

Anyway, now we can easily write the actual transform using the key:

def burroughs_wheeler_custom(text):
    return ''.join(text[i - 1] for i in radix_sort(range(len(text)), partial(bw_key, text)))
    # Notice I've dropped the square brackets; this means I'm passing a generator
    # expression to `join` instead of a list comprehension. In general, this is
    # a little slower, but uses less memory. And the underlying code uses lazy
    # evaluation heavily, so :)

Nice and pretty. Let's see how it does, shall we? We need a standard to compare it against:

def burroughs_wheeler_standard(text):
    return ''.join([i[-1] for i in sorted([text[i:] + text[:i] for i in range(len(text))])])

And a timing routine:

def test(n):
    data = ''.join(choice('abcdefghijklmnopqrstuvwxyz') for i in range(n)) + '$'
    custom = partial(burroughs_wheeler_custom, data)
    standard = partial(burroughs_wheeler_standard, data)
    assert custom() == standard()
    trials = 1000000 // n
    custom_time = timeit(custom, number=trials)
    standard_time = timeit(standard, number=trials)
    print("custom: {} standard: {}".format(custom_time, standard_time))

Notice the math I've done to decide on a number of trials, inversely related to the length of the test string. This should keep the total time used for testing in a reasonably narrow range - right? ;) (Wrong, of course, since we established that the standard algorithm is at least O(N^2).)

Let's see how it does (*drumroll*):

>>> imp.reload(burroughs_wheeler)
<module 'burroughs_wheeler' from 'burroughs_wheeler.py'>
>>> burroughs_wheeler.test(100)
custom: 4.7095093091438684 standard: 0.9819262643716229
>>> burroughs_wheeler.test(1000)
custom: 5.532266880287807 standard: 2.1733253807396977
>>> burroughs_wheeler.test(10000)
custom: 5.954826800612864 standard: 42.50686064849015

Whoa, that's a bit of a frightening jump. Anyway, as you can see, the new approach adds a ton of overhead on short strings, but enables the actual sorting to be the bottleneck instead of string slicing. :)

like image 20
Karl Knechtel Avatar answered Nov 11 '22 12:11

Karl Knechtel


Just adding a bit to @KarlKnechtel's spot-on response.

First, the "standard way" to speed cyclic-permutation extraction is just to paste two copies together and index directly into that. After:

N = len(text)
text2 = text * 2

then the cyclic permutation starting at index i is just text2[i: i+N], and character j in that permutation is just text2[i+j]. No need for pasting together two slices, or for modulus (%) operations.

Second, the builtin sort() can be used for this, although:

  1. It's funky ;-)
  2. For strings with few distinct characters (compared to the length of the string) Karl's radix sort will almost certainly be faster.

As proof-of-concept, here's a drop-in replacement for that part of Karl's code (although this sticks to Python 2):

def burroughs_wheeler_custom(text):
    N = len(text)
    text2 = text * 2
    class K:
        def __init__(self, i):
            self.i = i
        def __lt__(a, b):
            i, j = a.i, b.i
            for k in xrange(N): # use `range()` in Python 3
                if text2[i+k] < text2[j+k]:
                    return True
                elif text2[i+k] > text2[j+k]:
                    return False
            return False # they're equal

    inorder = sorted(range(N), key=K)
    return "".join(text2[i+N-1] for i in inorder)

Note that the builtin sort()'s implementation computes the key exactly once for each element in its input, and does save those results for the duration of the sort. In this case, the results are lazy little K instances that just remember the starting index, and whose __lt__ method compares one character pair at a time until "less than!" or "greater than!" is resolved.

like image 36
Tim Peters Avatar answered Nov 11 '22 14:11

Tim Peters