Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generate cartesian product in decreasing sum order

Given n sorted lists A1, A2, ..., An of integers in decreasing order, is there an algorithm to efficiently generate all elements of their cartesian product in decreasing tuple sum order ?

For example, n=3

A1 = [9, 8, 0]
A2 = [4, 2]
A3 = [5, 1]

The expected output would be the cartesian product of A1xA2xA3 in the following order:
combination sum
9, 4, 5 18
8, 4, 5 17
9, 2, 5 16
8, 2, 5 15
9, 4, 1 14
8, 4, 1 13
9, 2, 1 12
8, 2, 1 11
0, 4, 5 9
0, 2, 5 7
0, 4, 1 5
0, 2, 1 3

like image 517
Ayoub Sbai Avatar asked Mar 21 '18 22:03

Ayoub Sbai


1 Answers

If the problem instance has N sets to cross, then you can think of the tuples in the product as an N-dimensional "rectangular" grid, where each tuple corresponds to a grid element. You'll start by emitting the max-sum tuple [9,4,5], which is at one corner of the grid.

You'll keep track of a "candidate set" of un-emitted tuples that are one smaller on each dimension with respect to at least one already emitted. If it helps, you can visualize the already-emitted tuples as a "solid" in the grid. The candidate set is all tuples that touch the solid's surface.

You'll repeatedly choose the next tuple to emit from the candidate set, then update the set with the neighbors of the newly emitted tuple. When the set is empty, you're done.

After emitting [9,4,5], the candidate set is

[8,4,5]  (one smaller on first dimension)
[9,2,5]  (one smaller on second dimension)
[9,4,1]  (one smaller on third dimension) 

Next emit the one of these with the largest sum. That's [8,4,5]. Adjacent to that are

[0,4,5], [8,2,5], [8,4,1]

Add those to the candidate set, so we now have

[9,2,5], [9,4,1], [0,4,5], [8,2,5], [8,4,1]

Again pick the highest sum. That's [9,2,5]. Adjacent are

[8,2,5], [9,2,1]. 

So the new candidate set is

[9,4,1], [0,4,5], [8,2,5], [8,4,1], [9,2,1]

Note [8,2,5] came up again. Don't duplicate it.

This time the highest sum is [8,2,5]. Adjacent are

[0,2,5], [8,2,1]

At this point you should have the idea.

Use a max heap for the candidate set. Then finding the tuple with max sum requires O(log |C|) where C is the candidate set.

How big can the set get? Interesting question. I'll let you think about it. For 3 input sets as in your example, it's

|C| = O(|A1||A2| + |A2||A3| + |A1||A3|)

So the cost of emitting each tuple is

O(log(|A1||A2| + |A2||A3| + |A1||A3|))

If the sets have size at most N, then this is O(log 3 N^2) = O(log 3 + 2 log N) = O(log N).

There are |A1||A2||A3| tuples to emit, which is O(N^3).

The simpler algorithm of generating all tuples and sorting is O(log N^3) = O(3 log N) = O(log N). It's very roughly only 50% slower, which is asymptotically the same. The main advantage of the more complex algorithm is that it saves O(N) space. The heap/priority queue size is only O(N^2).

Here is a quick Java implementation meant to keep the code size small.

import java.util.Arrays;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Set;

public class SortedProduct {
  final SortedTuple [] tuples;
  final NoDupHeap candidates = new NoDupHeap();

  SortedProduct(SortedTuple [] tuple) {
    this.tuples = Arrays.copyOf(tuple, tuple.length);
    reset();
  }

  static class SortedTuple {
    final int [] elts;

    SortedTuple(int... elts) {
      this.elts = Arrays.copyOf(elts, elts.length);
      Arrays.sort(this.elts);
    }

    @Override
    public String toString() {
      return Arrays.toString(elts);
    }
  }

  class RefTuple {
    final int [] refs;
    final int sum;

    RefTuple(int [] index, int sum) {
      this.refs = index;
      this.sum = sum;
    }

    RefTuple getSuccessor(int i) {
      if (refs[i] == 0) return null;
      int [] newRefs = Arrays.copyOf(this.refs, this.refs.length);
      int j = newRefs[i]--;
      return new RefTuple(newRefs, sum - tuples[i].elts[j] + tuples[i].elts[j - 1]);
    }

    int [] getTuple() {
      int [] val = new int[refs.length];
      for (int i = 0; i < refs.length; ++i) 
        val[i] = tuples[i].elts[refs[i]];
      return val;
    }

    @Override
    public int hashCode() {
      return Arrays.hashCode(refs);
    }

    @Override
    public boolean equals(Object o) {
      if (o instanceof RefTuple) {
        RefTuple t = (RefTuple) o;
        return Arrays.equals(refs, t.refs);
      }
      return false;
    }
  }

  RefTuple getInitialCandidate() {
    int [] index = new int[tuples.length];
    int sum = 0;
    for (int j = 0; j < index.length; ++j) 
      sum += tuples[j].elts[index[j] = tuples[j].elts.length - 1];
    return new RefTuple(index, sum);
  }

  final void reset() {
    candidates.clear();
    candidates.add(getInitialCandidate());
  }

  int [] getNext() {
    if (candidates.isEmpty()) return null;
    RefTuple next = candidates.poll();
    for (int i = 0; i < tuples.length; ++i) {
      RefTuple successor = next.getSuccessor(i);
      if (successor != null) candidates.add(successor);
    }
    return next.getTuple();
  }

  /** A max heap of indirect ref tuples that ignores addition of duplicates. */
  static class NoDupHeap {
    final PriorityQueue<RefTuple> heap = 
        new PriorityQueue<>((a, b) -> Integer.compare(b.sum, a.sum));
    final Set<RefTuple> set = new HashSet<>();

    void add(RefTuple t) {
      if (set.contains(t)) return;
      heap.add(t);
      set.add(t);
    }

    RefTuple poll() {
      RefTuple t = heap.poll();
      set.remove(t);
      return t;
    }

    boolean isEmpty() {
      return heap.isEmpty();
    }

    void clear() {
      heap.clear();
      set.clear();
    }
  }

  public static void main(String [] args) {
    SortedTuple [] tuples = {
      new SortedTuple(9, 8, 0),
      new SortedTuple(4, 2),
      new SortedTuple(5, 1),
    };
    SortedProduct product = new SortedProduct(tuples);
    for (;;) {
      int[] next = product.getNext();
      if (next == null) break;
      System.out.println(Arrays.toString(next));
    }
  }
}
like image 154
Gene Avatar answered Nov 07 '22 12:11

Gene