Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Class Polymorphism and equality operators

I am trying to wrap my head around something I've been wondering for quite some time now. Assume I have a class Base

class Base
{
public:
    virtual ~Base(){}
    virtual bool operator== ( const Base & rhs ) const;
};

Now, another class inherits from it. It has two equality operators:

class A : public Base
{
public:
    bool operator== ( const A & rhs ) const;
    bool operator== ( const Base & rhs ) const;
private:
    int index__;
};

And yet another class which also inherits from Base, and which also has two equality operators:

class B : public Base
{
public:
    bool operator== ( const B & rhs ) const;
    bool operator== ( const Base & rhs ) const;
private:
    int index__;
};

This is what I understand (which is not necessarily correct). I can use the first operator only to check if same class objects are equal. Yet, I can use the second operator to check if they are the same type of class, and then if they are equal. Now, yet another class exists, which wraps around pointers of Base, which are however, polymorphic types A, or B.

class Z
{
public:
    bool operator== ( const Z & rhs ) const;
private:
    std::shared_ptr<Base> ptr__;
};

First things first, I've found out, that I can't have two operator== overloaded. I get no errors from the compiler, but when I try to run it, it just hangs. I am guessing it has something to do with rtti, which is beyond me.

What I have been using, and is quite ugly, is attempting to downcast, and if I can, then try to compare instances, within class Z:

bool Z::operator== ( const Z & rhs ) const
{
    if ( const auto a1 = std::dynamic_pointer_cast<A>( this->ptr__ ) )
        if ( const auto a2 = std::dynamic_pointer_cast<A>( rhs.ptr__ ) )
            return *a1 == *a2; 
    else if ( const auto b1 = std::dynamic_pointer_cast<B>( this->ptr__ ) )
        if ( const auto b2 = std::dynamic_pointer_cast<B>( rhs.ptr__ ) )
            return *b1 == *b2;
    return false;
}

This is quite ugly, and it assumes that your class A and B, have an equality operator which takes as parameter the same type class.

So I tried to come up with a way, which would use the second type of operator, more agnostic, more elegant if you will. And failed. This would require to use it in both classes A and B, thus moving it away from class Z.

bool A::operator== ( const Base & rhs ) const
{
    return ( typeid( *this ) == typeid( rhs ) ) && ( *this == rhs );
}

Same for class B. This doesn't seem to work (app hangs without any errors). Furthermore, it uses some kind of default operator, or does it use the base class operator? Ideally, it should use both the Base::operator== and compare class types.

If however, I want a more elaborate comparison, based upon a member of class A, or B, such as index__ then I obviously have to friend each class, because when I try this, it won't compile (unless of course I add a getter or make it somehow visible):

bool A::operator== ( const Base & rhs ) const
{
    return ( typeid( *this ) == typeid( rhs ) )
           && (*this == *rhs )
           && (this->index__ == rhs.index__ );
}

Is there an elegant, simple solution to this? Am I confined to downcasting and trying, or is there some other way to achieve what I want?

like image 347
Ælex Avatar asked Mar 31 '15 23:03

Ælex


2 Answers

I agree with @vsoftco about only implementing operator== in the base class and using the NVI idiom. However, I'd provide a pure virtual function that the derived classes need to implement to perform the equality check. In this way the base class doesn't know or care about what it means for any derived classes to be equivalent.

Code

#include <iostream>
#include <string>
#include <typeinfo>

class Base
{
public:
    virtual ~Base() {}

    bool operator==(const Base& other) const
    {
        // If the derived types are the same then compare them
        return typeid(*this) == typeid(other) && isEqual(other);
    }

private:
    // A pure virtual function derived classes must implement.
    // Furthermore, this function has a precondition that it will only
    // be called when the 'other' is the same type as the instance
    // invoking the function.
    virtual bool isEqual(const Base& other) const = 0;
};

class D1 : public Base
{
public:
    explicit D1(double v = 0.0) : mValue(v) {}
    virtual ~D1() override {}

private:
    virtual bool isEqual(const Base& other) const
    {
        // The cast is safe because of the precondition documented in the
        // base class
        return mValue == static_cast<const D1&>(other).mValue;
    }

private:
    double mValue;
};

class D2 : public Base
{
public:
    explicit D2(std::string v = "") : mValue(v) {}
    virtual ~D2() override {}

private:
    virtual bool isEqual(const Base& other) const
    {
        return mValue == static_cast<const D2&>(other).mValue;
    }

private:
    std::string mValue;
};

class D3 : public Base
{
public:
    explicit D3(int v = 0) : mValue(v) {}
    virtual ~D3() override {}

private:
    virtual bool isEqual(const Base& other) const
    {
        return mValue == static_cast<const D3&>(other).mValue;
    }

private:
    int mValue;
};

int main()
{
    D1 d1a(1.0);
    D1 d1b(2.0);
    D1 d1c(1.0);

    D2 d2a("1");
    D2 d2b("2");
    D2 d2c("1");

    D3 d3a(1);
    D3 d3b(2);
    D3 d3c(1);

    std::cout << "Compare D1 types\n";
    std::cout << std::boolalpha << (d1a == d1b) << "\n";
    std::cout << std::boolalpha << (d1b == d1c) << "\n";
    std::cout << std::boolalpha << (d1a == d1c) << "\n";

    std::cout << "Compare D2 types\n";
    std::cout << std::boolalpha << (d2a == d2b) << "\n";
    std::cout << std::boolalpha << (d2b == d2c) << "\n";
    std::cout << std::boolalpha << (d2a == d2c) << "\n";

    std::cout << "Compare D3 types\n";
    std::cout << std::boolalpha << (d3a == d3b) << "\n";
    std::cout << std::boolalpha << (d3b == d3c) << "\n";
    std::cout << std::boolalpha << (d3a == d3c) << "\n";

    std::cout << "Compare mixed derived types\n";
    std::cout << std::boolalpha << (d1a == d2a) << "\n";
    std::cout << std::boolalpha << (d2a == d3a) << "\n";
    std::cout << std::boolalpha << (d1a == d3a) << "\n";
    std::cout << std::boolalpha << (d1b == d2b) << "\n";
    std::cout << std::boolalpha << (d2b == d3b) << "\n";
    std::cout << std::boolalpha << (d1b == d3b) << "\n";
    std::cout << std::boolalpha << (d1c == d2c) << "\n";
    std::cout << std::boolalpha << (d2c == d3c) << "\n";
    std::cout << std::boolalpha << (d1c == d3c) << "\n";

    return 0;
}

Output

Compare D1 types
false
false
true
Compare D2 types
false
false
true
Compare D3 types
false
false
true
Compare mixed derived types
false
false
false
false
false
false
false
false
false
like image 161
James Adkison Avatar answered Oct 01 '22 07:10

James Adkison


In general, in a hierarchy, one should have a common interface, and imo the operator== should be only implemented in the Base class using (virtual) getters from the interface. Otherwise it's like re-defining functions (without using virtual) down the hierarchy, which is almost always a bad idea. So you may want to think about your design, having more than one operator== seems fishy.

Very simple example:

#include <iostream>

class A
{
    int _x;
public:
    A(int x):_x(x){}
    virtual int getx() const { return _x; } // runtime
    bool operator==(const A& other){return getx() == other.getx();} // one implementation
};

class B: public A
{
    using A::A;
    int getx() const override // Make all B's equal, runtime
    {
        return 0; // so always 0, all B's are equal
    }
};

int main()
{
    A a1(10), a2(20);
    B b1(10), b2(20);
    std::cout << std::boolalpha << (a1==a2) << std::endl; // false
    std::cout << std::boolalpha << (b1==b2) << std::endl; // always true
}

This pattern is usually called the non-virtual interface idiom and is a manifestation of the so called template method (has nothing to do with templates, just an unfortunate name), in which you have clients (along the hierarchy) call virtual functions indirectly through public non-virtual member functions. Item 55 of Scott Meyers' Effective C++ has an excellent discussion about this issue.

like image 22
vsoftco Avatar answered Oct 01 '22 08:10

vsoftco