Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Mo's algorithm to compute "power" of array

Recently, I learned Mo's algorithm for the square-root decomposition of queries in order to speed up solutions to certain problems.

In order to practice implementation, I have been trying to solve D. Powerful array (a past contest problem on Codeforces) using this idea. The problem is as follows:

You have an array with n integers array.

Consider an arbitrary subarray subarray. Define Ks to be the number of occurrences of an integer s in this subarray. The power of a subarray is defined as the sum of Ks*Ks*s for all integers s (note that there are only a positive number of terms for which this is not zero).

Answer q queries. In each query, given two integers l and r, compute the power of subarray.

It holds:

limits1
limits2
limits3

Using Mo's algorithm, I have written code that solves this problem offline in complexity. I am certain that this problem can be solved using this algorithm and time complexity, as I have inspected the accepted code of others and they also use a similar algorithm.

My code, however, gets a time limit exceeded verdict.

Below is the code I have written:

#include <ios>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <utility>
#include <map>

int sqt;
long long int ans = 0;
long long int arr[200005] = {};
long long int cnt[1000005] = {};
long long int tans[200005] = {};

struct el
{
    int l, r, in;
};

bool cmp(const el &x, const el &y)
{
    if (x.l/sqt != y.l/sqt)
        return x.l/sqt < y.l/sqt;
    return x.r < y.r;
}

el qr[200005];

int main()
{
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);
    std::cout.tie(NULL);
    int n, q, a, b;
    std::cin >> n >> q;
    sqt = sqrt((double)(n))+27;
    for (int i = 0; i < n; i++)
        std::cin >> arr[i];
    for (int i = 0; i < q; i++)
    {
        std::cin >> a >> b;
        a--; b--;
        qr[i].l = a;
        qr[i].r = b;
        qr[i].in = i;
    }
    std::sort(qr, qr+q, cmp);

    int li = 0; //left iterator
    int ri = 0; //right iterator
    ans = arr[0];
    cnt[arr[0]]++;

    for (int i = 0; i < q; i++)
    {
        while (li < qr[i].l)
        {
            ans -= cnt[arr[li]]*cnt[arr[li]]*arr[li];
            cnt[arr[li]]--;
            ans += cnt[arr[li]]*cnt[arr[li]]*arr[li];
            li++;
        }
        while (li > qr[i].l)
        {
            li--;
            ans -= cnt[arr[li]]*cnt[arr[li]]*arr[li];
            cnt[arr[li]]++;
            ans += cnt[arr[li]]*cnt[arr[li]]*arr[li];
        }
        while (ri < qr[i].r)
        {
            ri++;
            ans -= cnt[arr[ri]]*cnt[arr[ri]]*arr[ri];
            cnt[arr[ri]]++;
            ans += cnt[arr[ri]]*cnt[arr[ri]]*arr[ri];
        }
        while (ri > qr[i].r)
        {
            ans -= cnt[arr[ri]]*cnt[arr[ri]]*arr[ri];
            cnt[arr[ri]]--;
            ans += cnt[arr[ri]]*cnt[arr[ri]]*arr[ri];
            ri--;
        }
        tans[qr[i].in] = ans;
    }
    for (int i = 0; i < q; i++)
        std::cout << tans[i] << '\n';
}

Can you suggest any non-asymptotic (or possibly even an asymptotic) improvement that can speed up the program enough to pass the time limit?

I have already tried the following things, to no avail:

  1. Using a vector instead of an array.
  2. Using two nested pairs instead of struct.
  3. Using only one pair, and then using a map to try to recover the correct order of answers.
  4. Adding some various constants to sqt (such as 27 in the code above).
  5. Overloading the < comparison operator within the struct el itself.

I feel like I'm missing something important, since the other codes I have inspected seem to pass the time limit with quite a bit of leeway (around half a second). Yet, they seem to be using the same algorithm as my code.

Any help would be highly appreciated!

like image 254
Robin Yu Avatar asked Oct 30 '22 06:10

Robin Yu


1 Answers

You could strength-reduce

    while (li < qr[i].l)
    {
        ans -= cnt[arr[li]]*cnt[arr[li]]*arr[li];
        cnt[arr[li]]--;
        ans += cnt[arr[li]]*cnt[arr[li]]*arr[li];
        li++;
    }

to

    while (li < qr[i].l)
    {
        ans -= (2*cnt[arr[li]]-1)*arr[li];
        cnt[arr[li]]--;
        li++;
    }

and likewise for the others.

like image 127
David Eisenstat Avatar answered Nov 15 '22 05:11

David Eisenstat