The time complexity of this question differs from a similar question that's been asked. This is a question from Zauba developer hiring challenge (event ended a month ago):
f(0) = p
f(1) = q
f(2) = r
for n > 2
f(n) = a*f(n-1) + b*f(n-2) + c*f(n-3) + g(n)
where g(n) = n*n*(n+1)
p, q, r, a, b, c, n
are given. n
can be as large as 10^18
.
Link to a similar problem
In the above link, the time complexity was not specified and I have already solved this problem in O(n)
, the pseudocode is below (just an approach, all the possible boundaries, and edge cases were handled in the contest).
if(n == 0) return p;
if(n == 1) return q;
if(n == 2) return r;
for(long i=3;i<=n;i++){
now = a*r + b*q + c*p + i*i*(i+1);
p = q; q = r; r = now;
}
Please note that I have used modulo 10^9 + 7
wherever appropriate in the original code to handle overflows, handled appropriate edge cases wherever necessary and I have used java long data type (if it helps).
But since this still requires O(n)
time, I am expecting a better solution which can handle n ~ 10^18
.
EDIT
As user גלעד ברקן mentioned about its relation to matrix exponentiation, I have tried to do this and stuck at a particular point, where I am not sure what to place in the 4th row, 3rd col of the matrix. Kindly make any suggestions and corrections.
| a b c 1? | | f(n) | | f(n+1) |
| 1 0 0 0 | |f(n-1)| | f(n) |
| 0 1 0 0 | |f(n-2)| => | f(n-1) |
| 0 0 ?! 0 | | g(n) | | g(n+1) |
M A B
Matrix exponentiation is indeed the right way to go, but there's a little more work to be done.
Since g(n)
is not constant-valued, there is no way to apply matrix exponentiation efficiently (O(log n)
instead of O(n)
) to the recurrence relation in its current form.
A similar recurrence relation needs to be found for g(n)
with only a constant term trailing. Since g(n)
is cubic, 3 recursive terms are required:
g(n) = x*g(n-1) + y*g(n-2) + z*g(n-3) + w
Expand the cubic expressions for each of them:
n³ + n² = x(n³-2n²+n) + y(n³-5n²+8n-4) + z*(n³-8n²+21n-18) + w
= n³(x+y+z) + n²(-2x-5y-8z) + n(x+8y+21z) + (w-4y-18z)
Match the coefficients to obtain three simultaneous equations for x, y, z
plus another to calculate w
:
x + y + z = 1
-2x - 5y - 8z = 1
x + 8y + 21z = 0
w - 4y - 18z = 0
Solve them to obtain:
x = 3 y = -3 z = 1 w = 6
Conveniently, these coefficients are also integers*, which means modular arithmetic can be directly performed on the recurrence.
* I doubt this was a coincidence - it could well have been the intention of the hiring examiner.
The matrix recurrence equation is therefore:
| a b c 1 0 0 0 | | f(n-1) | | f(n) |
| 1 0 0 0 0 0 0 | | f(n-2) | | f(n-1) |
| 0 1 0 0 0 0 0 | | f(n-3) | | f(n-2) |
| 0 0 0 3 -3 1 6 | x | g(n) | = | g(n+1) |
| 0 0 0 1 0 0 0 | | g(n-1) | | g(n) |
| 0 0 0 0 1 0 0 | | g(n-2) | | g(n-1) |
| 0 0 0 0 0 0 1 | | 1 | | 1 |
The final matrix exponentiation equation is:
[n-2]
| a b c 1 0 0 0 | | f(2) | | f(n) | | f(2) | | r |
| 1 0 0 0 0 0 0 | | f(1) | | f(n-1) | | f(1) | | q |
| 0 1 0 0 0 0 0 | | f(0) | | f(n-2) | | f(0) | | p |
| 0 0 0 3 -3 1 6 | x | g(3) | = | g(n+1) | , | g(3) | = | 36 |
| 0 0 0 1 0 0 0 | | g(2) | | g(n) | | g(2) | | 12 |
| 0 0 0 0 1 0 0 | | g(1) | | g(n-1) | | g(1) | | 2 |
| 0 0 0 0 0 0 1 | | 1 | | 1 | | 1 | | 1 |
(Every operation is implicitly modulo 10^9 + 7
or whichever such number is supplied.)
Note that Java's %
operator is the remainder, which is different to the modulus for negative numbers. Example:
-1 % 5 == -1 // Java
-1 = 4 (mod 5) // mathematical modulus
The workaround is rather simple:
long mod(long b, long a)
{
// computes a mod b
// assumes that b is positive
return (b + (a % b)) % b;
}
The original iterative algorithm:
long recurrence_original(
long a, long b, long c,
long p, long q, long r,
long n, long m // 10^9 + 7 or whatever
) {
// base cases
if (n == 0) return p;
if (n == 1) return q;
if (n == 2) return r;
long f0, f1, f2;
f0 = p; f1 = q; f2 = r;
for (long i = 3; i <= n; i++) {
long f3 = mod(m,
mod(m, a*f2) + mod(m, b*f1) + mod(m, c*f0) +
mod(m, mod(m, i) * mod(m, i)) * mod(m, i+1)
);
f0 = f1; f1 = f2; f2 = f3;
}
return f2;
}
Modulo matrix functions:
long[][] matrix_create(int n)
{
return new long[n][n];
}
void matrix_multiply(int n, long m, long[][] c, long[][] a, long[][] b)
{
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
long s = 0;
for (int k = 0; k < n; k++)
s = mod(m, s + mod(m, a[i][k]*b[k][j]));
c[i][j] = s;
}
}
}
void matrix_pow(int n, long m, long p, long[][] y, long[][] x)
{
// swap matrices
long[][] a = matrix_create(n);
long[][] b = matrix_create(n);
long[][] c = matrix_create(n);
// initialize accumulator to identity
for (int i = 0; i < n; i++)
a[i][i] = 1;
// initialize base to original matrix
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
b[i][j] = x[i][j];
// exponentiation by squaring
// there are better algorithms, but this is the easiest to implement
// and is still O(log n)
long[][] t = null;
for (long s = p; s > 0; s /= 2) {
if (s % 2 == 1) {
matrix_multiply(n, m, c, a, b);
t = c; c = a; a = t;
}
matrix_multiply(n, m, c, b, b);
t = c; c = b; b = t;
}
// write to output
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
y[i][j] = a[i][j];
}
And finally, the new algorithm itself:
long recurrence_matrix(
long a, long b, long c,
long p, long q, long r,
long n, long m
) {
if (n == 0) return p;
if (n == 1) return q;
if (n == 2) return r;
// original recurrence matrix
long[][] mat = matrix_create(7);
mat[0][0] = a; mat[0][1] = b; mat[0][2] = c; mat[0][3] = 1;
mat[1][0] = 1; mat[2][1] = 1;
mat[3][3] = 3; mat[3][4] = -3; mat[3][5] = 1; mat[3][6] = 6;
mat[4][3] = 1; mat[5][4] = 1;
mat[6][6] = 1;
// exponentiate
long[][] res = matrix_create(7);
matrix_pow(7, m, n - 2, res, mat);
// multiply the first row with the initial vector
return mod(m, mod(m, res[0][6])
+ mod(m, res[0][0]*r) + mod(m, res[0][1]*q) + mod(m, res[0][2]*p)
+ mod(m, res[0][3]*36) + mod(m, res[0][4]*12) + mod(m, res[0][5]*2)
);
}
Here are some sample benchmarks for both algorithms above.
Original iterative algorithm:
n time (μs)
-------------------
10^1 9.3
10^2 44.9
10^3 401.501
10^4 3882.099
10^5 27940.9
10^6 88873.599
10^7 877100.5
10^8 9057329.099
10^9 91749994.4
New matrix algorithm:
n time (μs)
------------------
10^1 69.168
10^2 128.771
10^3 212.697
10^4 258.385
10^5 318.195
10^6 380.9
10^7 453.487
10^8 560.428
10^9 619.835
10^10 652.344
10^11 750.518
10^12 769.901
10^13 851.845
10^14 934.915
10^15 1016.732
10^16 1079.613
10^17 1123.413
10^18 1225.323
The old algorithm took over 90 seconds to calculate n = 10^9
, whereas the new algorithm accomplished it in just over 0.6 milliseconds (a 150,000x speed-up)!
The original algorithm's time complexity was evidently linear (as expected); n = 10^10
took too long to complete so I didn't continue.
The new algorithm's time complexity was evidently logarithmic - doubling the order-of-magnitude of n
led to the execution time doubling (again, as expected due to exponentiation-by-squaring).
For "small" values of n
(< 100
) the overhead of matrix allocation and operations overshadowed the algorithm itself, but quickly became insignificant as n
increased.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With