Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Wrong specialized generic function gets called in Swift 3 from an indirect call

I have code that follows the general design of:

protocol DispatchType {}
class DispatchType1: DispatchType {}
class DispatchType2: DispatchType {}

func doBar<D:DispatchType>(value:D) {
    print("general function called")
}

func doBar(value:DispatchType1) {
    print("DispatchType1 called")
}

func doBar(value:DispatchType2) {
    print("DispatchType2 called")
}

where in reality DispatchType is actually a backend storage. The doBarfunctions are optimized methods that depend on the correct storage type. Everything works fine if I do:

let d1 = DispatchType1()
let d2 = DispatchType2()

doBar(value: d1)    // "DispatchType1 called"
doBar(value: d2)    // "DispatchType2 called"

However, if I make a function that calls doBar:

func test<D:DispatchType>(value:D) {
    doBar(value: value)
}

and I try a similar calling pattern, I get:

test(value: d1)     // "general function called"
test(value: d2)     // "general function called"

This seems like something that Swift should be able to handle since it should be able to determine at compile time the type constraints. Just as a quick test, I also tried writing doBar as:

func doBar<D:DispatchType>(value:D) where D:DispatchType1 {
    print("DispatchType1 called")
}

func doBar<D:DispatchType>(value:D) where D:DispatchType2 {
    print("DispatchType2 called")
}

but get the same results.

Any ideas if this is correct Swift behavior, and if so, a good way to get around this behavior?

Edit 1: Example of why I was trying to avoid using protocols. Suppose I have the code (greatly simplified from my actual code):

protocol Storage {
     // ...
}

class Tensor<S:Storage> {
    // ...
}

For the Tensor class I have a base set of operations that can be performed on the Tensors. However, the operations themselves will change their behavior based on the storage. Currently I accomplish this with:

func dot<S:Storage>(_ lhs:Tensor<S>, _ rhs:Tensor<S>) -> Tensor<S> { ... }

While I can put these in the Tensor class and use extensions:

extension Tensor where S:CBlasStorage {
    func dot(_ tensor:Tensor<S>) -> Tensor<S> {
       // ...
    }
}

this has a few side effects which I don't like:

  1. I think dot(lhs, rhs) is preferable to lhs.dot(rhs). Convenience functions can be written to get around this, but that will create a huge explosion of code.

  2. This will cause the Tensor class to become monolithic. I really prefer having it contain the minimal amount of code necessary and expand its functionality by auxiliary functions.

  3. Related to (2), this means that anyone who wants to add new functionality will have to touch the base class, which I consider bad design.

Edit 2: One alternative is that things work expected if you use constraints for everything:

func test<D:DispatchType>(value:D) where D:DispatchType1 {
    doBar(value: value)
}

func test<D:DispatchType>(value:D) where D:DispatchType2 {
    doBar(value: value)
}

will cause the correct doBar to be called. This also isn't ideal, as it will cause a lot of extra code to be written, but at least lets me keep my current design.

Edit 3: I came across documentation showing the use of static keyword with generics. This helps at least with point (1):

class Tensor<S:Storage> {
   // ...
   static func cos(_ tensor:Tensor<S>) -> Tensor<S> {
       // ...
   }
}

allows you to write:

let result = Tensor.cos(value)

and it supports operator overloading:

let result = value1 + value2

it does have the added verbosity of required Tensor. This can made a little better with:

typealias T<S:Storage> = Tensor<S>
like image 556
Abe Schneider Avatar asked Feb 01 '17 12:02

Abe Schneider


1 Answers

This is indeed correct behaviour as overload resolution takes place at compile time (it would be a pretty expensive operation to take place at runtime). Therefore from within test(value:), the only thing the compiler knows about value is that it's of some type that conforms to DispatchType – thus the only overload it can dispatch to is func doBar<D : DispatchType>(value: D).

Things would be different if generic functions were always specialised by the compiler, because then a specialised implementation of test(value:) would know the concrete type of value and thus be able to pick the appropriate overload. However, specialisation of generic functions is currently only an optimisation (as without inlining, it can add significant bloat to your code), so this doesn't change the observed behaviour.

One solution in order to allow for polymorphism is to leverage the protocol witness table (see this great WWDC talk on them) by adding doBar() as a protocol requirement, and implementing the specialised implementations of it in the respective classes that conform to the protocol, with the general implementation being a part of the protocol extension.

This will allow for the dynamic dispatch of doBar(), thus allowing it to be called from test(value:) and having the correct implementation called.

protocol DispatchType {
    func doBar()
}

extension DispatchType {
    func doBar() {
        print("general function called")
    }
}

class DispatchType1: DispatchType {
    func doBar() {
        print("DispatchType1 called")
    }
}

class DispatchType2: DispatchType {
    func doBar() {
        print("DispatchType2 called")
    }
}

func test<D : DispatchType>(value: D) {
    value.doBar()
}

let d1 = DispatchType1()
let d2 = DispatchType2()

test(value: d1)    // "DispatchType1 called"
test(value: d2)    // "DispatchType2 called"
like image 93
Hamish Avatar answered Oct 04 '22 12:10

Hamish