Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala: Get sum of nth element from tuple array/RDD

I have a array of tuple like this:

val a = Array((1,2,3), (2,3,4))

I want to write a generic method for a method like below:

def sum2nd(aa: Array[(Int, Int, Int)]) = {
      aa.map { a => a._2 }.sum
      }

So what I am looking for a method like:

def sumNth(aa: Array[(Int, Int, Int)], n: Int)
like image 256
Mohitt Avatar asked Jan 29 '16 18:01

Mohitt


1 Answers

There are a few ways you can go about this. The simplest is to use productElement:

def unsafeSumNth[P <: Product](xs: Seq[P], n: Int): Int =
  xs.map(_.productElement(n).asInstanceOf[Int]).sum

And then (note that indexing starts at zero, so n = 1 gives us the second element):

scala> val a = Array((1, 2, 3), (2, 3, 4))
a: Array[(Int, Int, Int)] = Array((1,2,3), (2,3,4))

scala> unsafeSumNth(a, 1)
res0: Int = 5

This implementation can crash at runtime in two different ways, though:

scala> unsafeSumNth(List((1, 2), (2, 3)), 3)
java.lang.IndexOutOfBoundsException: 3
  at ...

scala> unsafeSumNth(List((1, "a"), (2, "b")), 1)
java.lang.ClassCastException: java.lang.String cannot be cast to java.lang.Integer
  at ...

I.e., if the tuple doesn't have enough elements, or if the element you're asking for isn't an Int.

You can write a version that doesn't crash at runtime:

import scala.util.Try

def saferSumNth[P <: Product](xs: Seq[P], n: Int): Try[Int] = Try(
  xs.map(_.productElement(n).asInstanceOf[Int]).sum
)

And then:

scala> saferSumNth(a, 1)
res4: scala.util.Try[Int] = Success(5)

scala> saferSumNth(List((1, 2), (2, 3)), 3)
res5: scala.util.Try[Int] = Failure(java.lang.IndexOutOfBoundsException: 3)

scala> saferSumNth(List((1, "a"), (2, "b")), 1)
res6: scala.util.Try[Int] = Failure(java.lang.ClassCastException: ...

This is an improvement, since it forces callers to address the possibility of failure, but it's also kind of annoying, since it forces callers to address the possibility of failure.

If you're willing to use Shapeless you can have the best of both worlds:

import shapeless._, shapeless.ops.tuple.At

def sumNth[P <: Product](xs: Seq[P], n: Nat)(implicit
  atN: At.Aux[P, n.N, Int]
): Int = xs.map(p => atN(p)).sum

And then:

scala> sumNth(a, 1)
res7: Int = 5

But the bad ones don't even compile:

scala> sumNth(List((1, 2), (2, 3)), 3)
<console>:17: error: could not find implicit value for parameter atN: ...

This still isn't perfect, though, since it means the second argument has to be a literal number (since it needs to be known at compile time):

scala> val x = 1
x: Int = 1

scala> sumNth(a, x)
<console>:19: error: Expression x does not evaluate to a non-negative Int literal
       sumNth(a, x)
                 ^

In many cases that's not a problem, though.

To sum up: If you're willing to take responsibilty for reasonable code crashing your program, use productElement. If you want a little more safety (at the cost of some inconvenience), use productElement with Try. If you want compile-time safety (but some limitations), use Shapeless.

like image 97
Travis Brown Avatar answered Sep 21 '22 03:09

Travis Brown