Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Clone a function in Julia

Tags:

julia

I want to overwrite a function in Julia using its old definition. It seems the way to do this would be to clone the function and overwrite the original using the copy — something like the following. However, it appears deepcopy(f) just returns a reference to f, so this doesn't work.

f(x) = x
f_old = deepcopy(f)
f(x) = 1 + f_old(x)

How can I clone a function?

Background: I'm interesting in writing a macro @override that allows me to override functions pointwise (or maybe even piecewise).

fib(n::Int) = fib(n-1) + fib(n-2)
@override fib(0) = 1
@override fib(1) = 1

This particular example would be slow and could be made more efficient using @memoize. There may be good reasons not to do this, but there may also be situations in which one does not know a function fully when it is defined and overriding is necessary.

like image 584
Luke Burns Avatar asked Oct 15 '19 18:10

Luke Burns


People also ask

What is base in Julia?

Base is a module which defines many of the functions, types and macros used in the Julia language. You can view the files for everything it contains here or call whos(Base) to print a list.

Where is Julia keyword?

The where keyword creates a type that is an iterated union of other types, over all values of some variable. For example Vector{T} where T<:Real includes all Vector s where the element type is some kind of Real number.


2 Answers

We can do this using IRTools.jl.

(Note, on newer versions of IRTools, you may need to ask for IRTools.Inner.code_ir instead of IRTools.code_ir.)

using IRTools

fib(n::Int) = fib(n-1) + fib(n-2)

const fib_ir  = IRTools.code_ir(fib, Tuple{Int})
const fib_old = IRTools.func(fib_ir)

fib(n::Int) = n < 2 ? 1 : fib_old(fib, n)

julia> fib(10)
89

What we did there was captured the intermediate representation of the function fib, and then rebuilt it into a new function which we called fib_old. Then we were free to overwrite the definition of fib in terms of fib_old! Notice that since fib_old was defined as recursively calling fib, not fib_old, there's no stack overflow when we call fib(10).

The other thing to notice is that when we called fib_old, we wrote fib_old(fib, n) instead of fib_old(n). This is due to how IRTools.func works.

According to Mike Innes on Slack:

In Julia IR, all functions take a hidden extra argument that represents the function itself The reason for this is that closures are structs with fields, which you need access to in the IR

Here's an implementation of your @override macro with a slightly different syntax:

function _get_type_sig(fdef)
    d = splitdef(fdef)
    types = []
    for arg in d[:args]
        if arg isa Symbol
            push!(types, :Any)
        elseif @capture(arg, x_::T_) 
            push!(types, T)
        else
            error("whoops!")
        end
    end
    if isempty(d[:whereparams])
        :(Tuple{$(types...)})
    else
        :((Tuple{$(types...)} where {$(d[:whereparams]...)}).body)
    end
end

macro override(cond, fdef)
    d = splitdef(fdef)
    shadowf = gensym()
    sig = _get_type_sig(fdef)
    f = d[:name]
    quote
        const $shadowf = IRTools.func(IRTools.code_ir($(d[:name]), $sig))
        function $f($(d[:args]...)) where {$(d[:whereparams]...)}
            if $cond
                $(d[:body])
            else
                $shadowf($f, $(d[:args]...))
            end
        end
    end |> esc
end

Now one can type

fib(n::Int) = fib(n-1) + fib(n-2)
@override n < 2 fib(n::Int) = 1

julia> fib(10)
89

The best part is that this is nearly as fast (at runtime, not compile time!) as if we had written the conditions into the original function!

n = 15

fib2(n::Int) = n < 2 ? 1 : fib2(n-1) + fib2(n-2)

julia> @btime fib($(Ref(15))[])
  4.239 μs (0 allocations: 0 bytes)
89

julia> @btime fib2($(Ref(15))[])
  3.022 μs (0 allocations: 0 bytes)
89
like image 196
Mason Avatar answered Sep 16 '22 19:09

Mason


I really don't see why you'd want to do this (there must a better way to get what you want!).

Nonetheless, although not exactly equivalent you can get what you want by using anonymous functions:

julia> f = x->x                    
#3 (generic function with 1 method)

julia> f_old = deepcopy(f)         
#3 (generic function with 1 method)

julia> f = x->1+f_old(x)           
#5 (generic function with 1 method)

julia> f(4)                        
5                                  
like image 44
carstenbauer Avatar answered Sep 18 '22 19:09

carstenbauer