Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Checking for varargs type ascription in Scala macros

Suppose I have this macro:

import language.experimental.macros
import scala.reflect.macros.Context

object FooExample {
  def foo[A](xs: A*): Int = macro foo_impl[A]
  def foo_impl[A](c: Context)(xs: c.Expr[A]*) = c.literal(xs.size)
}

This works as expected with "real" varargs:

scala> FooExample.foo(1, 2, 3)
res0: Int = 3

But the behavior with a sequence ascribed to the varargs type is confusing to me (in Scala 2.10.0-RC3):

scala> FooExample.foo(List(1, 2, 3): _*)
res1: Int = 1

And to show that nothing fishy is going on with the inferred type:

scala> FooExample.foo[Int](List(1, 2, 3): _*)
res2: Int = 1

I would have expected a compile-time error here, and that's what I want. I've used the following approach in most of the macros I've written:

object BarExample {
  def bar(xs: Int*): Int = macro bar_impl
  def bar_impl(c: Context)(xs: c.Expr[Int]*) = {
    import c.universe._
    c.literal(
      xs.map(_.tree).headOption map {
        case Literal(Constant(x: Int)) => x
        case _ => c.abort(c.enclosingPosition, "bar wants literal arguments!")
      } getOrElse c.abort(c.enclosingPosition, "bar wants arguments!")
    )
  }
}

And this catches the problem at compile time:

scala> BarExample.bar(3, 2, 1)
res3: Int = 3

scala> BarExample.bar(List(3, 2, 1): _*)
<console>:8: error: bar wants literal arguments!
              BarExample.bar(List(3, 2, 1): _*)

This feels like a hack to me, though—it's mixing up one bit of validation (checking that the arguments are literals) with another (confirming that we really have varargs). I can also imagine cases where I don't need the arguments to be literals (or where I want their type to be generic).

I know I could do the following:

object BazExample {
  def baz[A](xs: A*): Int = macro baz_impl[A]
  def baz_impl[A](c: Context)(xs: c.Expr[A]*) = {
    import c.universe._

    xs.toList.map(_.tree) match {
      case Typed(_, Ident(tpnme.WILDCARD_STAR)) :: Nil =>
        c.abort(c.enclosingPosition, "baz wants real varargs!")
      case _ => c.literal(xs.size)
    }
  }
}

But this is an ugly way of handling a very simple (and I'd suppose widely necessary) bit of argument validation. Is there a trick I'm missing here? What's the simplest way that I can make sure that foo(1 :: Nil: _*) in my first example gives a compile-time error?

like image 893
Travis Brown Avatar asked Dec 01 '12 20:12

Travis Brown


1 Answers

Does this work as expected?

object BarExample {
  def bar(xs: Int*): Int = macro bar_impl
  def bar_impl(c: Context)(xs: c.Expr[Int]*) = { 
    import c.universe._
    import scala.collection.immutable.Stack
    Stack[Tree](xs map (_.tree): _*) match { 
      case Stack(Literal(Constant(x: Int)), _*) => c.literal(x)
      case _ => c.abort(c.enclosingPosition, "bar wants integer constant arguments!")
    }
  }
}
like image 142
idonnie Avatar answered Sep 22 '22 14:09

idonnie