Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Creating serializable objects from Scala source code at runtime

To embed Scala as a "scripting language", I need to be able to compile text fragments to simple objects, such as Function0[Unit] that can be serialised to and deserialised from disk and which can be loaded into the current runtime and executed.

How would I go about this?

Say for example, my text fragment is (purely hypothetical):

Document.current.elements.headOption.foreach(_.open())

This might be wrapped into the following complete text:

package myapp.userscripts
import myapp.DSL._

object UserFunction1234 extends Function0[Unit] {
  def apply(): Unit = {
    Document.current.elements.headOption.foreach(_.open())
  }
}

What comes next? Should I use IMain to compile this code? I don't want to use the normal interpreter mode, because the compilation should be "context-free" and not accumulate requests.

What I need to get hold off from the compilation is I guess the binary class file? In that case, serialisation is straight forward (byte array). How would I then load that class into the runtime and invoke the apply method?

What happens if the code compiles to multiple auxiliary classes? The example above contains a closure _.open(). How do I make sure I "package" all those auxiliary things into one object to serialize and class-load?


Note: Given that Scala 2.11 is imminent and the compiler API probably changed, I am happy to receive hints as how to approach this problem on Scala 2.11

like image 416
0__ Avatar asked Feb 23 '14 18:02

0__


1 Answers

Here is one idea: use a regular Scala compiler instance. Unfortunately it seems to require the use of hard disk files both for input and output. So we use temporary files for that. The output will be zipped up in a JAR which will be stored as a byte array (that would go into the hypothetical serialization process). We need a special class loader to retrieve the class again from the extracted JAR.

The following assumes Scala 2.10.3 with the scala-compiler library on the class path:

import scala.tools.nsc
import java.io._
import scala.annotation.tailrec

Wrapping user provided code in a function class with a synthetic name that will be incremented for each new fragment:

val packageName = "myapp"

var userCount = 0

def mkFunName(): String = {
  val c = userCount
  userCount += 1
  s"Fun$c"
}

def wrapSource(source: String): (String, String) = {
  val fun = mkFunName()
  val code = s"""package $packageName
                |
                |class $fun extends Function0[Unit] {
                |  def apply(): Unit = {
                |    $source
                |  }
                |}
                |""".stripMargin
  (fun, code)
}

A function to compile a source fragment and return the byte array of the resulting jar:

/** Compiles a source code consisting of a body which is wrapped in a `Function0`
  * apply method, and returns the function's class name (without package) and the
  * raw jar file produced in the compilation.
  */
def compile(source: String): (String, Array[Byte]) = {
  val set             = new nsc.Settings
  val d               = File.createTempFile("temp", ".out")
  d.delete(); d.mkdir()
  set.d.value         = d.getPath
  set.usejavacp.value = true
  val compiler        = new nsc.Global(set)
  val f               = File.createTempFile("temp", ".scala")
  val out             = new BufferedOutputStream(new FileOutputStream(f))
  val (fun, code)     = wrapSource(source)
  out.write(code.getBytes("UTF-8"))
  out.flush(); out.close()
  val run             = new compiler.Run()
  run.compile(List(f.getPath))
  f.delete()

  val bytes = packJar(d)
  deleteDir(d)

  (fun, bytes)
}

def deleteDir(base: File): Unit = {
  base.listFiles().foreach { f =>
    if (f.isFile) f.delete()
    else deleteDir(f)
  }
  base.delete()
}

Note: Doesn't handle compiler errors yet!

The packJar method uses the compiler output directory and produces an in-memory jar file from it:

// cf. http://stackoverflow.com/questions/1281229
def packJar(base: File): Array[Byte] = {
  import java.util.jar._

  val mf = new Manifest
  mf.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
  val bs    = new java.io.ByteArrayOutputStream
  val out   = new JarOutputStream(bs, mf)

  def add(prefix: String, f: File): Unit = {
    val name0 = prefix + f.getName
    val name  = if (f.isDirectory) name0 + "/" else name0
    val entry = new JarEntry(name)
    entry.setTime(f.lastModified())
    out.putNextEntry(entry)
    if (f.isFile) {
      val in = new BufferedInputStream(new FileInputStream(f))
      try {
        val buf = new Array[Byte](1024)
        @tailrec def loop(): Unit = {
          val count = in.read(buf)
          if (count >= 0) {
            out.write(buf, 0, count)
            loop()
          }
        }
        loop()
      } finally {
        in.close()
      }
    }
    out.closeEntry()
    if (f.isDirectory) f.listFiles.foreach(add(name, _))
  }

  base.listFiles().foreach(add("", _))
  out.close()
  bs.toByteArray
}

A utility function that takes the byte array found in deserialization and creates a map from class names to class byte code:

def unpackJar(bytes: Array[Byte]): Map[String, Array[Byte]] = {
  import java.util.jar._
  import scala.annotation.tailrec

  val in = new JarInputStream(new ByteArrayInputStream(bytes))
  val b  = Map.newBuilder[String, Array[Byte]]

  @tailrec def loop(): Unit = {
    val entry = in.getNextJarEntry
    if (entry != null) {
      if (!entry.isDirectory) {
        val name  = entry.getName  
        // cf. http://stackoverflow.com/questions/8909743
        val bs  = new ByteArrayOutputStream
        var i   = 0
        while (i >= 0) {
          i = in.read()
          if (i >= 0) bs.write(i)
        }
        val bytes = bs.toByteArray
        b += mkClassName(name) -> bytes
      }
      loop()
    }
  }
  loop()
  in.close()
  b.result()
}

def mkClassName(path: String): String = {
  require(path.endsWith(".class"))
  path.substring(0, path.length - 6).replace("/", ".")
}

A suitable class loader:

class MemoryClassLoader(map: Map[String, Array[Byte]]) extends ClassLoader {
  override protected def findClass(name: String): Class[_] =
    map.get(name).map { bytes =>
      println(s"defineClass($name, ...)")
      defineClass(name, bytes, 0, bytes.length)

    } .getOrElse(super.findClass(name)) // throws exception
}

And a test case which contains additional classes (closures):

val exampleSource =
  """val xs = List("hello", "world")
    |println(xs.map(_.capitalize).mkString(" "))
    |""".stripMargin

def test(fun: String, cl: ClassLoader): Unit = {
  val clName  = s"$packageName.$fun"
  println(s"Resolving class '$clName'...")
  val clazz = Class.forName(clName, true, cl)
  println("Instantiating...")
  val x     = clazz.newInstance().asInstanceOf[() => Unit]
  println("Invoking 'apply':")
  x()
}

locally {
  println("Compiling...")
  val (fun, bytes) = compile(exampleSource)

  val map = unpackJar(bytes)
  println("Classes found:")
  map.keys.foreach(k => println(s"  '$k'"))

  val cl = new MemoryClassLoader(map)
  test(fun, cl)   // should call `defineClass`
  test(fun, cl)   // should find cached class
}
like image 118
0__ Avatar answered Sep 22 '22 23:09

0__