Is there a way in Julia to generalise a pattern like the following?
function compute_sum(xs::Vector{Float64})
res = 0
for i in 1:length(xs)
res += sqrt(xs[i])
end
res
end
This computes the square-root of each vector element and then sums everything. It is much faster than the "naive" versions with array comprehension or map
, and also doesn't allocate additional memory:
xs = rand(1000)
julia> @time compute_sum(xs)
0.000004 seconds
676.8372556762225
julia> @time sum([sqrt(x) for x in xs])
0.000013 seconds (3 allocations: 7.969 KiB)
676.837255676223
julia> @time sum(map(sqrt, xs))
0.000013 seconds (3 allocations: 7.969 KiB)
676.837255676223
Unfortunately the "obvious" generic version is terrible wrt performance:
function compute_sum2(xs::Vector{Float64}, fn::Function)
res = 0
for i in 1:length(xs)
res += fn(xs[i])
end
res
end
julia> @time compute_sum2(xs, x -> sqrt(x))
0.013537 seconds (19.34 k allocations: 1.011 MiB)
676.8372556762225
Function is an abstract type. So for example Vector{Function} is like a Vector{Any} , or Vector{Integer} : Julia just can't infer the results.
Tuples in Julia are an immutable collection of distinct values of same or different datatypes separated by commas. Tuples are more like arrays in Julia except that arrays only take values of similar datatypes. The values of a tuple can not be changed because tuples are immutable.
As a heuristic, Julia avoids automatically specializing on argument type parameters in three specific cases: Type, Function, and Vararg. Julia will always specialize when the argument is used within the method, but not if the argument is just passed through to another function.
Structuring code into functions is also key for performance. As Julia is structured around a JIT compiler, in order to get our code compiled we need to wrap it up inside a function. As a consequence, performance-critical sections of our code should always be written inside functions.
In Julia, a function is an object that maps a tuple of argument values to a return value. Julia functions are not pure mathematical functions, because they can alter and be affected by the global state of the program. The basic syntax for defining functions in Julia is:
This is not the case in Julia. In Julia, the compiler generally knows the types of all function arguments, local variables, and expressions. However, there are a few specific instances where declarations are helpful. Types can be declared without specifying the types of their fields: This allows a to be of any type.
The reason is that x -> sqrt(x)
is defined as a new anonymous function with each call to compute_sum2
, so this causes new compilation every time you call it.
If you define it before even e.g. like this:
julia> f = x -> sqrt(x)
then you have:
julia> @time compute_sum2(xs, f) # here you pay compilation cost
0.010053 seconds (19.46 k allocations: 1.064 MiB)
665.2469135020949
julia> @time compute_sum2(xs, f) # here you have already compiled everything
0.000003 seconds (1 allocation: 16 bytes)
665.2469135020949
Note that a natural approach would be to define a function with a name like this:
julia> g(x) = sqrt(x)
g (generic function with 1 method)
julia> @time compute_sum2(xs, g)
0.000002 seconds
665.2469135020949
You can see that x -> sqrt(x)
defines a fresh anonymous function each time it is encountered when you write e.g.:
julia> typeof(x -> sqrt(x))
var"#3#4"
julia> typeof(x -> sqrt(x))
var"#5#6"
julia> typeof(x -> sqrt(x))
var"#7#8"
Note that this would be different if an anonymous function would be defined in a function body:
julia> h() = typeof(x -> sqrt(x))
h (generic function with 2 methods)
julia> h()
var"#11#12"
julia> h()
var"#11#12"
julia> h()
var"#11#12"
and you see that this time the anonymous function is the same every time.
In addition to the excellent response by Bogumil, I would just like to add that a very convenient way of generalizing this is to use the normal functional programming function like map
, reduce
, fold
, etc.
In this case, you're doing a map
transformation (namely sqrt
) and a reduce (namely +
), so you can also achieve the result with mapreduce(sqrt, +, xs)
. This has essentially no overhead and is comparable to a manual loop in performance.
If you have a really complicated series of transformations, you can get optimal performance and still use a function using the Transducers.jl package.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With