Given an array of positive integers. How to find a subsequence of length L
with max
sum which has the distance between any two of its neighboring elements that do not exceed K
I have the following solution but don't know how to take into account length L.
1 <= N <= 100000, 1 <= L <= 200, 1 <= K <= N
f[i] contains max sum of the subsequence that ends in i.
for i in range(K, N)
f[i] = INT_MIN
for j in range(1, K+1)
f[i] = max(f[i], f[i-j] + a[i])
return max(f)
(edit: slightly simplified non-recursive solution)
You can do it like this, just for each iteration consider if the item should be included or excluded.
def f(maxK,K, N, L, S):
if L == 0 or not N or K == 0:
return S
#either element is included
included = f(maxK,maxK, N[1:], L-1, S + N[0] )
#or excluded
excluded = f(maxK,K-1, N[1:], L, S )
return max(included, excluded)
assert f(2,2,[10,1,1,1,1,10],3,0) == 12
assert f(3,3,[8, 3, 7, 6, 2, 1, 9, 2, 5, 4],4,0) == 30
If N is very long you can consider changing to a table version, you could also change the input to tuples and use memoization.
Since OP later included the information that N can be 100 000, we can't really use recursive solutions like this. So here is a solution that runs in O(nKL), with same memory requirement:
import numpy as np
def f(n,K,L):
t = np.zeros((len(n),L+1))
for l in range(1,L+1):
for i in range(len(n)):
t[i,l] = n[i] + max( (t[i-k,l-1] for k in range(1,K+1) if i-k >= 0), default = 0 )
return np.max(t)
assert f([10,1,1,1,1,10],2,3) == 12
assert f([8, 3, 7, 6, 2, 1, 9],3,4) == 30
Explanation of the non recursive solution. Each cell in the table t[ i, l ] expresses the value of max subsequence with exactly l elements that use the element in position i and only elements in position i or lower where elements have at most K distance between each other.
subsequences of length n (those in t[i,1] have to have only one element, n[i] )
Longer subsequences have the n[i] + a subsequence of l-1 elements that starts at most k rows earlier, we pick the one with the maximal value. By iterating this way, we ensure that this value is already calculated.
Further improvements in memory is possible by considering that you only look at most K steps back.
Here is a bottom up (ie no recursion) dynamic solution in Python. It takes memory O(l * n)
and time O(l * n * k)
.
def max_subseq_sum(k, l, values):
# table[i][j] will be the highest value from a sequence of length j
# ending at position i
table = []
for i in range(len(values)):
# We have no sum from 0, and i from len 1.
table.append([0, values[i]])
# By length of previous subsequence
for subseq_len in range(1, l):
# We look back up to k for the best.
prev_val = None
for last_i in range(i-k, i):
# We don't look back if the sequence was not that long.
if subseq_len <= last_i+1:
# Is this better?
this_val = table[last_i][subseq_len]
if prev_val is None or prev_val < this_val:
prev_val = this_val
# Do we have a best to offer?
if prev_val is not None:
table[i].append(prev_val + values[i])
# Now we look for the best entry of length l.
best_val = None
for row in table:
# If the row has entries for 0...l will have len > l.
if l < len(row):
if best_val is None or best_val < row[l]:
best_val = row[l]
return best_val
print(max_subseq_sum(2, 3, [10, 1, 1, 1, 1, 10]))
print(max_subseq_sum(3, 4, [8, 3, 7, 6, 2, 1, 9, 2, 5, 4]))
If I wanted to be slightly clever I could make this memory O(n)
pretty easily by calculating one layer at a time, throwing away the previous one. It takes a lot of cleverness to reduce running time to O(l*n*log(k))
but that is doable. (Use a priority queue for your best value in the last k. It is O(log(k))
to update it for each element but naturally grows. Every k
values you throw it away and rebuild it for a O(k)
cost incurred O(n/k)
times for a total O(n)
rebuild cost.)
And here is the clever version. Memory O(n)
. Time O(n*l*log(k))
worst case, and average case is O(n*l)
. You hit the worst case when it is sorted in ascending order.
import heapq
def max_subseq_sum(k, l, values):
count = 0
prev_best = [0 for _ in values]
# i represents how many in prev subsequences
# It ranges from 0..(l-1).
for i in range(l):
# We are building subsequences of length i+1.
# We will have no way to find one that ends
# before the i'th element at position i-1
best = [None for _ in range(i)]
# Our heap will be (-sum, index). It is a min_heap so the
# minimum element has the largest sum. We track the index
# so that we know when it is in the last k.
min_heap = [(-prev_best[i-1], i-1)]
for j in range(i, len(values)):
# Remove best elements that are more than k back.
while min_heap[0][-1] < j-k:
heapq.heappop(min_heap)
# We append this value + (best prev sum) using -(-..) = +.
best.append(values[j] - min_heap[0][0])
heapq.heappush(min_heap, (-prev_best[j], j))
# And now keep min_heap from growing too big.
if 2*k < len(min_heap):
# Filter out elements too far back.
min_heap = [_ for _ in min_heap if j - k < _[1]]
# And make into a heap again.
heapq.heapify(min_heap)
# And now finish this layer.
prev_best = best
return max(prev_best)
Extending the code for itertools.combinations
shown at the docs, I built a version that includes an argument for the maximum index distance (K
) between two values. It only needed an additional and indices[i] - indices[i-1] < K
check in the iteration:
def combinations_with_max_dist(iterable, r, K):
# combinations('ABCD', 2) --> AB AC AD BC BD CD
# combinations(range(4), 3) --> 012 013 023 123
pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = list(range(r))
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r and indices[i] - indices[i-1] < K:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
Using this you can bruteforce over all combinations with regards to K, and then find the one that has the maximum value sum:
def find_subseq(a, L, K):
return max((sum(values), values) for values in combinations_with_max_dist(a, L, K))
Results:
print(*find_subseq([10, 1, 1, 1, 1, 10], L=3, K=2))
# 12 (10, 1, 1)
print(*find_subseq([8, 3, 7, 6, 2, 1, 9, 2, 5, 4], L=4, K=3))
# 30 (8, 7, 6, 9)
Not sure about the performance if your value lists become very long though...
Basic idea:
firstIdx
.
[firstIdx + 1, firstIdx + K]
, both inclusive.L - 1
as the new L.firstIndex
, L
), cache its max sum, for reuse.
Maybe this is necessary for large input.Constraints:
array length
<= 1 << 17
// 131072
K
<= 1 << 6
// 64
L
<= 1 << 8
// 256
Complexity:
O(n * L * K)
(firstIdx , L)
pair only calculated once, and that contains a iteration of K.
O(n * L)
Tips:
L
, not array length
.array length
and K
actually could be of any size as long as there are enough memory, since they are handled via iteration.L
is handled via recursion, thus it does has a limit.Java
SubSumLimitedDistance.java:
import java.util.HashMap;
import java.util.Map;
public class SubSumLimitedDistance {
public static final long NOT_ENOUGH_ELE = -1; // sum that indicate not enough element, should be < 0,
public static final int MAX_ARR_LEN = 1 << 17; // max length of input array,
public static final int MAX_K = 1 << 6; // max K, should not be too long, otherwise slow,
public static final int MAX_L = 1 << 8; // max L, should not be too long, otherwise stackoverflow,
/**
* Find max sum of sum array.
*
* @param arr
* @param K
* @param L
* @return max sum,
*/
public static long find(int[] arr, int K, int L) {
if (K < 1 || K > MAX_K)
throw new IllegalArgumentException("K should be between [1, " + MAX_K + "], but get: " + K);
if (L < 0 || L > MAX_L)
throw new IllegalArgumentException("L should be between [0, " + MAX_L + "], but get: " + L);
if (arr.length > MAX_ARR_LEN)
throw new IllegalArgumentException("input array length should <= " + MAX_ARR_LEN + ", but get: " + arr.length);
Map<Integer, Map<Integer, Long>> cache = new HashMap<>(); // cache,
long maxSum = NOT_ENOUGH_ELE;
for (int i = 0; i < arr.length; i++) {
long sum = findTakeFirst(arr, K, L, i, cache);
if (sum == NOT_ENOUGH_ELE) break; // not enough elements,
if (sum > maxSum) maxSum = sum; // larger found,
}
return maxSum;
}
/**
* Find max sum of sum array, with index of first taken element specified,
*
* @param arr
* @param K
* @param L
* @param firstIdx index of first taken element,
* @param cache
* @return max sum,
*/
private static long findTakeFirst(int[] arr, int K, int L, int firstIdx, Map<Integer, Map<Integer, Long>> cache) {
// System.out.printf("findTakeFirst(): K = %d, L = %d, firstIdx = %d\n", K, L, firstIdx);
if (L == 0) return 0; // done,
if (firstIdx + L > arr.length) return NOT_ENOUGH_ELE; // not enough elements,
// check cache,
Map<Integer, Long> map = cache.get(firstIdx);
Long cachedResult;
if (map != null && (cachedResult = map.get(L)) != null) {
// System.out.printf("hit cache, cached result = %d\n", cachedResult);
return cachedResult;
}
// cache not exists, calculate,
long maxRemainSum = NOT_ENOUGH_ELE;
for (int i = firstIdx + 1; i <= firstIdx + K; i++) {
long remainSum = findTakeFirst(arr, K, L - 1, i, cache);
if (remainSum == NOT_ENOUGH_ELE) break; // not enough elements,
if (remainSum > maxRemainSum) maxRemainSum = remainSum;
}
if ((map = cache.get(firstIdx)) == null) cache.put(firstIdx, map = new HashMap<>());
if (maxRemainSum == NOT_ENOUGH_ELE) { // not enough elements,
map.put(L, NOT_ENOUGH_ELE); // cache - as not enough elements,
return NOT_ENOUGH_ELE;
}
long maxSum = arr[firstIdx] + maxRemainSum; // max sum,
map.put(L, maxSum); // cache - max sum,
return maxSum;
}
}
SubSumLimitedDistanceTest.java:
(test case, via TestNG
)
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.util.concurrent.ThreadLocalRandom;
public class SubSumLimitedDistanceTest {
private int[] arr;
private int K;
private int L;
private int maxSum;
private int[] arr2;
private int K2;
private int L2;
private int maxSum2;
private int[] arrMax;
private int KMax;
private int KMaxLargest;
private int LMax;
private int LMaxLargest;
@BeforeClass
private void setUp() {
// init - arr,
arr = new int[]{10, 1, 1, 1, 1, 10};
K = 2;
L = 3;
maxSum = 12;
// init - arr2,
arr2 = new int[]{8, 3, 7, 6, 2, 1, 9, 2, 5, 4};
K2 = 3;
L2 = 4;
maxSum2 = 30;
// init - arrMax,
arrMax = new int[SubSumLimitedDistance.MAX_ARR_LEN];
ThreadLocalRandom rd = ThreadLocalRandom.current();
long maxLongEle = Long.MAX_VALUE / SubSumLimitedDistance.MAX_ARR_LEN;
int maxEle = maxLongEle > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) maxLongEle;
for (int i = 0; i < arrMax.length; i++) {
arrMax[i] = rd.nextInt(maxEle);
}
KMax = 5;
LMax = 10;
KMaxLargest = SubSumLimitedDistance.MAX_K;
LMaxLargest = SubSumLimitedDistance.MAX_L;
}
@Test
public void test() {
Assert.assertEquals(SubSumLimitedDistance.find(arr, K, L), maxSum);
Assert.assertEquals(SubSumLimitedDistance.find(arr2, K2, L2), maxSum2);
}
@Test(timeOut = 6000)
public void test_veryLargeArray() {
run_printDuring(arrMax, KMax, LMax);
}
@Test(timeOut = 60000) // takes seconds,
public void test_veryLargeArrayL() {
run_printDuring(arrMax, KMax, LMaxLargest);
}
@Test(timeOut = 60000) // takes seconds,
public void test_veryLargeArrayK() {
run_printDuring(arrMax, KMaxLargest, LMax);
}
// run find once, and print during,
private void run_printDuring(int[] arr, int K, int L) {
long startTime = System.currentTimeMillis();
long sum = SubSumLimitedDistance.find(arr, K, L);
long during = System.currentTimeMillis() - startTime; // during in milliseconds,
System.out.printf("arr length = %5d, K = %3d, L = %4d, max sum = %15d, running time = %.3f seconds\n", arr.length, K, L, sum, during / 1000.0);
}
@Test
public void test_corner_notEnoughEle() {
Assert.assertEquals(SubSumLimitedDistance.find(new int[]{1}, 2, 3), SubSumLimitedDistance.NOT_ENOUGH_ELE); // not enough element,
Assert.assertEquals(SubSumLimitedDistance.find(new int[]{0}, 1, 3), SubSumLimitedDistance.NOT_ENOUGH_ELE); // not enough element,
}
@Test
public void test_corner_ZeroL() {
Assert.assertEquals(SubSumLimitedDistance.find(new int[]{1, 2, 3}, 2, 0), 0); // L = 0,
Assert.assertEquals(SubSumLimitedDistance.find(new int[]{0}, 1, 0), 0); // L = 0,
}
@Test(expectedExceptions = IllegalArgumentException.class)
public void test_invalid_K() {
// SubSumLimitedDistance.find(new int[]{1, 2, 3}, 0, 2); // K = 0,
// SubSumLimitedDistance.find(new int[]{1, 2, 3}, -1, 2); // K = -1,
SubSumLimitedDistance.find(new int[]{1, 2, 3}, SubSumLimitedDistance.MAX_K + 1, 2); // K = SubSumLimitedDistance.MAX_K+1,
}
@Test(expectedExceptions = IllegalArgumentException.class)
public void test_invalid_L() {
// SubSumLimitedDistance.find(new int[]{1, 2, 3}, 2, -1); // L = -1,
SubSumLimitedDistance.find(new int[]{1, 2, 3}, 2, SubSumLimitedDistance.MAX_L + 1); // L = SubSumLimitedDistance.MAX_L+1,
}
@Test(expectedExceptions = IllegalArgumentException.class)
public void test_invalid_tooLong() {
SubSumLimitedDistance.find(new int[SubSumLimitedDistance.MAX_ARR_LEN + 1], 2, 3); // input array too long,
}
}
Output of test case for large input:
arr length = 131072, K = 5, L = 10, max sum = 20779205738, running time = 0.303 seconds
arr length = 131072, K = 64, L = 10, max sum = 21393422854, running time = 1.917 seconds
arr length = 131072, K = 5, L = 256, max sum = 461698553839, running time = 9.474 seconds
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With