Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Number of substrings in range [l, r] that can be permuted to palindrome

Given a string s of lowercase English letters, with |s| <= 10^5. There are upto q <= 10^5 queries giving a range [l, r] that asks: how many substrings of the string s[l...r] can be permuted to form a palindrome.

Now a string can be permuted to a palindrome iff the number of characters appearing odd number of times is at most 1. I tried to use a segment tree but can't seem to merge two ranges. How can I approach it?

like image 640
Artur Avatar asked Jan 24 '17 11:01

Artur


1 Answers

For want of a better answer, we can apply Mo's algorithm to get an O(n^(3/2) |alphabet|)-time algorithm. That might be fast enough for you. The core is the following incremental algorithm for computing the number of palindromically-permutable substrings of the whole string (in Python 3):

import collections


def ppcount(s):
    n = 0
    sigcount = collections.Counter()
    sig = 0
    for c in s:
        sigcount[sig] += 1
        sig ^= 1 << (ord(c) - ord('a'))
        n += sigcount[sig]
        for i in range(26):
            n += sigcount[sig ^ (1 << i)]
    return n

The variable sig tracks which letters have odd frequency in the input so far. A substring s[l:r] (include l, exclude r) is palindromically-permutable iff the signature for the prefix of length l is at Hamming distance at most 1 from the signature for the prefix of length r-1. The map sigcount tracks how many prefixes have a particular signature.

To apply Mo's algorithm, first write the inverse operation for the loop body above (i.e., subtract from n, update sig, and decrement sigcount). Read in all of the queries and sort them by (l // int(sqrt(n)), r). For each query in sorted order, use update and inverse update operations to adjust the string being considered to s[l:r+1], then report the current total.

Working Python code (naive version first, for comparison; keep scrolling):

import collections
import math
import random


def odd(n):
    return bool(n & 1)


def ispp(s):
    return sum(odd(n) for n in collections.Counter(s).values()) <= 1


def naiveppcount(s):
    n = len(s)
    return sum(ispp(s[l:r + 1]) for l in range(n) for r in range(l, n))


def bit(c):
    return 1 << ((ord(c) - 1) & 31)


def neighbors(sig):
    yield sig
    for i in range(26):
        yield sig ^ (1 << i)


class PPCounter(object):
    def __init__(self):
        self.count = 0
        self._sigcount = collections.Counter({0: 1})
        self._leftsig = 0
        self._rightsig = 0

    def pushleft(self, c):
        self._leftsig ^= bit(c)
        for sig in neighbors(self._leftsig):
            self.count += self._sigcount[sig]
        self._sigcount[self._leftsig] += 1

    def popleft(self, c):
        self._sigcount[self._leftsig] -= 1
        for sig in neighbors(self._leftsig):
            self.count -= self._sigcount[sig]
        self._leftsig ^= bit(c)

    def pushright(self, c):
        self._rightsig ^= bit(c)
        for sig in neighbors(self._rightsig):
            self.count += self._sigcount[sig]
        self._sigcount[self._rightsig] += 1

    def popright(self, c):
        self._sigcount[self._rightsig] -= 1
        for sig in neighbors(self._rightsig):
            self.count -= self._sigcount[sig]
        self._rightsig ^= bit(c)


def ppcount(s, intervals):
    sqrtn = int(math.sqrt(len(s)))
    intervals = sorted(
        intervals, key=lambda interval: (interval[0] // sqrtn, interval[1]))
    l = 0
    r = -1
    ctr = PPCounter()
    for interval in intervals:
        il, ir = interval
        while l > il:
            l -= 1
            ctr.pushleft(s[l])
        while r < ir:
            r += 1
            ctr.pushright(s[r])
        while l < il:
            ctr.popleft(s[l])
            l += 1
        while r > ir:
            ctr.popright(s[r])
            r -= 1
        yield interval, ctr.count


def test():
    n = 100
    s = [random.choice('abcd') for i in range(n)]
    intervals = []
    for i in range(1000):
        l = random.randrange(n)
        r = random.randrange(n)
        intervals.append((min(l, r), max(l, r)))
    for (l, r), count in ppcount(s, intervals):
        assert count == naiveppcount(s[l:r + 1])


if __name__ == '__main__':
    test()
like image 105
David Eisenstat Avatar answered Sep 30 '22 13:09

David Eisenstat