Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dynamic prefix sum

Is there any data structure which is able to return the prefix sum [1] of array, update an element, and insert/remove elements to the array, all in O(log n)?

[1] "prefix sum" is the sum of all elements from the first one up to given index

For example, given the array of non-negative integers 8 1 10 7 the prefix sum for first three elements is 19 (8 + 1 + 10). Updating the first element to 7, inserting 3 as the second element and removing the third one gives 7 3 10 7. Again, the prefix sum of first three elements would be 20.

For prefix sum and update, there is Fenwick tree. But I don't know how to handle the addition/removal in O(log n) with it.

On the other hand, there are several binary search trees such as Red-black tree, all of which handle the update/insert/remove in logarithmic time. But I don't know how to maintain the given ordering and do the prefix sum in O(log n).

like image 549
Ecir Hana Avatar asked Jan 16 '15 18:01

Ecir Hana


People also ask

What is prefix sum in Python?

Prefix sums is a simple yet powerful technique that we can use to easily calculate the sum of a segment or an array. It allows us to use a reusable lookup array to look up the sum for the whole array in constant time.

How to fill the prefix sum array in JavaScript?

To fill the prefix sum array, we run through index 1 to last and keep on adding the present element with the previous value in the prefix sum array.

How do you find the sum of prefixes in a sequence?

... Prefix sums are trivial to compute in sequential models of computation, by using the formula yi = yi − 1 + xi to compute each output value in sequence order.

Is there a parallel prefix sum algorithm?

There are two key algorithms for computing a prefix sum in parallel. The first offers a shorter span and more parallelism but is not work-efficient. The second is work-efficient but requires double the span and offers less parallelism. These are presented in turn below. Hillis and Steele present the following parallel prefix sum algorithm:


2 Answers

A treap with implicit keys can perform all this operations in O(log n) time per query. The idea of implicit keys is pretty simple: we do not store any keys in nodes. Instead, we maintain subtrees' sizes for all nodes and find an appropriate position when we add or remove an element using this information.

Here is my implementation:

#include <iostream>
#include <memory>

struct Node {
  int priority;
  int val;
  long long sum;
  int size;
  std::shared_ptr<Node> left;
  std::shared_ptr<Node> right;

  Node(long val): 
    priority(rand()), val(val), sum(val), size(1), left(), right() {}
};

// Returns the size of a node owned by t if it is not empty and 0 otherwise.
int getSize(std::shared_ptr<Node> t) {
  if (!t)
    return 0;
  return t->size;
}

// Returns the sum of a node owned by t if it is not empty and 0 otherwise.
long long getSum(std::shared_ptr<Node> t) {
  if (!t)
    return 0;
  return t->sum;
}


// Updates a node owned by t if it is not empty.
void update(std::shared_ptr<Node> t) {
  if (t) {
    t->size = 1 + getSize(t->left) + getSize(t->right);
    t->sum = t->val + getSum(t->left) + getSum(t->right);
  }
}

// Merges the nodes owned by L and R and returns the result.
std::shared_ptr<Node> merge(std::shared_ptr<Node> L, 
    std::shared_ptr<Node> R) {
  if (!L || !R)
    return L ? L : R;
  if (L->priority > R->priority) {
    L->right = merge(L->right, R);
    update(L);
    return L;
  } else {
    R->left = merge(L, R->left);
    update(R);
    return R;
  }
}

// Splits a subtree rooted in t by pos. 
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> split(
    std::shared_ptr<Node> t,
    int pos, int add) {
  if (!t)
    return make_pair(std::shared_ptr<Node>(), std::shared_ptr<Node>());
  int cur = getSize(t->left) + add;
  std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> res;
  if (pos <= cur) {
    auto ret = split(t->left, pos, add);
    t->left = ret.second;
    res = make_pair(ret.first, t); 
  } else {
    auto ret = split(t->right, pos, cur + 1);
    t->right = ret.first;
    res = make_pair(t, ret.second); 
  }
  update(t);
  return res;
}

// Returns a prefix sum of [0 ... pos]
long long getPrefixSum(std::shared_ptr<Node>& root, int pos) {
  auto parts = split(root, pos + 1, 0);
  long long res = getSum(parts.first);
  root = merge(parts.first, parts.second);
  return res;
}

// Adds a new element at a position pos with a value newValue.
// Indices are zero-based.
void addElement(std::shared_ptr<Node>& root, int pos, int newValue) {
  auto parts = split(root, pos, 0);
  std::shared_ptr<Node> newNode = std::make_shared<Node>(newValue);
  auto temp = merge(parts.first, newNode);
  root = merge(temp, parts.second);
}

// Removes an element at the given position pos.
// Indices are zero-based.
void removeElement(std::shared_ptr<Node>& root, int pos) {
  auto parts1 = split(root, pos, 0);
  auto parts2 = split(parts1.second, 1, 0);
  root = merge(parts1.first, parts2.second);
}

int main() {
  std::shared_ptr<Node> root;
  int n;
  std::cin >> n;
  for (int i = 0; i < n; i++) {
    std::string s;
    std::cin >> s;
    if (s == "add") {
      int pos, val;
      std::cin >> pos >> val;
      addElement(root, pos, val);
    } else if (s == "remove") {
      int pos;
      std::cin >> pos;
      removeElement(root, pos);
    } else {
      int pos;
      std::cin >> pos;
      std::cout << getPrefixSum(root, pos) << std::endl;
    }
  }
  return 0;
}
like image 91
kraskevich Avatar answered Sep 22 '22 10:09

kraskevich


An idea: to modify an AVL tree. Additions and deletions are done by index. Every node keeps the count and the sum of each subtree to allow all operations in O(log n).

Proof-of-concept with add_node and update_node and prefix_sum implemented:

class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.left_height = 0
        self.right_height = 0
        self.left_count = 1
        self.left_sum = value
        self.right_count = 0
        self.right_sum = 0

    def set_value(self, value):
        self.value = value
        self.left_sum = self.left.left_sum + self.left.right_sum+self.value if self.left else self.value

    def set_left(self, node):
        self.left = node
        self.left_height = max(node.left_height, node.right_height)+1 if node else 0
        self.left_count = node.left_count + node.right_count+1 if node else 1
        self.left_sum = node.left_sum + node.right_sum+self.value if node else self.value

    def set_right(self, node):
        self.right = node
        self.right_height = max(node.left_height, node.right_height)+1 if node else 0
        self.right_count = node.left_count + node.right_count if node else 0
        self.right_sum = node.left_sum + node.right_sum if node else 0

    def rotate_left(self):
        b = self.right
        self.set_right(b.left)
        b.set_left(self)
        return b

    def rotate_right(self):
        a = self.left
        self.set_left(a.right)
        a.set_right(self)
        return a

    def factor(self):
        return self.right_height - self.left_height

def add_node(root, index, node):
    if root is None: return node

    if index < root.left_count:
        root.set_left(add_node(root.left, index, node))
        if root.factor() < -1:
            if root.left.factor() > 0:
                root.set_left(root.left.rotate_left())
            return root.rotate_right()
    else:
        root.set_right(add_node(root.right, index-root.left_count, node))
        if root.factor() > 1:
            if root.right.factor() < 0:
                root.set_right(root.right.rotate_right())
            return root.rotate_left()

    return root

def update_node(root, index, value):
    if root is None: return root

    if index+1 < root.left_count:
        root.set_left(update_node(root.left, index, value))
    elif index+1 > root.left_count:
        root.set_right(update_node(root.right, index - root.left_count, value))
    else:
        root.set_value(value)

    return root


def prefix_sum(root, index):
    if root is None: return 0

    if index+1 < root.left_count:
        return prefix_sum(root.left, index)
    else:
        return root.left_sum + prefix_sum(root.right, index-root.left_count)


import random
tree = None
tree = add_node(tree, 0, Node(10))
tree = add_node(tree, 1, Node(40))
tree = add_node(tree, 1, Node(20))
tree = add_node(tree, 2, Node(70))

tree = update_node(tree, 2, 30)

print prefix_sum(tree, 0)
print prefix_sum(tree, 1)
print prefix_sum(tree, 2)
print prefix_sum(tree, 3)
print prefix_sum(tree, 4)
like image 42
Juan Lopes Avatar answered Sep 23 '22 10:09

Juan Lopes