Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to see Zygote differentiated function implementation?

Tags:

julia

I have written a simple function in a .jl file that I can successfully differentiate using forward. However I am new to Julia and I do not understand how to see the generated source code for the differentiated function. I've tried all sorts of things like @code_lowered Zygote.forward(maxPool, [1.0, 2.0]) and @code_lowered Zygote.forward(maxPool) but they just show me the call to forward itself.

How can I see the code that Zygote generates for the forward and reverse passes?

using Pkg
using Zygote, ForwardDiff

function size1d(v)
    return size(v)[1]
end

function max(a, b)
    if a > b
        a
    else
        b
    end
end

function maxPool(v)
    return [max(v[2 * i - 1], v[2 * i])
            for i in 1:div(size1d(v), 2)]
end

v = [1.0, 2.0, 3.0, 4.0]
df = [20.0, 30.0]

println("maxPool(v):")
println(maxPool(v))
println()

println("maxAdjoint:")
maxAdjoint = Zygote.forward(max, 3.0, 4.0)[2]
println(maxAdjoint(1.0))
println()

println("maxPoolAdjoint:")
maxPoolAdjoint = Zygote.forward(maxPool, v)[2]
println(maxPoolAdjoint(df))

like image 756
Tom Ellis Avatar asked Apr 01 '26 09:04

Tom Ellis


1 Answers

Zygote has its own macro Zygote.@code_adjoint for showing the lowered adjoint code, i.e. the code that generates the gradient of a function in reverse mode. I'm not sure about forward mode though.

Here's a simple example in reverse mode:

julia> using Zygote

julia> f(x) = 2x + 1
f (generic function with 1 method)

julia> @code_lowered f(1)
CodeInfo(
1 ─ %1 = 2 * x
│   %2 = %1 + 1
└──      return %2
)

julia> Zygote.@code_adjoint f(1)
Zygote.Adjoint(1: (%3, %4 :: Zygote.Context, %1, %2)
  %5 = Zygote._forward(%4, Main.:*, 2, %2)
  %6 = Base.getindex(%5, 1)
  %7 = Base.getindex(%5, 2)
  %8 = Zygote._forward(%4, Main.:+, %6, 1)
  %9 = Base.getindex(%8, 1)
  %10 = Base.getindex(%8, 2)
  return %9
, 1: (%1)
  %2 = (@10)(%1)
  %3 = Zygote.gradindex(%2, 2)
  %4 = (@7)(%3)
  %5 = Zygote.gradindex(%4, 3)
  %6 = Zygote.tuple(nothing, %5)
  return %6
)

We might worry from the length and apparent complexity of this lowered adjoint code that the gradient is slow, but we can check the LLVM code to make sure everything ends up being elided away:

julia> @code_llvm f'(1)

;  @ /Users/mason/.julia/packages/Zygote/SAZMM/src/compiler/interface.jl:50 within `#34'
define i64 @"julia_#34_18250"(i64) {
top:
  ret i64 2
}
like image 198
Mason Avatar answered Apr 03 '26 15:04

Mason



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!