HiveBrain v1.2.0
Get Started
← Back to all entries
snippetcppMinor

Efficient mechanism to generate operator overload functions

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
operatoroverloadefficientgeneratemechanismfunctions

Problem

I'm attempting to tidy up a C++ framework (PyCXX).

One particular file contains ~400 lines of operator overload functions, which I've managed to reduce to the following:

#define OP( op, l, r, cmpL, cmpR ) \
    bool operator op( l, r ) { return cmpL op cmpR; }

#define OPS( l, r, cmpL, cmpR ) \
    OP( !=, l, r, cmpL, cmpR ) \
    OP( ==, l, r, cmpL, cmpR ) \
    OP( > , l, r, cmpL, cmpR ) \
    OP( >=, l, r, cmpL, cmpR ) \
    OP( < , l, r, cmpL, cmpR ) \
    OP( <=, l, r, cmpL, cmpR )

    OPS( const Long &a, const Long &b,   a.as_long() , b.as_long() )
    OPS( const Long &a,       int   b,   a.as_long() , b           )
    OPS( const Long &a,       long  b,   a.as_long() , b           )
    OPS( int         a, const Long &b,   a           , b.as_long() )
    OPS( long        a, const Long &b,   a           , b.as_long() )

#ifdef HAVE_LONG_LONG
    OPS( const Long &a, PY_LONG_LONG b,   a.as_long_long() , b                )
    OPS( PY_LONG_LONG a, const Long &b,   a                , b.as_long_long() )
#endif

    //------------------------------------------------------------
    // compare operators
    OPS( const Float &a, const Float &b,   a.as_double() , b.as_double() )
    OPS( const Float &a, double       b,   a.as_double() , b             )
    OPS( double a,       const Float &b,   a             , b.as_double() )

}    // end of namespace Py


However, I wonder if it may be possible to tidy it up further.

It looks as though I can save having to pass all the parameters each time by using:

#define OPS( ... ) \
    OP( !=, ##__VA_ARGS__ ) \
    OP( ==, ##__VA_ARGS__ ) \
    OP( > , ##__VA_ARGS__ ) \
    OP( >=, ##__VA_ARGS__ ) \
    OP( < , ##__VA_ARGS__ ) \
    OP( <=, ##__VA_ARGS__ )


Although I have a hunch this might be making the code less transparent.

Another possible idea would involve storing pairs:

```
#define FD(op, x) ( Float &x , op, x.as_double() )
#define L1(op, x) ( const Long &x , op, x.as_long() )
#define

Solution

I was looking a bit at your PyCXX and want to offer you Barton–Nackman trick using Curiously recurring template pattern together with variadic template recursion (through partial specialization). This way you can solve all your types and conversion operators with one template:

template struct Arithmetic;
using Long = Arithmetic;
using Double = Arithmetic;

template struct compare_operators {};

template
  struct compare_operators
  : compare_operators
{
    friend bool operator == (const Base& a, First b) {
        return a.template as() == b; }
    friend bool operator != (const Base& a, First b) {
        return a.template as() != b; }
    friend bool operator () ()   (const Base& a, First b) {
        return a.template as() >  b; }
    friend bool operator >= (const Base& a, First b) {
        return a.template as() >= b; }

    friend bool operator == (const Base& a, const Arithmetic& b) {
        return a.template as() == b.value; }
    friend bool operator != (const Base& a, const Arithmetic& b) {
        return a.template as() != b.value; }
    friend bool operator & b) {
        return a.template as() & b) {
        return a.template as()   (const Base& a, const Arithmetic& b) {
        return a.template as() >  b.value; }
    friend bool operator >= (const Base& a, const Arithmetic& b) {
        return a.template as() >= b.value; }
};

template struct Arithmetic
  : compare_operators
  , long long, int, short, double, float, long double>
{
    T value;
    Arithmetic(T value): value(value) {}

    template>::value>>
      X as() const { return (X)value; }
};

int main() {
    Long val(3);
    if(val > 2) cout  2\n";
    if(val <= 3.2) cout << "val <= 3.2\n";
    Double dbl(3.14);
    if(val < dbl) cout << "val < dbl\n";
}


Beware that the comparision operators need some fine-tuning especially when you compare integral vs. floating-point (which to choose as common type) and that dacay_t in conversion operator may be problematic (not sure). (Finally imagine using namespace std, but I use my own header while explicitly using selected std features/classes in the namespace and then using the namespace).

BTW: This would be the same (but the as conversion is altered) in my own syntax, which you may find a bit familiar:

#include "basics.hpp"
using*firda

forward template struct Arithmetic
using Long = Arithmetic
using Double = Arithmetic

template
  struct compare_operators

template
  struct compare_operators
  : compare_operators

    friend bool operator == (const Base& a, First b)
        return a.template as() == b
    friend bool operator != (const Base& a, First b)
        return a.template as() != b
    friend bool operator () ()   (const Base& a, First b)
        return a.template as() >  b
    friend bool operator >= (const Base& a, First b)
        return a.template as() >= b

    friend bool operator == (const Base& a, const Arithmetic& b)
        return a.template as() == b.value
    friend bool operator != (const Base& a, const Arithmetic& b)
        return a.template as() != b.value
    friend bool operator & b)
        return a.template as() & b)
        return a.template as()   (const Base& a, const Arithmetic& b)
        return a.template as() >  b.value
    friend bool operator >= (const Base& a, const Arithmetic& b)
        return a.template as() >= b.value

template struct Arithmetic
  : compare_operators
  , long long, int, short, double, float, long double>

    T value
    Arithmetic(T value): value(value) {}

    template
      enable_if_t::value,
      X> as() const
        return value

int main()
    Long val(3)
    if val > 2; cout  2\n"
    if val <= 3.2; cout << "val <= 3.2\n"
    Double dbl(3.14)
    if val < dbl; cout << "val < dbl\n"

Code Snippets

template<typename T> struct Arithmetic;
using Long = Arithmetic<long long>;
using Double = Arithmetic<long double>;

template<class Base, class... Other> struct compare_operators {};

template<class Base, class First, class... Other>
  struct compare_operators<Base, First, Other...>
  : compare_operators<Base, Other...>
{
    friend bool operator == (const Base& a, First b) {
        return a.template as<First>() == b; }
    friend bool operator != (const Base& a, First b) {
        return a.template as<First>() != b; }
    friend bool operator <  (const Base& a, First b) {
        return a.template as<First>() <  b; }
    friend bool operator <= (const Base& a, First b) {
        return a.template as<First>() <= b; }
    friend bool operator >  (const Base& a, First b) {
        return a.template as<First>() >  b; }
    friend bool operator >= (const Base& a, First b) {
        return a.template as<First>() >= b; }

    friend bool operator == (const Base& a, const Arithmetic<First>& b) {
        return a.template as<First>() == b.value; }
    friend bool operator != (const Base& a, const Arithmetic<First>& b) {
        return a.template as<First>() != b.value; }
    friend bool operator <  (const Base& a, const Arithmetic<First>& b) {
        return a.template as<First>() <  b.value; }
    friend bool operator <= (const Base& a, const Arithmetic<First>& b) {
        return a.template as<First>() <= b.value; }
    friend bool operator >  (const Base& a, const Arithmetic<First>& b) {
        return a.template as<First>() >  b.value; }
    friend bool operator >= (const Base& a, const Arithmetic<First>& b) {
        return a.template as<First>() >= b.value; }
};

template<typename T> struct Arithmetic
  : compare_operators<Arithmetic<T>
  , long long, int, short, double, float, long double>
{
    T value;
    Arithmetic(T value): value(value) {}

    template<typename X, typename = enable_if_t<is_arithmetic<decay_t<X>>::value>>
      X as() const { return (X)value; }
};

int main() {
    Long val(3);
    if(val > 2) cout << "val > 2\n";
    if(val <= 3.2) cout << "val <= 3.2\n";
    Double dbl(3.14);
    if(val < dbl) cout << "val < dbl\n";
}
#include "basics.hpp"
using*firda

forward template<typename T> struct Arithmetic
using Long = Arithmetic<long long>
using Double = Arithmetic<long double>

template<class Base, class... Other>
  struct compare_operators

template<class Base, class First, class... Other>
  struct compare_operators<Base, First, Other...>
  : compare_operators<Base, Other...>

    friend bool operator == (const Base& a, First b)
        return a.template as<First>() == b
    friend bool operator != (const Base& a, First b)
        return a.template as<First>() != b
    friend bool operator <  (const Base& a, First b)
        return a.template as<First>() <  b
    friend bool operator <= (const Base& a, First b)
        return a.template as<First>() <= b
    friend bool operator >  (const Base& a, First b)
        return a.template as<First>() >  b
    friend bool operator >= (const Base& a, First b)
        return a.template as<First>() >= b

    friend bool operator == (const Base& a, const Arithmetic<First>& b)
        return a.template as<First>() == b.value
    friend bool operator != (const Base& a, const Arithmetic<First>& b)
        return a.template as<First>() != b.value
    friend bool operator <  (const Base& a, const Arithmetic<First>& b)
        return a.template as<First>() <  b.value
    friend bool operator <= (const Base& a, const Arithmetic<First>& b)
        return a.template as<First>() <= b.value
    friend bool operator >  (const Base& a, const Arithmetic<First>& b)
        return a.template as<First>() >  b.value
    friend bool operator >= (const Base& a, const Arithmetic<First>& b)
        return a.template as<First>() >= b.value

template<typename T> struct Arithmetic
  : compare_operators<Arithmetic<T>
  , long long, int, short, double, float, long double>

    T value
    Arithmetic(T value): value(value) {}

    template<typename X>
      enable_if_t<is_arithmetic<X>::value,
      X> as() const
        return value

int main()
    Long val(3)
    if val > 2; cout << "val > 2\n"
    if val <= 3.2; cout << "val <= 3.2\n"
    Double dbl(3.14)
    if val < dbl; cout << "val < dbl\n"

Context

StackExchange Code Review Q#64829, answer score: 5

Revisions (0)

No revisions yet.