loading...

That `overloaded` Trick: Overloading Lambdas in C++17

tmr232 profile image Tamir Bahar Updated on ・4 min read

C++17 has granted us with std::variant. Simply put, it is a type-safe union. To access the value it stores, you can either request a specific type (using std::get or something similar) or "visit" the variant, automatically handling only the data-type that is actually there.
Visiting is done using std::visit, and is fairly straight forward.

Compilation, Execution

#include <variant>
#include <cstdio>
#include <vector>

using var_t = std::variant<int, const char*>; // (1)

struct Print { // (2)
    void operator() (int i) {
        printf("%d\n", i);
    }

    void operator () (const char* str) {
        puts(str);
    }
};

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"}; // (3)

    for (auto& v : vars) {
        std::visit(Print{}, v); // (4)
    }

    return 0;
}

In (1) we define our variant type. In (2) we define a class with an overloaded operator(). This is needed for the call to std::visit. In (3) we define a vector of variants. In (4) we visit each variant. We pass in an instance of Print, and overload resolution ensures that the correct overload will be called for every type.
But this example forces us to write and name an object for the overloaded operator(). We can do better. In fact, the example for std::visit on cppreference already does. Here is an example derived from it:

Compilation, Execution

#include <variant>
#include <cstdio>
#include <vector>

template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; }; // (1)
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;  // (2)

using var_t = std::variant<int, const char*>;

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(overloaded {  // (3)
            [](int i) { printf("%d\n", i); },
            [](const char* str) { puts(str); }
        }, v);
    }

    return 0;
}

This is certainly more compact, and we removed the Print struct. But how does it work? You can see a class-template (1), lambdas passed in as arguments for the construction (3), and something with an arrow and some more template magic (2). Let's build it step by step.

First, we want to break the print functions out of Print and compose them later.

Compilation, Execution

struct PrintInt { //(1)
    void operator() (int i) {
        printf("%d\n", i);
    }
};

struct PrintCString { // (2)
    void operator () (const char* str) {
        puts(str);
    }
};

struct Print : PrintInt, PrintCString { // (3)
    using PrintInt::operator();
    using PrintCString::operator();
};

In (1) and (2), we define the same operators as before, but in separate structs. In (3), we are inherit from both of those structs, then explicitly use their operator(). This results in exactly the same results as before. Next, we convert Print into a class template. I'll jump ahead and convert it directly to a variadic template.

Compilation, Execution

template <class... Ts> // (1)
struct Print : Ts... {
    using Ts::operator()...;
};

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(Print<PrintCString, PrintInt>{}, v); // (2)
    }

    return 0;
}

In (1) we define the template. We take an arbitrary number of classes, inherit from them, and use their operator(). In (2) we instantiate the Print class-template with PrintCString and PrintInt to get their functionality.
Next, we want to use lambdas to do the same. This is possible because lambdas are not functions; they are objects implementing operator().

Compilation, Execution

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};
    auto PrintInt = [](int i) { printf("%d\n", i); }; // (1)
    auto PrintCString = [](const char* str) { puts(str); };

    for (auto& v : vars) {
        std::visit(
            Print<decltype(PrintCString), decltype(PrintInt)>{PrintCString, PrintInt}, // (2)
            v);
    }

    return 0;
}

In (1) we define the lambdas we need. In (2) we instantiate the template with our lambdas. This is ugly. Since lambdas have unique types, we need to define them before using them as template parameters (deducing their types using decltype). Then, we need to pass the lambdas as arguments for aggregate initialization as lambdas have a delete default constructor. We are close, but not quite there yet.
The <decltype(PrintCString), decltype(PrintInt)> part is really ugly, and causes repetition. But it is needed as ctors cannot do type-deduction. So in proper C++ style, we will create a function to circumvent that.

Compilation, Execution

template <class... Ts> // (1)
auto MakePrint(Ts... ts) {
    return Print<Ts...>{ts...};
}

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(
            MakePrint( // (2)
                [](const char* str) { puts(str); },
                [](int i) { printf("%d\n", i); }
                ),
            v);
    }

    return 0;
}

In (1) we define our helper function, to perform type deduction and forward it to the ctor. In (2) we take advantage of our newly found type-deduction to define the lambdas inline. But this is C++17, and we can do better.

C++17 added user-defined deduction guides. Those allow us to instruct the compiler to perform the same actions as our helper function, but without adding another function. Using a suitable deduction guide, the code is as follows.

Compilation, Execution

#include <variant>
#include <cstdio>
#include <vector>

using var_t = std::variant<int, const char*>;

template <class... Ts>
struct Print : Ts... {
    using Ts::operator()...;
};

template <class...Ts> Print(Ts...) -> Print<Ts...>; // (1)

int main() {
    std::vector<var_t> vars = {1, 2, "Hello, World!"};

    for (auto& v : vars) {
        std::visit(
            Print{ // (2)
                [](const char* str) { puts(str); },
                [](int i) { printf("%d\n", i); }
            },
            v);
    }

    return 0;
}

In (1) we define a deduction guide which acts as our previous helper function, and in (2) we use the constructor instead of a helper function. Done.

Now we have fully recreated the original example. As Print is no longer indicative of the template-class' behavior, overloaded is probably a better name.

template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

Discussion

pic
Editor guide
Collapse
mmatrosov profile image
Mikhail Matrosov

It would be nice to mention why we changed () to {} when we switched to deduction guide. E.g. if you add constructor

overloaded(Ts... ts) : Ts(ts)... {}

then you can use () again. At least I was curious about this.

Collapse
tmr232 profile image
Tamir Bahar Author

I did mention that we need to use aggregate initialization. Though I now think my reasoning for it is wrong. Also - I never tried making it work with ().

Collapse
misiaszek profile image
Marcin Misiaszek

This also works without deduction guide:

template<class... Ts> struct overload : Ts... {
  overload(Ts...) = delete;
  using Ts::operator()...;
};
Thread Thread
maxxon profile image
Ma-XX-oN

I'm confused with this part:

        Print{ // (2)
            [](const char* str) { puts(str); },
            [](int i) { printf("%d\n", i); }
        }

The braces indicate a constructor call, but there is none. There is a non-member helper template though, but it has no body. So, what's happening here?

Thread Thread
maxxon profile image
Ma-XX-oN

It's not a non-member helper template, it's a user defined deduction guide. See this article:

arne-mertz.de/2017/06/class-templa...

Thread Thread
misiaszek profile image
Marcin Misiaszek

It works also without deduction guide and the 'using' fix is for g++. 'operator()' cannot be ambiguous in visitors but gcc is very careful with overloading derived functions. In clang we can easily have the same results with minimal code (checked output assembler with Compiler Explorer):

template<class... Ts> struct overload : Ts... {
  overload(Ts...) = delete;
};

There is needed aggregate Initialisation + variadic templates.

Collapse
voins profile image
Alexey Voinov

Very cool code transformations. :) Thank you.

Collapse
void78 profile image
void78

Excellent brake-down, ta!

Collapse
stanlukanoff profile image
Stanimir Lukanov

That's just... weird...
PS: (...) this in the comment is not fold expression

Collapse
tmr232 profile image
Collapse
stanlukanoff profile image
Stanimir Lukanov

Where the standard is going to :) nevertheless the implementation is cool!

Collapse
mmatrosov profile image
Mikhail Matrosov

(2) and (3) in second code snippet are mixed up.

Collapse
tmr232 profile image
Tamir Bahar Author

Fixed. Thank you!