Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Implementing a generic Vector in Scala

Tags:

scala

I'm trying to implement a generic (mathematical) vector in Scala, and I'm running into a couple of issues of how to do it properly:

1) How do you handle + and - such that operating on a Vector[Int] and a Vector[Double] would return a Vector[Double]? In short, how would I go about doing auto promotion of numeric types (preferably taking advantage of Scala's auto promotion)? Because using implicit n: Numeric[T] only works if the types of both vectors are the same.

2) Related, how should I define a * operation such that it takes in any Numeric type, and return a vector of the right numeric type? That is, a Vector[Int] * 2.0 would return a Vector[Double].

This is my current code (which doesn't behave as I would want it):

case class Vector2[T](val x: T, val y: T)(implicit n: Numeric[T]) {
  import n._

  def length = sqrt(x.toDouble() * x.toDouble() + y.toDouble() * y.toDouble())
  def unary_- = new Vector2(-x, -y)

  def +(that: Vector2) = new Vector2(x + that.x, y + that.y)
  def -(that: Vector2) = new Vector2(x - that.x, y - that.y)

  def *(s: ???) = new Vector2(x * s, y * s)
}

Update

After a lot of thought, I've decided to accept Chris K's answer, because it works for all the situations I've asked about, despite the verbosity of the type class solution (the numeric types in Scala are Byte, Short, Int, Long, Float, Double, BigInt, BigDecimal, which makes for a very fun time implementing all the operations between each possible pair of types).

I've upvoted both answers, because they're both excellent answers. And I really wish Gabriele Petronella's answer worked for all possible scenarios, if only because it's a very elegant and consise answer. I do hope there'll be some way that it'll work eventually.

like image 704
Benedict Lee Avatar asked Dec 08 '22 06:12

Benedict Lee


1 Answers

A possible approach is to unify the type of the two vectors before applying the operation. By doing so, operations on Vector2[A] can alwyas take a Vector2[A] as parameter.

A similar approach can be used for multiplication (see the example below).

Using an implicit conversion from Vector2[A] to Vector2[B] (provided that Numeric[A] and Numeric[B] both exist and that you have implicit evidence that A can be converted to B), you can do:

case class Vector2[A](val x: A, val y: A)(implicit n: Numeric[A]) {
  import n.mkNumericOps
  import scala.math.sqrt

  def map[B: Numeric](f: (A => B)): Vector2[B] = Vector2(f(x), f(y))

  def length = sqrt(x.toDouble * x.toDouble + y.toDouble * y.toDouble)
  def unary_- = this.map(-_)

  def +(that: Vector2[A]) = Vector2(x + that.x, y + that.y)
  def -(that: Vector2[A]) = Vector2(x - that.x, y - that.y)
  def *[B](s: B)(implicit ev: A => B, nb: Numeric[B]) = this.map(ev(_)).map(nb.times(_, s))
}

object Vector2 {
  implicit def toV[A: Numeric, B: Numeric](v: Vector2[A])(
    implicit ev: A => B // kindly provided by scala std library for all numeric types
  ): Vector2[B] = v.map(ev(_))
}

examples:

val x = Vector2(1, 2)         //> x  : Solution.Vector2[Int] = Vector2(1,2)
val y = Vector2(3.0, 4.0)     //> y  : Solution.Vector2[Double] = Vector2(3.0,4.0)
val z = Vector2(5L, 6L)       //> z  : Solution.Vector2[Long] = Vector2(5,6)

x + y                         //> res0: Solution.Vector2[Double] = Vector2(4.0,6.0)
y + x                         //> res1: Solution.Vector2[Double] = Vector2(4.0,6.0)
x + z                         //> res2: Solution.Vector2[Long] = Vector2(6,8)
z + x                         //> res3: Solution.Vector2[Long] = Vector2(6,8)
y + z                         //> res4: Solution.Vector2[Double] = Vector2(8.0,10.0)
z + y                         //> res5: Solution.Vector2[Double] = Vector2(8.0,10.0)

x * 2                         //> res6: Solution.Vector2[Int] = Vector2(2,4)
x * 2.0                       //> res7: Solution.Vector2[Double] = Vector2(2.0,4.0)
x * 2L                        //> res8: Solution.Vector2[Long] = Vector2(2,4)
x * 2.0f                      //> res9: Solution.Vector2[Float] = Vector2(2.0,4.0)
x * BigDecimal(2)             //> res10: Solution.Vector2[scala.math.BigDecimal] = Vector2(2,4)

As per Chris' request in the comments, here's an example of how the implicit conversions chain work

If we run the scala REPL with scala -XPrint:typer, we can see the implicits at work explicitly For instance

z + x

becomes

val res1: Vector2[Long] = $line7.$read.$iw.$iw.$iw.z.+($iw.this.Vector2.toV[Int, Long]($line4.$read.$iw.$iw.$iw.x)(math.this.Numeric.IntIsIntegral, math.this.Numeric.LongIsIntegral, {
        ((x: Int) => scala.this.Int.int2long(x))
      }));

which translated to more readable terms is

val res: Vector2[Long] = z + toV[Int, Long](x){ i: Int => Int.int2long(i) }
                             ^____________________________________________^
                              the result of this is a Vector[Long]

Conversely, x + z becomes

val res: Vector2[Long] = toV[Int, Long](x){ i: Int => Int.int2long(i) } + z

The way it works is roughly this:

  1. we say z: V[Long] + x: V[Int]
  2. the compiler sees that there's a method +[Long, Long]
  3. it looks from a conversion from V[Int] to V[Long]
  4. it finds toV
  5. it looks for a conversion from Int to Long as required by toV
  6. it finds Int.int2Long, i.e. a function Int => Long
  7. it can then use toV[Int, Long] i.e. a function V[Int] => V[Long]
  8. it does x + toV(z)

in case we do instead x: V[Int] + z: V[Long]

  1. the compiler sees that there's a method +[Int, Int]
  2. it looks from a conversion from V[Long] to V[Int]
  3. it finds toV
  4. it looks for a conversion from Long to Int as required by toV
  5. it can't find it!
  6. it sees that there's a method +[Long, Long]

and we're back to point 3 of the previous example


Update

As noticed in the comments, there's a problem when doing

Vector(2.0, 1.0) * 2.0f

This is pretty much the issue:

2.0f * 3.0 // 6.0: Double

but also

2.0 * 3.0f // 6.0: Double

So it doesn't matter what's the argument, when mixing doubles and floats we always end up with a double. Unfortunately we're requiring evidence of A => B in order to convert the vector to the type of s, but sometimes we actually want to convert s to the type of the vector.

We need to handle the two cases. The first naive approach could be

def *[B](s: B)(implicit ev: A => B, nb: Numeric[B]): Vector[B] =
  this.map(nb.times(ev(_), s)) // convert A to B
def *[B](s: B)(implicit ev: B => A, na: Numeric[A]): Vector[A] =
  this.map(na.times(_, ev(s))) // convert B to A

Neat, right? Too bad it doesn't work: scala does not consider implicit arguments when disambiguating overloaded methods. We have to work around this, using the magnet pattern, as suggested here.

case class Vector2[A](val x: A, val y: A)(implicit na: Numeric[A]) {
  object ToBOrToA {
    implicit def fromA[B: Numeric](implicit ev: A => B): ToBOrToA[B] = ToBOrToA(Left(ev))
    implicit def fromB[B: Numeric](implicit ev: B => A): ToBOrToA[B] = ToBOrToA(Right(ev))
  }
  case class ToBOrToA[B: Numeric](e: Either[(A => B), (B => A)])

  def *[B](s: B)(implicit ev: ToBOrToA[B], nb: Numeric[B]) = ev match {
    case ToBOrToA(Left(f)) => Vector2[B](nb.times(f(x), s), nb.times(ev(y), s))
    case ToBOrToA(Right(f)) => Vector2[A](na.times(x, f(s)), na.times(y, f(s))
  }
}

We have only one * method, and we inspect the implicit parameter ev to know whether we have to convert everything to the type of the vector or to the type of s.

The only drawback of this approach is the result type. ev match { ... } returns something that it's supertype of B with A, and I still haven't found a workaround for it.

val a = x * 2.0    //> a  : Solution.Vector2[_ >: Double with Int] = Vector2(2.0,4.0)
val b = y * 2      //> b  : Solution.Vector2[_ >: Int with Double] = Vector2(6.0,8.0)
like image 109
Gabriele Petronella Avatar answered Jan 13 '23 03:01

Gabriele Petronella