Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to require typesafe constant-size array in scala?

I need something like this:

  def encryptBlock(arr: FixedArray[Size16]) = ???
  val blocks = arr.splitFixed[Size16]
  val encrypted = encryptBlock(FixedArray[Size16]())

So, to be sure that I receive only 128-bit array as input.

like image 993
dk14 Avatar asked Feb 02 '15 22:02

dk14


1 Answers

Shapeless can do that for seqs:

import shapeless._
import nat._
import syntax.sized._

scala> def func(l: Sized[List[Int], _3]) = l
func: (l: shapeless.Sized[List[Int],shapeless.nat._3])shapeless.Sized[List[Int],shapeless.nat._3]

scala> List(1,2,3,4,5,6).grouped(3).map(_.sized(3).get).map(func)
res26: Iterator[shapeless.Sized[List[Int],shapeless.nat._3]] = non-empty iterator

scala> List(1,2,3,4,5,6).grouped(2).map(_.sized(2).get).map(func)
<console>:25: error: type mismatch;
 found   : shapeless.Sized[List[Int],shapeless.nat._3] => shapeless.Sized[List[Int],shapeless.nat._3]
 required: shapeless.Sized[List[Int],shapeless.Succ[shapeless.Succ[shapeless._0]]] => ?
              List(1,2,3,4,5,6).grouped(2).map(_.sized(2).get).map(func)

Parameter passed to .size should be Literal(Constant(n: Int)), so you can't pass some variable or expression.

It's also possible convert an Array to some IndexedSeq (.toSeq), like Vector (.toVector)

You can also specify some set of accepatable sizes using type disjunction:

def func[A <: Nat](l: Sized[List[Int], A])(implicit ev: (_2 with _3) <:< A) = l

func(List(1,2).sized(2).get)
res17: shapeless.Sized[List[Int],shapeless.Succ[shapeless.Succ[shapeless._0]]] = shapeless.Sized@3ac1111f

scala> func(List(1,2,3).sized(3).get)
res18: shapeless.Sized[List[Int],shapeless.Succ[shapeless.Succ[shapeless.Succ[shapeless._0]]]] = shapeless.Sized@17191095

scala> func(List(1,2,3,4).sized(4).get)
<console>:24: error: Cannot prove that shapeless.nat._2 with shapeless.nat._3 <:< nat_1.N.
              func(List(1,2,3,4).sized(4).get)
                  ^

Maximum-N restriction (from @DougC and @Miles Sabin):

import ops.nat._
import LT._
scala> def func[N <: Nat](l: Sized[List[Int], N])(implicit ev: N < _3) = l
func: [N <: shapeless.Nat](l: shapeless.Sized[List[Int],N])(implicit ev: shapeless.ops.nat.LT.<[N,shapeless.nat._3])shapeless.Sized[List[Int],N]

scala> func(List(1,2).sized(2).get)
res25: shapeless.Sized[List[Int],shapeless.Succ[shapeless.Succ[shapeless._0]]] = shapeless.Sized@3ac1111f

scala> func(List(1).sized(1).get)
res26: shapeless.Sized[List[Int],shapeless.Succ[shapeless._0]] = shapeless.Sized@73f49b57

scala> func(List(1,2,3).sized(3).get)
<console>:30: error: could not find implicit value for parameter ev: shapeless.ops.nat.LT[nat_1.N,shapeless.nat._3]
              func(List(1,2,3).sized(3).get)
                  ^
like image 137
dk14 Avatar answered Sep 28 '22 11:09

dk14