As a part of learning Scala I try to implement Haskell's flip function (a function with signature (A => B => C) => (B => A => C)) in Scala - and implement it as a function (using val) and not as a method (using def).
I can implement it as a method, for instance this way:
def flip[A, B, C](f: (A, B) => C):((B, A) => C) = (b: B, a: A) => f(a, b)
val minus = (a: Int, b: Int) => a - b
val f = flip(minus)
println(f(3, 5))
However, when I try to implement it as a function, it does not work:
val flip = (f: ((Any, Any) => Any)) => ((a: Any, b: Any) => f(b, a))
val minus = (a: Int, b: Int) => a - b
val f = flip(minus)
println(f(3, 5))
When I try to compile this code, it fails with this message:
Error:(8, 18) type mismatch;
found : (Int, Int) => Int
required: (Any, Any) => Any
val f = flip(minus)
I understand why it fails: I try to pass (Int, Int) => Int where (Any, Any) => Any is expected. However, I don't know how to fix this problem. Is it possible at all?
Scala doesn't support polymorphic functions, unlike methods which are. This is due to the first class value nature of functions, which are simply instances of the FunctioN
traits. These functions are classes, and they need the types to be bound at declaration site.
If we took the flip
method and tried to eta expand it to a function, we'd see:
val flipFn = flip _
We'd get back in return a value of type:
((Nothing, Nothing) => Nothing) => (Nothing, Nothing) => Nothing
Due to the fact that none of the types were bound, hence the compiler resorts to the buttom type Nothing
.
However, not all hope is lost. There is a library called shapeless which does allow us to define polymorphic functions via PolyN
.
We can implement flip like this:
import shapeless.Poly1
object flip extends Poly1 {
implicit def genericCase[A, B, C] = at[(A, B) => C](f => (b: B, a: A) => f(a, b))
}
flip
is no different from the FunctionN
trait, it defines an apply
method which will be called.
We use it like this:
def main(args: Array[String]): Unit = {
val minus = (a: Int, b: Int) => a - b
val f = flip(minus)
println(f(3, 5))
}
Yielding:
2
This would also work for String
:
def main(args: Array[String]): Unit = {
val stringConcat = (a: String, b: String) => a + b
val f = flip(stringConcat)
println(f("hello", "world"))
}
Yielding:
worldhello
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With