I am trying to use quantile in a loss function to train! (for some robustness, like least trimmed squares), but it mutates the array and Zygote throws an error Mutating arrays is not supported
, coming from sort!
. Below is a simple example (the content does not make sense of course):
using Flux, StatsBase
xdata = randn(2, 100)
ydata = randn(100)
model = Chain(Dense(2,10), Dense(10, 1))
function trimmedLoss(x,y; trimFrac=0.f05)
yhat = model(x)
absRes = abs.(yhat .- y) |> vec
trimVal = quantile(absRes, 1.f0-trimFrac)
s = sum(ifelse.(absRes .> trimVal, 0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
#s = sum(absRes)/length(absRes) # using this and commenting out the two above works (no surprise)
end
println(trimmedLoss(xdata, ydata)) #works ok
Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())
println(trimmedLoss(xdata, ydata)) #changed loss?
This is all in Flux 0.10 with Julia 1.2
Thanks in advance for any hints or workaround!
Ideally, we'd define a custom adjoint for quantile
so that this works out of the box. (Feel free to open an issue to remind us to do this.)
In the mean time there's a quick workaround. It's actually the sorting that causes trouble here so if you do quantile(xs, p, sorted=true)
it'll work. Obviously this requires xs
to be sorted to get correct results, so you might need to use quantile(sort(xs), ...)
.
Depending on your Zygote version you might also need an adjoint for sort
. That one's pretty easy:
julia> using Zygote: @adjoint
julia> @adjoint function sort(x)
p = sortperm(x)
x[p], x̄ -> (x̄[invperm(p)],)
end
julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)
We'll make that built-in in the next Zygote release, but for now if you add that to your script it'll get your code working.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With