To embed Scala as a "scripting language", I need to be able to compile text fragments to simple objects, such as Function0[Unit]
that can be serialised to and deserialised from disk and which can be loaded into the current runtime and executed.
How would I go about this?
Say for example, my text fragment is (purely hypothetical):
Document.current.elements.headOption.foreach(_.open())
This might be wrapped into the following complete text:
package myapp.userscripts
import myapp.DSL._
object UserFunction1234 extends Function0[Unit] {
def apply(): Unit = {
Document.current.elements.headOption.foreach(_.open())
}
}
What comes next? Should I use IMain
to compile this code? I don't want to use the normal interpreter mode, because the compilation should be "context-free" and not accumulate requests.
What I need to get hold off from the compilation is I guess the binary class file? In that case, serialisation is straight forward (byte array). How would I then load that class into the runtime and invoke the apply
method?
What happens if the code compiles to multiple auxiliary classes? The example above contains a closure _.open()
. How do I make sure I "package" all those auxiliary things into one object to serialize and class-load?
Note: Given that Scala 2.11 is imminent and the compiler API probably changed, I am happy to receive hints as how to approach this problem on Scala 2.11
Here is one idea: use a regular Scala compiler instance. Unfortunately it seems to require the use of hard disk files both for input and output. So we use temporary files for that. The output will be zipped up in a JAR which will be stored as a byte array (that would go into the hypothetical serialization process). We need a special class loader to retrieve the class again from the extracted JAR.
The following assumes Scala 2.10.3 with the scala-compiler
library on the class path:
import scala.tools.nsc
import java.io._
import scala.annotation.tailrec
Wrapping user provided code in a function class with a synthetic name that will be incremented for each new fragment:
val packageName = "myapp"
var userCount = 0
def mkFunName(): String = {
val c = userCount
userCount += 1
s"Fun$c"
}
def wrapSource(source: String): (String, String) = {
val fun = mkFunName()
val code = s"""package $packageName
|
|class $fun extends Function0[Unit] {
| def apply(): Unit = {
| $source
| }
|}
|""".stripMargin
(fun, code)
}
A function to compile a source fragment and return the byte array of the resulting jar:
/** Compiles a source code consisting of a body which is wrapped in a `Function0`
* apply method, and returns the function's class name (without package) and the
* raw jar file produced in the compilation.
*/
def compile(source: String): (String, Array[Byte]) = {
val set = new nsc.Settings
val d = File.createTempFile("temp", ".out")
d.delete(); d.mkdir()
set.d.value = d.getPath
set.usejavacp.value = true
val compiler = new nsc.Global(set)
val f = File.createTempFile("temp", ".scala")
val out = new BufferedOutputStream(new FileOutputStream(f))
val (fun, code) = wrapSource(source)
out.write(code.getBytes("UTF-8"))
out.flush(); out.close()
val run = new compiler.Run()
run.compile(List(f.getPath))
f.delete()
val bytes = packJar(d)
deleteDir(d)
(fun, bytes)
}
def deleteDir(base: File): Unit = {
base.listFiles().foreach { f =>
if (f.isFile) f.delete()
else deleteDir(f)
}
base.delete()
}
Note: Doesn't handle compiler errors yet!
The packJar
method uses the compiler output directory and produces an in-memory jar file from it:
// cf. http://stackoverflow.com/questions/1281229
def packJar(base: File): Array[Byte] = {
import java.util.jar._
val mf = new Manifest
mf.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
val bs = new java.io.ByteArrayOutputStream
val out = new JarOutputStream(bs, mf)
def add(prefix: String, f: File): Unit = {
val name0 = prefix + f.getName
val name = if (f.isDirectory) name0 + "/" else name0
val entry = new JarEntry(name)
entry.setTime(f.lastModified())
out.putNextEntry(entry)
if (f.isFile) {
val in = new BufferedInputStream(new FileInputStream(f))
try {
val buf = new Array[Byte](1024)
@tailrec def loop(): Unit = {
val count = in.read(buf)
if (count >= 0) {
out.write(buf, 0, count)
loop()
}
}
loop()
} finally {
in.close()
}
}
out.closeEntry()
if (f.isDirectory) f.listFiles.foreach(add(name, _))
}
base.listFiles().foreach(add("", _))
out.close()
bs.toByteArray
}
A utility function that takes the byte array found in deserialization and creates a map from class names to class byte code:
def unpackJar(bytes: Array[Byte]): Map[String, Array[Byte]] = {
import java.util.jar._
import scala.annotation.tailrec
val in = new JarInputStream(new ByteArrayInputStream(bytes))
val b = Map.newBuilder[String, Array[Byte]]
@tailrec def loop(): Unit = {
val entry = in.getNextJarEntry
if (entry != null) {
if (!entry.isDirectory) {
val name = entry.getName
// cf. http://stackoverflow.com/questions/8909743
val bs = new ByteArrayOutputStream
var i = 0
while (i >= 0) {
i = in.read()
if (i >= 0) bs.write(i)
}
val bytes = bs.toByteArray
b += mkClassName(name) -> bytes
}
loop()
}
}
loop()
in.close()
b.result()
}
def mkClassName(path: String): String = {
require(path.endsWith(".class"))
path.substring(0, path.length - 6).replace("/", ".")
}
A suitable class loader:
class MemoryClassLoader(map: Map[String, Array[Byte]]) extends ClassLoader {
override protected def findClass(name: String): Class[_] =
map.get(name).map { bytes =>
println(s"defineClass($name, ...)")
defineClass(name, bytes, 0, bytes.length)
} .getOrElse(super.findClass(name)) // throws exception
}
And a test case which contains additional classes (closures):
val exampleSource =
"""val xs = List("hello", "world")
|println(xs.map(_.capitalize).mkString(" "))
|""".stripMargin
def test(fun: String, cl: ClassLoader): Unit = {
val clName = s"$packageName.$fun"
println(s"Resolving class '$clName'...")
val clazz = Class.forName(clName, true, cl)
println("Instantiating...")
val x = clazz.newInstance().asInstanceOf[() => Unit]
println("Invoking 'apply':")
x()
}
locally {
println("Compiling...")
val (fun, bytes) = compile(exampleSource)
val map = unpackJar(bytes)
println("Classes found:")
map.keys.foreach(k => println(s" '$k'"))
val cl = new MemoryClassLoader(map)
test(fun, cl) // should call `defineClass`
test(fun, cl) // should find cached class
}
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With