Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to compare Scala function values for equality

Tags:

scala

How can you compare two Scala function values for equality. The use case is that I have a list of functions where the list can contain duplicates and I only want to execute each function once.

If I have:

scala> object A {
     |   def a {}
     | }
defined module A

scala> val f1 = A.a _
f1: () => Unit = <function0>

scala> val f2 = A.a _
f2: () => Unit = <function0>

If I try to compare the function with either == or eq, I will get false in both cases:

scala> f1 == f2
res0: Boolean = false

scala> f1 eq f2
res1: Boolean = false
like image 249
jelovirt Avatar asked Oct 04 '12 10:10

jelovirt


2 Answers

Short answer: It's not possible.

Longer answer: You could have some kind of function factory that ensures that "identical" functions are acutally the same object. Depending on the architecture of your application, that might not be feasible though.

like image 190
Kim Stebel Avatar answered Nov 16 '22 23:11

Kim Stebel


I want to extent a bit on Kim's answer and give an example of how to achieve a limited comparability of function values.

If you have some kind of descriptive definition of your function, it is possible to check for equality on this description. For example, you can define a class (not an oo class) of simple arithmetic functions in the following way:

sealed trait ArthFun extends (Double => Double)
case class Mult(x: Double) extends ArthFun {def apply(y: Double) = x * y}
case class Add(x: Double) extends ArthFun {def apply(y: Double) = x + y}

With this setup, where an ArthFun is defined by its class and members, you can check for equality of values of the ArthFun type simply by object equality as defined by the case class.

scala> trait ArthFun extends (Double => Double)
defined trait ArthFun

scala> case class Mult(y: Double) extends ArthFun { def apply(x: Double) = x * y; override def toString = "*" + y}
defined class Mult

scala> case class Add(y: Double) extends ArthFun { def apply(x: Double) = x + y; override def toString = "+" + y }
defined class Add

scala> Seq(Mult(5),Mult(4),Add(4),Add(3),Mult(5)).distinct
res4: Seq[Product with ArthFun with Serializable] = List(*5.0, *4.0, +4.0, +3.0)
like image 42
ziggystar Avatar answered Nov 16 '22 23:11

ziggystar