Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala type constraint to check argument values

I'm trying to implement Conway's surreal numbers in Scala. A surreal number is defined recursively – as a pair of sets of surreal numbers, called left and right, such that no element in the right set is less than or equal to any element in the left set. Here the relation "less than or equal to" between surreal numbers is also defined recursively: we say that xy if

  • there is no element a in the left set of x such that ya, and
  • there is no element b in the right set of y such that bx.

We start with defining zero as a pair of empty sets, then use zero to define 1 and -1, and so on.

I cannot figure out how to enforce the definition of a surreal number at compile time. This is what I have now:

case class SurrealNumber(left: Set[SurrealNumber], right: Set[SurrealNumber]) {
  if ((for { a <- left; b <- right; if b <= a } yield (a, b)).nonEmpty)
    throw new Exception
  def <=(other: SurrealNumber): Boolean =
    !this.left.exists(other <= _) && !other.right.exists(_ <= this)
}

val zero = SurrealNumber(Set.empty, Set.empty)
val one = SurrealNumber(Set(zero), Set.empty)
val minusOne = SurrealNumber(Set.empty, Set(zero))

assert(zero <= zero)
assert((zero <= one) && !(one <= zero))
assert((minusOne <= zero) && !(zero <= minusOne))

When the arguments are invalid, as in SurrealNumber(Set(one), Set(zero)), this code would throw a runtime exception. Is it possible to express the validity check as a type constraint, so that SurrealNumber(Set(one), Set(zero)) wouldn't compile?

like image 437
siarhei Avatar asked Jun 26 '20 08:06

siarhei


1 Answers

You could define a macro in order to execute calculations at compile time

case class SurrealNumber private(left: Set[SurrealNumber], right: Set[SurrealNumber]) {
  def <=(other: SurrealNumber): Boolean =
    !this.left.exists(other <= _) && !other.right.exists(_ <= this)
}

object SurrealNumber {
  def unsafeApply(left: Set[SurrealNumber], right: Set[SurrealNumber]): SurrealNumber =
    new SurrealNumber(left, right)

  def apply(left: Set[SurrealNumber], right: Set[SurrealNumber]): SurrealNumber =
    macro applyImpl

  def applyImpl(c: blackbox.Context)(left: c.Tree, right: c.Tree): c.Tree = {
    import c.universe._
    def eval[A](t: Tree): A = c.eval(c.Expr[A](c.untypecheck(t)))
    val l = eval[Set[SurrealNumber]](left)
    val r = eval[Set[SurrealNumber]](right)
    if ((for { a <- l; b <- r; if b <= a } yield (a, b)).nonEmpty)
      c.abort(c.enclosingPosition, "invalid surreal number")
    else q"SurrealNumber.unsafeApply($left, $right)"
  }
}

but the thing is that although

SurrealNumber(Set.empty, Set.empty)

is a compile-time value of zero but

SurrealNumber(Set(zero), Set.empty)
SurrealNumber(Set.empty, Set(zero))

are runtime values of one, minusOne and compiler doesn't have access to them. So

SurrealNumber(Set(SurrealNumber(Set.empty, Set.empty)), Set.empty)
SurrealNumber(Set.empty, Set(SurrealNumber(Set.empty, Set.empty)))

compile but

SurrealNumber(Set(zero), Set.empty)
SurrealNumber(Set.empty, Set(zero))

don't.


So you should redesign SurrealNumber to be more type-level. For example

import shapeless.{::, HList, HNil, IsDistinctConstraint, OrElse, Poly1, Poly2, Refute, poly}
import shapeless.ops.hlist.{CollectFirst, LeftReducer}
import shapeless.test.illTyped

class SurrealNumber[L <: HList : IsDistinctConstraint : IsSorted, 
                    R <: HList : IsDistinctConstraint : IsSorted](implicit
  notExist: Refute[CollectFirst[L, CollectPoly[R]]]
)

trait LEq[S, S1]
object LEq {
  implicit def mkLEq[S,  L  <: HList,  R <: HList, 
                     S1, L1 <: HList, R1 <: HList](implicit
    ev:        S  <:< SurrealNumber[L, R],
    ev1:       S1 <:< SurrealNumber[L1, R1],
    notExist:  Refute[CollectFirst[L, FlippedLEqPoly[S1]]],
    notExist1: Refute[CollectFirst[R1, LEqPoly[S]]]
  ): S LEq S1 = null
}

trait CollectPoly[R <: HList] extends Poly1
object CollectPoly {
  implicit def cse[R <: HList, LElem](implicit 
    exist: CollectFirst[R, LEqPoly[LElem]]
  ): poly.Case1.Aux[CollectPoly[R], LElem, Unit] = null
}

trait LEqPoly[FixedElem] extends Poly1
object LEqPoly {
  implicit def cse[FixedElem, Elem](implicit 
    leq: Elem LEq FixedElem
  ): poly.Case1.Aux[LEqPoly[FixedElem], Elem, Unit] = null
}

trait FlippedLEqPoly[FixedElem] extends Poly1
object FlippedLEqPoly {
  implicit def cse[FixedElem, Elem](implicit 
    leq: FixedElem LEq Elem
  ): poly.Case1.Aux[FlippedLEqPoly[FixedElem], Elem, Unit] = null
}

object isSortedPoly extends Poly2 {
  implicit def cse[Elem, Elem1](implicit 
    leq: Elem LEq Elem1
  ): Case.Aux[Elem, Elem1, Elem1] = null
}
type IsSorted[L <: HList] = (L <:< HNil) OrElse LeftReducer[L, isSortedPoly.type]

val zero = new SurrealNumber[HNil, HNil]
val one = new SurrealNumber[zero.type :: HNil, HNil]
val minusOne = new SurrealNumber[HNil, zero.type :: HNil]
illTyped("new SurrealNumber[one.type :: HNil, zero.type :: HNil]")
new SurrealNumber[zero.type :: HNil, one.type :: HNil]
like image 123
Dmytro Mitin Avatar answered Oct 31 '22 21:10

Dmytro Mitin