Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using quantile in Flux (Julia) in loss function

I am trying to use quantile in a loss function to train! (for some robustness, like least trimmed squares), but it mutates the array and Zygote throws an error Mutating arrays is not supported, coming from sort! . Below is a simple example (the content does not make sense of course):

using Flux, StatsBase
xdata = randn(2, 100)   
ydata = randn(100)

model = Chain(Dense(2,10), Dense(10, 1))


function trimmedLoss(x,y; trimFrac=0.f05)
        yhat = model(x)
        absRes = abs.(yhat .- y) |> vec
        trimVal = quantile(absRes, 1.f0-trimFrac) 
        s = sum(ifelse.(absRes .> trimVal,  0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
        #s = sum(absRes)/length(absRes)   # using this and commenting out the two above works (no surprise)    
end

println(trimmedLoss(xdata, ydata)) #works ok

Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())

println(trimmedLoss(xdata, ydata)) #changed loss?

This is all in Flux 0.10 with Julia 1.2

Thanks in advance for any hints or workaround!

like image 664
MR_MPI-BGC Avatar asked Jan 16 '20 20:01

MR_MPI-BGC


1 Answers

Ideally, we'd define a custom adjoint for quantile so that this works out of the box. (Feel free to open an issue to remind us to do this.)

In the mean time there's a quick workaround. It's actually the sorting that causes trouble here so if you do quantile(xs, p, sorted=true) it'll work. Obviously this requires xs to be sorted to get correct results, so you might need to use quantile(sort(xs), ...).

Depending on your Zygote version you might also need an adjoint for sort. That one's pretty easy:

julia> using Zygote: @adjoint

julia> @adjoint function sort(x)
         p = sortperm(x)
         x[p], x̄ -> (x̄[invperm(p)],)
       end

julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)

We'll make that built-in in the next Zygote release, but for now if you add that to your script it'll get your code working.

like image 98
one-more-minute Avatar answered Sep 28 '22 02:09

one-more-minute