As an artificial example suppose I have a parametric struct where T <: AbstractFloat
mutable struct Summary{T<:AbstractFloat}
count
sum::T
end
I would like to type the count
field as UInt16
when T === Float16
or as UInt32
when T === Float32
and as UInt64
in all other cases.
My current approach is to use a union type Union{UInt16, UInt32, UInt64}
for the count
field
module SummaryStats
export Summary, avg
const CounterType = Union{UInt16, UInt32, UInt64}
mutable struct Summary{T<:AbstractFloat}
count::CounterType
sum::T
# explicitly typed no-arg constructor
Summary{T}() where {T<:AbstractFloat} = new(_counter(T), zero(T))
end
# untyped no-arg constructor defaults to Float64
Summary() = Summary{Float64}()
function avg(summary::Summary{T})::T where {T <: AbstractFloat}
if summary.count > zero(_counter(typeof(T)))
summary.sum / summary.count
else
zero(T)
end
end
# internal helper functions, not exported
Base.@pure _counter(::Type{Float16})::UInt16 = UInt16(0)
Base.@pure _counter(::Type{Float32})::UInt32 = UInt32(0)
Base.@pure _counter(::DataType)::UInt64 = UInt64(0)
end # module
This seems to work but, obviously, @code_warntype
is not happy with the union type for the count
field.
I am wondering whether it is possible to somehow compute the correct concrete type according to the rules laid out above?
The "outer-only" constructors are mainly for these kinda use cases:
julia> const CounterType = Union{UInt16, UInt32, UInt64}
Union{UInt16, UInt32, UInt64}
julia> mutable struct Summary{T<:AbstractFloat, S<:CounterType}
count::S
sum::T
function Summary{T}() where {T<:AbstractFloat}
S = T === Float16 ? UInt16 :
T === Float32 ? UInt32 :
T === Float64 ? UInt64 : throw(ArgumentError("unexpected type: $(T)!"))
new{T,S}(zero(S), zero(T))
end
end
julia> Summary() = Summary{Float64}()
Summary
julia> function avg(summary::Summary{T})::T where {T <: AbstractFloat}
if summary.count > zero(summary.count)
summary.sum / summary.count
else
zero(T)
end
end
avg (generic function with 1 method)
julia> avg(Summary())
0.0
julia> @code_warntype avg(Summary())
Body::Float64
1 ─ %1 = (Base.getfield)(summary, :count)::UInt64
│ %2 = (Base.ult_int)(0x0000000000000000, %1)::Bool
└── goto #3 if not %2
2 ─ %4 = (Base.getfield)(summary, :sum)::Float64
│ %5 = (Base.getfield)(summary, :count)::UInt64
│ %6 = (Base.uitofp)(Float64, %5)::Float64
│ %7 = (Base.div_float)(%4, %6)::Float64
└── return %7
3 ─ return 0.0
julia> @code_warntype avg(Summary{Float32}())
Body::Float32
1 ─ %1 = (Base.getfield)(summary, :count)::UInt32
│ %2 = (Base.ult_int)(0x00000000, %1)::Bool
└── goto #3 if not %2
2 ─ %4 = (Base.getfield)(summary, :sum)::Float32
│ %5 = (Base.getfield)(summary, :count)::UInt32
│ %6 = (Base.uitofp)(Float32, %5)::Float32
│ %7 = (Base.div_float)(%4, %6)::Float32
└── return %7
3 ─ return 0.0f0
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With