Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Split vector into balanced list (balancing sum of list elements)

Having a hard time figuring out an efficient solution to the following problem. The question is very verbose because I'm not sure if I'm making this problem harder than it can be.

Given a named vector

t <- c(2, 0, 0, 30, 0, 0, 10, 2000, 0, 20, 0, 40, 60, 10)
names(t) <- c(1, 0, 0, 2, 0, 0, 3, 4, 0, 5, 0, 6, 7, 8)

I want to split t into a list of 4 elements that's balanced based on the sum of the resulting list elements while keeping the order of elements, and only splitting on non-zero elements. Expected outcome

L[1] <- c(2, 0, 0, 30, 0, 0, 10) # sum = 42
L[2] <- c(2000, 0)               # sum = 2000
L[3] <- c(20, 0, 40)             # sum = 60
L[4] <- c(60, 10)                # sum = 70

The error function I use is minimizing sd(rowSums(L)) or sd(sapply(L, sum))

Trying to split the vector using something like the following doesn't quite work

split(t, cut(cumsum(t), 4))

# $`(-0.17,544]`
 # 1  0  0  2  0  0  3 
 # 2  0  0 30  0  0 10 

# $`(544,1.09e+03]`
# named numeric(0)

# $`(1.09e+03,1.63e+03]`
# named numeric(0)

# $`(1.63e+03,2.17e+03]`
   # 4    0    5    0    6    7    8 
# 2000    0   20    0   40   60   10 

I wrote a function to split the list the way that I wanted (see error function above)

break_at <- function(val, nchunks) {
    nchunks <- nchunks - 1
    nonzero <- val[val != 0]
    all_groupings <- as.matrix(gtools::permutations(n = 2, r = length(nonzero), v = c(1, 0), repeats.allowed = TRUE))
    all_groupings <- all_groupings[rowSums(all_groupings) == nchunks, ]
    which_grouping <- which.min(
    sapply(
        1:nrow(all_groupings), 
        function(i) { 
            sd(
                sapply(
                    split(
                        nonzero, 
                        cumsum(all_groupings[i,])
                    ), 
                    sum
                )
            )
        }
    )
    )
    mark_breaks <- rep(0, length(val))
    mark_breaks[names(val) %in% which(all_groupings[which_grouping,]==1)] <- 1
    return(mark_breaks)
}

You can see the result is much better

break_at(t, 4)
# 0 0 0 0 0 0 0 1 0 1 0 0 1 0

split(t, cumsum(break_at(t, 4)))

# $`0`
 # 1  0  0  2  0  0  3 
 # 2  0  0 30  0  0 10 

# $`1`
   # 4    0 
# 2000    0 

# $`2`
 # 5  0  6 
# 20  0 40 

# $`3`
 # 7  8 
# 60 10 

It works by using gtools::permutations(n = 2, r = length(nonzero), v = c(1, 0), repeats.allowed = TRUE) to look at all potential splits. See how the above works for r = 3

     # [,1] [,2] [,3]
# [1,]    0    0    0
# [2,]    0    0    1
# [3,]    0    1    0
# [4,]    0    1    1
# [5,]    1    0    0
# [6,]    1    0    1
# [7,]    1    1    0
# [8,]    1    1    1

which I then filter, all_groupings[rowSums(all_groupings) == nchunks, ]. This only looks at potential splits that produce nchunks.

My issue is that this works horribly with my real data because of the number of permutations involved.

hard <- structure(c(2, 0, 1, 2, 0, 1, 1, 1, 5, 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 0, 0, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1, 1,
1, 1, 2, 0, 2, 0, 1, 4, 0, 0, 0, 1, 3, 0, 0, 4, 0, 0, 0, 2, 0,
1, 1, 1, 3, 0, 0, 1, 1, 1, 1, 2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 6,
0, 0, 0, 0, 0, 1, 1, 1, 3, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0,
1, 1, 2, 0, 1, 2, 0, 1, 1, 4, 0, 0, 0, 1, 1, 3, 0, 0, 1, 2, 0,
1, 1, 2, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 1, 1, 2, 0, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 1, 1, 2, 0, 3,
0, 0, 1, 1, 2, 0, 1, 2, 0, 1, 1, 1, 2, 0, 2, 0, 1, 3, 0, 0, 1,
1, 1, 1, 1, 2, 0, 1, 1, 1, 2, 0, 1, 2, 0, 1, 1, 1, 1, 1, 1, 2,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
0, 1, 1, 1, 1, 1, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 2, 0, 1, 1, 1, 2, 0, 1, 1, 1, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0,
1, 2, 0, 1, 1, 1, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1, 2, 0, 1, 1, 1,
3, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 1, 1,
1, 1, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1, 3, 0,
0, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 1, 1, 1, 1,
1, 1, 1, 2, 0, 1, 1, 1, 1, 5, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 2, 0, 1, 1, 1, 1, 2, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 2, 0, 1, 1, 2, 0, 1, 2, 0, 1, 8, 0, 0, 0, 0, 0, 0, 0, 2,
0, 1, 9, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, 0, 1, 1, 1,
1, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 1, 3, 0, 0, 1, 1, 1, 3,
0, 0, 7, 0, 0, 0, 0, 0, 0, 1, 1, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1, 1, 1), .Names = c("1", "0",
"2", "3", "0", "4", "5", "6", "7", "0", "0", "0", "0", "8", "9",
"10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
"21", "22", "23", "24", "0", "0", "25", "26", "27", "28", "29",
"30", "31", "32", "33", "34", "35", "36", "37", "38", "39", "0",
"40", "41", "42", "43", "0", "44", "45", "46", "47", "48", "49",
"50", "51", "52", "0", "53", "0", "54", "55", "0", "0", "0",
"56", "57", "0", "0", "58", "0", "0", "0", "59", "0", "60", "61",
"62", "63", "0", "0", "64", "65", "66", "67", "68", "0", "69",
"70", "0", "71", "72", "73", "0", "74", "75", "0", "0", "0",
"0", "0", "76", "77", "78", "79", "0", "0", "80", "81", "82",
"83", "84", "85", "86", "87", "88", "0", "89", "90", "91", "0",
"92", "93", "0", "94", "95", "96", "0", "0", "0", "97", "98",
"99", "0", "0", "100", "101", "0", "102", "103", "104", "0",
"105", "106", "0", "0", "107", "108", "0", "0", "109", "110",
"111", "112", "0", "113", "114", "115", "116", "117", "118",
"119", "120", "121", "122", "123", "124", "125", "126", "127",
"128", "129", "130", "131", "0", "132", "133", "134", "0", "135",
"0", "0", "136", "137", "138", "0", "139", "140", "0", "141",
"142", "143", "144", "0", "145", "0", "146", "147", "0", "0",
"148", "149", "150", "151", "152", "153", "0", "154", "155",
"156", "157", "0", "158", "159", "0", "160", "161", "162", "163",
"164", "165", "166", "0", "167", "168", "169", "170", "171",
"172", "173", "174", "175", "176", "177", "178", "179", "180",
"181", "182", "183", "184", "185", "186", "0", "187", "188",
"189", "190", "191", "192", "0", "0", "0", "0", "0", "0", "0",
"0", "0", "0", "193", "194", "195", "196", "197", "0", "198",
"199", "200", "201", "0", "202", "203", "204", "205", "0", "206",
"0", "0", "0", "0", "0", "0", "0", "207", "208", "0", "209",
"210", "211", "212", "213", "214", "215", "0", "216", "217",
"218", "219", "220", "221", "0", "222", "223", "224", "225",
"0", "0", "226", "227", "228", "229", "230", "231", "232", "233",
"234", "235", "236", "237", "238", "239", "240", "0", "241",
"242", "243", "244", "245", "246", "247", "248", "0", "249",
"250", "251", "252", "253", "254", "0", "255", "256", "257",
"258", "259", "260", "0", "0", "261", "262", "263", "264", "0",
"265", "266", "267", "268", "269", "270", "271", "272", "273",
"274", "0", "275", "276", "277", "278", "279", "280", "281",
"282", "0", "283", "284", "285", "286", "287", "0", "0", "0",
"0", "288", "0", "0", "0", "0", "0", "289", "290", "291", "292",
"293", "294", "295", "296", "297", "298", "299", "300", "301",
"302", "303", "304", "305", "306", "307", "308", "309", "310",
"311", "312", "313", "314", "315", "316", "317", "318", "319",
"320", "321", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0",
"0", "0", "322", "323", "324", "325", "326", "327", "328", "329",
"330", "331", "332", "333", "334", "335", "336", "337", "338",
"339", "340", "341", "0", "342", "343", "344", "345", "346",
"0", "347", "0", "348", "349", "350", "351", "352", "353", "354",
"355", "356", "357", "358", "359", "360", "0", "361", "362",
"363", "0", "364", "365", "0", "366", "367", "0", "0", "0", "0",
"0", "0", "0", "368", "0", "369", "370", "0", "0", "0", "0",
"0", "0", "0", "0", "371", "0", "0", "372", "0", "0", "0", "373",
"374", "375", "376", "377", "0", "0", "0", "0", "0", "378", "0",
"0", "0", "0", "0", "379", "380", "0", "0", "381", "382", "383",
"384", "0", "0", "385", "0", "0", "0", "0", "0", "0", "386",
"387", "388", "0", "389", "390", "391", "392", "393", "394",
"395", "396", "397", "398", "399", "400", "401", "402", "0",
"403", "404", "405", "406", "407", "408", "409"))
like image 959
CPak Avatar asked Feb 08 '18 04:02

CPak


2 Answers

I don't know if there are some analytical solutions. But if you treat it as a integer programming problem you could use the "SANN" heuristics implemented in optim. For example, consider some (sub-optimal) random split points to cut the vector t

> startpar <- sort(sample(length(t)-1, 3))
> startpar
[1] 5 6 9
> # result in a sub-optimal split
> split(t, cut(1:length(t), c(0, startpar, length(t)), labels = 1:4))
$`1`
 1  0  0  2  0 
 2  0  0 30  0 

$`2`
0 
0 

$`3`
   3    4    0 
  10 2000    0 

$`4`
 5  0  6  7  8 
20  0 40 60 10 

The error function could be written as

> # from manual: A function to be minimized (or maximized)
> fn <- function(par, vec){
+   ind_vec <- cut(1:length(vec), c(0, par, length(vec)), labels = 1:4)
+   sd(unlist(lapply(split(vec, ind_vec), sum)))
+ }
> # evaluated at the starting parameters
> fn(startpar, t)
[1] 979.5625

The "SANN" heuristics (Simulated annealing) needs a method to generate a new candidate solution. There can be more sophisticated ways to select either the functions or the starting values, but the present choices still lead to the/an [edit:] near optimal solution (and maybe in acceptable time?).

> # from manual: For the "SANN" method it specifies a function to generate a new candidate point
> gr <- function(par, vec){
+   ind <- sample(length(par), 1)
+   par[ind] <- par[ind] + sample(-1:1, 1)
+   par[ind] <- max(c(par[ind], ifelse(ind == 1, 1, par[ind - 1] + 1)))
+   par[ind] <- min(c(par[ind], ifelse(ind == 3, length(vec) - 1, par[ind + 1] - 1)))
+   par
+ }

Applied to the toy data

> optimpar <- optim(startpar, fn, gr, method = "SANN", vec = t)$par
> split(t, cut(1:length(t), c(0, optimpar, length(t)), labels = 1:4))
$`1`
 1  0  0  2 
 2  0  0 30 

$`2`
 0  0  3 
 0  0 10 

$`3`
   4 
2000 

$`4`
 0  5  0  6  7  8 
 0 20  0 40 60 10 

> fn(optimpar, t)
[1] 972.7329
> 

Applied to the real data

> # use for "hard"
> startpar <- sort(sample(length(hard)-1, 3))
> optimpar <- optim(startpar, fn, gr, method = "SANN", vec = hard)
> optimpar
$par
[1] 146 293 426

$value
[1] 4.573474
...[output shortened]

[Edit] since my initial results were sub-optimal.

I'm sure you found a sufficient alternative yourself already, but for the sake of completeness: Regarding the present toy and real data examples a better choice for gr (I'll call it gr2 for later reference) would have a different sampling length (e.g. dependent on the length of the data) in order to generate the new candidate which will be less dependent from the incumbent (the current solution). For example

> gr2 <- function(par, vec){
+   ind <- sample(length(par), 1)
+   l <- round(log(length(vec), 2))
+   par[ind] <- par[ind] + sample(-l:l, 1)
+   par[ind] <- max(c(par[ind], ifelse(ind == 1, 1, par[ind - 1] + 1)))
+   par[ind] <- min(c(par[ind], ifelse(ind == 3, length(vec) - 1, par[ind + 1] - 1)))
+   par
+ }

For the real data resulting in

> set.seed(1337)
> 
> startpar <- sort(sample(length(hard)-1, 3))
> opt <- optim(startpar, fn, gr2, method = "SANN", vec = hard)
> opt$value
[1] 4.5
> lapply(split(hard, cut(1:length(hard), c(0, opt$par, length(hard)), labels = 1:4)), sum)
$`1`
[1] 140

$`2`
[1] 141

$`3`
[1] 144

$`4`
[1] 150

And for the toy data resulting in

> startpar <- sort(sample(length(t)-1, 3))
> opt <- optim(startpar, fn, gr2, method = "SANN", vec = t)
> opt$value
[1] 971.4024
> split(t, cut(1:length(t), c(0, opt$par, length(t)), labels = 1:4))
$`1`
 1  0  0  2  0  0  3 
 2  0  0 30  0  0 10 

$`2`
   4 
2000 

$`3`
 0  5  0  6 
 0 20  0 40 

$`4`
 7  8 
60 10 

Regarding the optimality for the real data (using gr2), I ran a short simulation of 100 optimization runs from different starting parameters: Each of those runs terminated at a value of 4.5.

like image 188
Tom Avatar answered Oct 10 '22 12:10

Tom


By using dynamic programming you can get the true optimum in O(N^2) time. The trick is to see that minimizing the standard deviation is the same as minimizing the sum of squares of rowSums. Since the error contributions of each subvector are independent, we can reduce the search-space of possible splits by ignoring extensions of suboptimal splits of subvectors.

If for instance (3, 5) is a better split for V[1:7] than (2, 4), then every split of V starting with (3, 5, 8,...) is better than every split starting with (2, 4, 8, ...). So if we for each 1 < k < len(V) find the best 2-group split of 'V[1:k]', we can find the best into 3-group split of each V[1:k] by only considering extensions of the optimal 2-group splits of the subvectors V[1:k]. In general we find the best (n+1)-group spilt by extending the optimal n-group splits.

The balanced.split function below takes in a vector of values and the number of splits and returns a list of subvectors. This yields the a solution with row sums 140,141,144,150 on the hard set.

balanced.split <- function(all.values, n.splits) {
    nonzero.idxs <- which(all.values!=0)
    values <- all.values[nonzero.idxs]
    cumsums = c(0, cumsum(values))
    error.table <- outer(cumsums, cumsums, FUN='-')**2
    # error.table[i, j] = error contribution of segment
    # values[i:(j-1)]

    # Iteratively find best i splits
    index.matrix <- array(dim=c(n.splits-1, ncol(error.table)))
    cur.best.splits <- error.table[1, ]
    for (i in 1:(n.splits-1)){
        error.sums <- cur.best.splits + error.table
        index.matrix[i, ] <- apply(error.sums, 2, which.min)
        # index.matrix[i, k] = last split of optimal (i+1)-group
        # split of values[1:k]
        cur.best.splits <- apply(error.sums, 2, min)
        # cur.best.splits[k] = minimal error function
        # of (i+1)-group split of values[1:k]
    }
    # Trace best splits
    cur.idx <- ncol(index.matrix)
    splits <- vector("numeric", n.splits-1)
    for (i in (n.splits-1):1) {
        cur.idx = index.matrix[i, cur.idx]
        splits[i] <- cur.idx
    }
    # Split values vector
    splits <- c(1, nonzero.idxs[splits], length(all.values)+1)
    chunks <- list()
    for (i in 1:n.splits)
        chunks[[i]] <- all.values[splits[i]:(splits[i+1]-1)]
    return(chunks)
}

Below is more detailed code for the same algorithm

# Matrix containing the error contribution of 
# subsegments [i:j]
.makeErrorTable <- function(values) {
    cumsums = c(0, cumsum(values))
    return(outer(cumsums, cumsums, FUN='-')**2)
}

# Backtrace the optimal split points from an index matrix
.findPath <- function(index.matrix){
    nrows <- nrow(index.matrix)
    cur.idx <- ncol(index.matrix) 
    path <- vector("numeric", nrows)
    for (i in nrows:1) {
        cur.idx = index.matrix[i, cur.idx]
        path[i] <- cur.idx
    }
    return(path)
}

.findSplits <- function(error.table, n.splits) {
    n.diffs <- nrow(error.table)
    max.val <- error.table[1, n.diffs]

    # Table used to backtrace the optimal path
    idx.table <- array(dim=c(n.splits-1, n.diffs))
    cur.best.splits <- error.table[1, ]
    for (i in 1:(n.splits-1)){
        error.sums <- cur.best.splits + error.table
        idx.table[i, ] <- apply(error.sums, 2, which.min)
        cur.best.splits <- apply(error.sums, 2, min)
    }
    return(.findPath(idx.table))
}

# Split values at given split points
.splitChunks <- function(values, splits) {
    splits <- c(1, splits, length(values)+1)
    chunks <- list()
    for (i in 1:(length(splits)-1))
        chunks[[i]] <- values[splits[i]:(splits[i+1]-1)]
    return(chunks)
}

#' Main function that splits all.values into n.splits
#' chunks, minimizing sd(sum(chunk))    
balanced.split <- function(all.values, n.splits) {
    nonzero.idxs <- which(all.values!=0)
    values <- all.values[nonzero.idxs]
    error.table <- .makeErrorTable(values)
    splits <- .findSplits(error.table, n.splits)
    full.splits <- nonzero.idxs[splits]
    return(.splitChunks(all.values, full.splits))
}
like image 43
kuppern87 Avatar answered Oct 10 '22 12:10

kuppern87