I want to increment a static array A
at location i
by x
. If it was a mutable array, I would just do A[i] += x
. But since it's StaticArray, I need to create a new one. However, if I new the size of A
then I would do something like
A = A + @SVector [0,0,x]
and have a branch for each i
. But in this case, the SVector
is a user input, so I only know in advance using type information. I would rather not make my core logic all be a generated function just to handle this, so I was hoping there was an easy solution, or maybe this requires an @generated
helper function.
Note that this problem is equivalent to creating an @SVector
that has value x
at location i
but is otherwise zero. If there's an easy way to do that, then my problem is solved as well.
The naive approach using array comprehensions would be
julia> k = 4
4
julia> @SVector [i == k? 1.0 : 0 for i in 1:5]
5-element SVector{10,Float64}:
0.0
0.0
0.0
1.0
0.0
This a good first step you could take when reading the StaticArrays.jl "Quick Start" part of the README.
However, we care a lot in Julia about type stability and generic code, because:
so a more Julian way would be to use
julia> function increment_value(A::SVector{L,T},x,k) where {L,T}
_A = [i == k ? x : zero(x) for i in 1:L]
A+_A
end
julia> A = @SVector [0, 0, 0, 0, 10]
5-element SVector{5,Int64}:
0
0
0
0
10
julia> increment_value(A,5,2)
5-element SVector{5,Int64}:
0
5
0
0
10
However, our final answer should include a way to avoid that extra variable allocation, and exploit some compiler pipelining with the handy ifelse
function:
`julia> using StaticArrays, BenchmarkTools
julia> function increment_value(A :: SVector{L,T}, x,k) where {T,L}
SVector(ntuple(i->ifelse(i == k, A[i]+x, A[i]), Val{L}))
end
increment_value (generic function with 1 method)
julia> a = @SVector [ 1, 2, 3, 4, 5]
5-element SVector{5,Int64}:
1
2
3
4
5
julia> @benchmark increment_value($a,$3,$5)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
minimum time: 3.178 ns (0.00% GC)
median time: 3.285 ns (0.00% GC)
mean time: 3.293 ns (0.00% GC)
maximum time: 13.620 ns (0.00% GC)
samples: 10000
evals/sample: 1000
It is a bit tricky to get all the necessary values for this problem available at compile time. What I've got now is:
@generated updateindex(s::SVector{L,T},j::Type{Val{I}},v) where {L,T,I} =
Expr(:call, :(SVector{L,T}), (ifelse(i==I, :(s[$i]+v), :(s[$i])) for i=1:L)...)
or just to set a coordinate:
@generated setindex(s::SVector{L,T},j::Type{Val{I}},v) where {L,T,I} =
Expr(:call, :(SVector{L,T}), (ifelse(i==I, :v, :(s[$i])) for i=1:L)...)
And this can be used as:
julia> Z = @SVector [1,1,1,1,1];
julia> updateindex(Z,Val{3},4)
5-element SVector{5,Int64}:
1
1
5
1
1
And benchmarked as:
julia> using BenchmarkTools
julia> @btime updateindex($Z,Val{3},4);
2.032 ns (0 allocations: 0 bytes)
The code is minimal:
julia> @code_native updateindex(Z,Val{3},4)
.text
Filename: REPL[13]
pushq %rbp
movq %rsp, %rbp
Source line: 1
vmovups (%rsi), %xmm0
addq 16(%rsi), %rcx
movq 24(%rsi), %rax
movq 32(%rsi), %rdx
vmovups %xmm0, (%rdi)
movq %rcx, 16(%rdi)
movq %rax, 24(%rdi)
movq %rdx, 32(%rdi)
movq %rdi, %rax
popq %rbp
retq
nopl (%rax)
Does this solve the conundrum?
BTW if there are ways to rewrite this into a more readable form, I will be happy to see in the comments (and will update answer accordingly).
UPDATE
Chris's comment correctly noted that a version with a non value type j
can be made:
@generated setindex(s::SVector{L,T},j,v) where {L,T} =
Expr(:call, :(SVector{L,T}), (:(ifelse($i==j, v, s[$i])) for i=1:L)...)
Demo and low-level code (it can be seen a bit of performance has to be paid for not knowing which index to update):
julia> setindex(Z,4,3)
5-element SVector{5,Int64}:
1
1
1
3
1
julia> @code_native setindex(Z,4,3)
.text
Filename: REPL[15]
pushq %rbp
movq %rsp, %rbp
Source line: 1
cmpq $1, %rdx
movq (%rsi), %r8
cmoveq %rcx, %r8
cmpq $2, %rdx
movq 8(%rsi), %r9
cmoveq %rcx, %r9
cmpq $3, %rdx
movq 16(%rsi), %r10
cmoveq %rcx, %r10
cmpq $4, %rdx
movq 24(%rsi), %rax
cmoveq %rcx, %rax
cmpq $5, %rdx
cmovneq 32(%rsi), %rcx
movq %r8, (%rdi)
movq %r9, 8(%rdi)
movq %r10, 16(%rdi)
movq %rax, 24(%rdi)
movq %rcx, 32(%rdi)
movq %rdi, %rax
popq %rbp
retq
nopw %cs:(%rax,%rax)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With