Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I write a function which will check if the return type is statically inferable for each called method?

Tags:

julia

I want to write a function which will throw an error if julia is not able to infer a concrete return type for the function. How can I do this without any runtime overhead?

like image 613
Mason Avatar asked Sep 24 '19 00:09

Mason


1 Answers

One way to do this is with a generated function. For instance, suppose the function in question was

f(x) = x + (rand(Bool) ? 1.0 : 1)

We can instead write

_f(x) = x + (rand(Bool) ? 1.0 : 1)
@generated function f(x)
    out_type = Core.Compiler.return_type(_f, Tuple{x})
    if !isconcretetype(out_type)
        error("$f($x) does not infer to a concrete type")
    end
    :(_f(x))
end

now we can test this out at the repl. Floating point inputs are fine, but integers error:

julia> f(1.0)
2.0

julia> f(1)
ERROR: f(Int64) does not infer to a concrete type
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] #s28#4(::Any, ::Any) at ./REPL[5]:4
 [3] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:524
 [4] top-level scope at REPL[8]:1

and because of the way we used the generated functions, the type checking and error throwing only happens at compile time, so we pay no runtime cost for this.

If the above looks like too much boiler plate code for you, we can write a macro to automatically generate the inner function and the generated function for arbitrary function signatures:

using MacroTools: splitdef, combinedef

strip_type_asserts(ex::Expr) = ex.head == :(::) ? ex.args[1] : ex
strip_type_asserts(s) = s

macro checked(fdef)
    d = splitdef(fdef)

    f = d[:name]
    args = d[:args]
    whereparams = d[:whereparams]

    d[:name] = gensym()
    shadow_fdef = combinedef(d)

    args_stripped = strip_type_asserts.(args)

    quote
        $shadow_fdef
        @generated function $f($(args...)) where {$(whereparams...)}
            d = $d
            T = Tuple{$(args_stripped...)}
            shadowf = $(d[:name])
            out_type = Core.Compiler.return_type(shadowf, T)
            sig = collect(T.parameters)
            if !isconcretetype(out_type)
                f = $f
                sig = reduce(*, (", $U" for U in T.parameters[2:end]), init="$(T.parameters[1])")
                error("$f($(sig...)) does not infer to a concrete type")
            end
            args = $args
            #Core.println("statically inferred return type was $out_type")
            :($(shadowf)($(args...)))
        end
    end |> esc
end

Now at the repl we just have to annotate a function definition with @checked:

julia> @checked g(x, y) = x + (rand(Bool) ? 1.0 : 1)*y
f (generic function with 2 methods)

julia> g(1, 2.0)
3.0

julia> g(1, 2)
ERROR: g(Int64, Int64) does not infer to a concrete type
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] #s28#5(::Any, ::Any, ::Any) at ./REPL[11]:22
 [3] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:524
 [4] top-level scope at REPL[14]:1

Edit: It's been pointed out in the comments that I am violating one of the 'rules' for using generated functions here because what happens at compile time in the generated function can be silently invalidated if someone redefines a function that the @checked function relies on. For example:

julia> g(x) = x + 1;

julia> @checked f(x) = g(x) + 1;

julia> f(1) # shouldn't error
3

julia> g(x) = rand(Bool) ? 1.0 : 1
g (generic function with 1 method)

julia> f(1) # Should error but doesn't!!!
2.0

julia> f(1)
2

So be warned: if you use something like this interactively, be careful about redefining functions you're relying on. If for whatever reason, you decide to use this macro in a package, be warned that people committing type piracy will invalidate your type checking.

If someone were to try to apply this technique to important code, I would suggest either reconsidering, or putting some serious thought into how to make this safer. If you have any ideas on making it safer, I'd love to hear them! Perhaps there are some tricks you can do to force recompilation of the function every time a dependant method is changed.

like image 128
Mason Avatar answered Oct 19 '22 15:10

Mason