Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Defining function for any array of integers

I want to define a function that takes as an input any array of dimension 2 that has integers (and only integers) as its elements. Although I know I don't have to specify the type of the arguments of a function in Julia, I would like to do it in order to speed it up.

With the type hierarchy, I can do this for a function that takes integers as input with the following code:


julia> function sum_two(x::Integer)
           return x+2
       end
sum_two (generic function with 1 method)

julia> sum_two(Int8(4))
6

julia> sum_two(Int16(4))

However, when I try to this for the type Array{Integer,2} I get the following error:

julia> function sum_array(x::Array{Integer,2})
           return sum(x)
       end
sum_array (generic function with 1 method)

julia> sum_array(ones(Int8,10,10))
ERROR: MethodError: no method matching sum_array(::Array{Int8,2})
Closest candidates are:
  sum_array(::Array{Integer,2}) at REPL[4]:2
Stacktrace:
 [1] top-level scope at none:0

I don't what I could do to solve this. One option would be to define the method for every lowest-level subtype of Integer in the following way:

function sum_array(x::Array{Int8,2})
           return sum(x)
       end

function sum_array(x::Array{UInt8,2})
           return sum(x)
       end
.
.
.

But it doesn't look very practical.

like image 455
choforito84 Avatar asked Jan 27 '23 00:01

choforito84


1 Answers

First of all: specifying the types of input arguments to a function does not speed up the code. This is a misunderstanding. You should specify concrete field types when you define structs, but for function signatures it makes no difference whatsoever to performance. You use it to control dispatch.

Now, to your question: Julia's type parameters are invariant, meaning that even if S<:T is true, A{S}<:A{T} is not true. You can read more about that here: https://docs.julialang.org/en/v1/manual/types/index.html#Parametric-Composite-Types-1

Therefore, ones(Int8,10,10), which is a Matrix{Int8} is not a subtype of Matrix{Integer}.

To get your code to work, you can do this:

function sum_array(x::Array{T, 2}) where {T<:Integer}
    return sum(x)
end

or use this nice shortcut

function sum_array(x::Array{<:Integer, 2})
    return sum(x)
end
like image 136
DNF Avatar answered Jan 31 '23 10:01

DNF