Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Given two 32-bit numbers, N and M, and two bit positions, i and j. Write a method to set all bits between i and j in N equal to M

You are given two 32-bit numbers, N and M, and two bit positions, i and j. Write a method to set all bits between i and j in N equal to M (e.g., M becomes a substring of N located at i and starting at j). EXAMPLE: Input: N = 10000000000, M = 10101, i = 2, j = 6 Output: N = 10001010100

This problem is from Cracking the Coding interview. I was able to solve it using the following O(j - i) algorithm:

def set_bits(a, b, i, j):
    if not b: return a
    while i <= j:
        if b & 1 == 1:
            last_bit = (b & 1) << i
            a |= last_bit
        else:
            set_bit = ~(1 << i)
            a &= set_bit
        b >>= 1
        i += 1
    return a

The author gave this O(1) algorithm as a solution:

def update_bits(n, m, i, j):
    max = ~0 # All 1s

    # 1s through position j, then zeroes
    left = max - ((1 << j) - 1)

    # 1s after position i
    right = ((1 << i) - 1)

    # 1’s, with 0s between i and j
    mask = left | right

    #  Clear i through j, then put m in there 
    return (n & mask) | (m << i)

I noticed that for some test cases the author's algorithm seems to be outputting the wrong answer. For example for N = 488, M = 5, i = 2, j = 6 it outputs 468. When the output should be 404, as my O(j - i) algorithm does.

Question: Is there a way to get a constant time algorithm which works for all cases?

like image 291
Sebastian Cruz Avatar asked Dec 05 '22 15:12

Sebastian Cruz


1 Answers

I think the author of the algorithm assumes the bound of j (six in your example) to be exclusive; this boils down to the question whether a range from 2 to 6 should include 6 (in Python that is not the case). In other words, if the algorithm is modified to:

def update_bits(n, m, i, j):
    max = ~0 # All 1s

    # 1s through position j, then zeroes
    left = max - ((1 << (j+1)) - 1)

    # 1s after position i
    right = ((1 << i) - 1)

    # 1’s, with 0s between i and j
    mask = left | right

    #  Clear i through j, then put m in there 
    return (n & mask) | (m << i)

It works.

Nevertheless you can speed up things a bit as follows:

def update_bits(n, m, i, j):
    # 1s through position j, then zeroes
    left = (~0) << (j+1)

    # 1s after position i
    right = ((1 << i) - 1)

    # 1’s, with 0s between i and j
    mask = left | right

    #  Clear i through j, then put m in there 
    return (n & mask) | (m << i)

In this example, we simply shift the ones out of the register.

Note that you made an error in your own algorithm, in case b = 0, that does not mean you can simply return a, since for that range, the bits should be cleared. Say a = '0b1011001111101111' and b = '0b0' and i and j are 6 and 8 respectively, one expects the result to be '0b1011001000101111'. The algorithm thus should be:

def set_bits(a, b, i, j):
    while i <= j:
        if b & 1 == 1:
            last_bit = (b & 1) << i
            a |= last_bit
        else:
            set_bit = ~(1 << i)
            a &= set_bit
        b >>= 1
        i += 1
    return a

If I make this modification and I test the program with 10'000'000 random inputs, both algorithms always produce the same result:

for i in range(10000000):
    m = randint(0,65536)
    i = randint(0,15)
    j = randint(i,16)
    n = randint(0,2**(j-i))
    if set_bits(m,n,i,j) != update_bits(m,n,i,j):
        print((bin(m),bin(n),i,j,bin(set_bits(m,n,i,j)),bin(update_bits(m,n,i,j)))) #This line is never printed.

Of course this is not a proof both algorithms are equivalent (perhaps there is a tiny cornercase where they differ), but I'm quite confident that for valid input (i and j positive, i < j, etc.) both should always produce the same result.

like image 165
Willem Van Onsem Avatar answered May 18 '23 01:05

Willem Van Onsem