Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark GraphX Aggregation Summation

I'm trying to compute the sum of node values in a spark graphx graph. In short the graph is a tree and the top node (root) should sum all children and their children. My graph is actually a tree that looks like this and the expected summed value should be 1850:

                                     +----+
                     +--------------->    |  VertexID 14
                     |               |    |  Value: 1000
                 +---+--+            +----+
    +------------>      | VertexId 11
    |            |      | Value:     +----+
    |            +------+ Sum of 14 & 24  |  VertexId 24
+---++                +-------------->    |  Value: 550
|    | VertexId 20                   +----+
|    | Value:
+----++Sum of 11 & 911
      |
      |           +-----+
      +----------->     | VertexId 911
                  |     | Value: 300
                  +-----+

The first stab at this looks like this:

val vertices: RDD[(VertexId, Int)] =
      sc.parallelize(Array((20L, 0)
        , (11L, 0)
        , (14L, 1000)
        , (24L, 550)
        , (911L, 300)
      ))

  //note that the last value in the edge is for factor (positive or negative)
    val edges: RDD[Edge[Int]] =
      sc.parallelize(Array(
        Edge(14L, 11L, 1),
        Edge(24L, 11L, 1),
        Edge(11L, 20L, 1),
        Edge(911L, 20L, 1)
      ))

    val dataItemGraph = Graph(vertices, edges)


    val sum: VertexRDD[(Int, BigDecimal, Int)] = dataItemGraph.aggregateMessages[(Int, BigDecimal, Int)](
      sendMsg = { triplet => triplet.sendToDst(1, triplet.srcAttr, 1) },
      mergeMsg = { (a, b) => (a._1, a._2 * a._3 + b._2 * b._3, 1) }
    )

    sum.collect.foreach(println)

This returns the following:

(20,(1,300,1))
(11,(1,1550,1))

It's doing the sum for vertex 11 but it's not rolling up to the root node (vertex 20). What am I missing or is there a better way of doing this? Of course the tree can be of arbitrary size and each vertex can have an arbitrary number of children edges.

like image 395
will Avatar asked Jan 03 '17 20:01

will


1 Answers

Given the graph is directed (as in you example it seems to be) it should be possible to write a Pregel program that does what you're asking for:

val result = 
 dataItemGraph.pregel(0, activeDirection = EdgeDirection.Out)(
  (_, vd, msg) => msg + vd, 
  t => Iterator((t.dstId, t.srcAttr)), 
  (x, y) => x + y
 )

 result.vertices.collect().foreach(println)

// Output is:
// (24,550)
// (20,1850)
// (14,1000)
// (11,1550)
// (911,300)

I'm using EdgeDirection.Out so that the messages are being send only from bottom to up (otherwise we would get into an endless loop).

like image 76
lpiepiora Avatar answered Sep 22 '22 17:09

lpiepiora