I want to write a function which will throw an error if julia is not able to infer a concrete return type for the function. How can I do this without any runtime overhead?
One way to do this is with a generated function. For instance, suppose the function in question was
f(x) = x + (rand(Bool) ? 1.0 : 1)
We can instead write
_f(x) = x + (rand(Bool) ? 1.0 : 1)
@generated function f(x)
out_type = Core.Compiler.return_type(_f, Tuple{x})
if !isconcretetype(out_type)
error("$f($x) does not infer to a concrete type")
end
:(_f(x))
end
now we can test this out at the repl. Floating point inputs are fine, but integers error:
julia> f(1.0)
2.0
julia> f(1)
ERROR: f(Int64) does not infer to a concrete type
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] #s28#4(::Any, ::Any) at ./REPL[5]:4
[3] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:524
[4] top-level scope at REPL[8]:1
and because of the way we used the generated functions, the type checking and error throwing only happens at compile time, so we pay no runtime cost for this.
If the above looks like too much boiler plate code for you, we can write a macro to automatically generate the inner function and the generated function for arbitrary function signatures:
using MacroTools: splitdef, combinedef
strip_type_asserts(ex::Expr) = ex.head == :(::) ? ex.args[1] : ex
strip_type_asserts(s) = s
macro checked(fdef)
d = splitdef(fdef)
f = d[:name]
args = d[:args]
whereparams = d[:whereparams]
d[:name] = gensym()
shadow_fdef = combinedef(d)
args_stripped = strip_type_asserts.(args)
quote
$shadow_fdef
@generated function $f($(args...)) where {$(whereparams...)}
d = $d
T = Tuple{$(args_stripped...)}
shadowf = $(d[:name])
out_type = Core.Compiler.return_type(shadowf, T)
sig = collect(T.parameters)
if !isconcretetype(out_type)
f = $f
sig = reduce(*, (", $U" for U in T.parameters[2:end]), init="$(T.parameters[1])")
error("$f($(sig...)) does not infer to a concrete type")
end
args = $args
#Core.println("statically inferred return type was $out_type")
:($(shadowf)($(args...)))
end
end |> esc
end
Now at the repl we just have to annotate a function definition with @checked
:
julia> @checked g(x, y) = x + (rand(Bool) ? 1.0 : 1)*y
f (generic function with 2 methods)
julia> g(1, 2.0)
3.0
julia> g(1, 2)
ERROR: g(Int64, Int64) does not infer to a concrete type
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] #s28#5(::Any, ::Any, ::Any) at ./REPL[11]:22
[3] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:524
[4] top-level scope at REPL[14]:1
Edit: It's been pointed out in the comments that I am violating one of the 'rules' for using generated functions here because what happens at compile time in the generated function can be silently invalidated if someone redefines a function that the @checked
function relies on. For example:
julia> g(x) = x + 1;
julia> @checked f(x) = g(x) + 1;
julia> f(1) # shouldn't error
3
julia> g(x) = rand(Bool) ? 1.0 : 1
g (generic function with 1 method)
julia> f(1) # Should error but doesn't!!!
2.0
julia> f(1)
2
So be warned: if you use something like this interactively, be careful about redefining functions you're relying on. If for whatever reason, you decide to use this macro in a package, be warned that people committing type piracy will invalidate your type checking.
If someone were to try to apply this technique to important code, I would suggest either reconsidering, or putting some serious thought into how to make this safer. If you have any ideas on making it safer, I'd love to hear them! Perhaps there are some tricks you can do to force recompilation of the function every time a dependant method is changed.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With