Skip to content

Commit

Permalink
Optimize tail call via symmetric transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
BewareMyPower committed Apr 1, 2024
1 parent 6878abb commit 9725dd8
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions coroutines/tail_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,34 @@ class Task {
struct promise_type {
auto get_return_object();
auto initial_suspend() const noexcept { return std::suspend_always{}; }
auto return_value(T value) noexcept {
sum += value;
std::cout << "Promise::return_value " << sum << " " << this << std::endl;
}
auto final_suspend() const noexcept {
std::cout << "Promise::final_suspend " << this << std::endl;
return std::suspend_always{};
}
auto return_value(T value) noexcept { sum = value; }
auto final_suspend() const noexcept { return FinalAwaiter{}; }
auto unhandled_exception() const noexcept {}

T sum;
std::coroutine_handle<promise_type> caller;

promise_type() {
std::cout << "Promise(" << sum << ") " << this << std::endl;
}

~promise_type() {
std::cout << "~Promise(" << sum << ") " << this << std::endl;
}
struct FinalAwaiter : std::suspend_always {
std::coroutine_handle<> await_suspend(
std::coroutine_handle<promise_type> caller) noexcept;
};
};

Task(promise_type& promise)
: handle_(std::coroutine_handle<promise_type>::from_promise(promise)) {}
~Task() { handle_.destroy(); }

decltype(auto) operator()() noexcept {
handle_.resume();
return std::move(handle_.promise().sum);
}
decltype(auto) operator()() noexcept { handle_.resume(); }

auto result() const noexcept { return handle_.promise().sum; }

constexpr auto await_ready() const noexcept { return false; }

auto await_suspend(std::coroutine_handle<promise_type> caller) const {
handle_.resume();
return caller;
}
decltype(auto) await_resume() noexcept {
return std::move(handle_.promise().sum);
handle_.promise().caller = caller;
return handle_;
}
decltype(auto) await_resume() noexcept { return handle_.promise().sum; }

private:
std::coroutine_handle<promise_type> handle_;
Expand All @@ -61,15 +51,28 @@ inline auto Task<T>::promise_type::get_return_object() {
return Task{*this};
}

Task<int> f(int initial, int x) {
std::cout << "f(" << initial << ", " << x << ")\n";
template <Addable T>
inline std::coroutine_handle<>
Task<T>::promise_type::FinalAwaiter::await_suspend(
std::coroutine_handle<promise_type> handle) noexcept {
auto caller = handle.promise().caller;
if (caller) {
return caller;
} else {
return std::noop_coroutine();
}
}

Task<long> f(long initial, long x) {
if (x <= 0) {
co_return initial;
}
co_return co_await f(initial + x, x - 1);
}

int main() {
auto coro = f(0, 3);
std::cout << coro() << std::endl;
int main(int argc, char* argv[]) {
auto n = (argc > 1) ? std::stol(argv[1]) : 5;
auto coro = f(0, n);
coro();
std::cout << coro.result() << std::endl;
}

0 comments on commit 9725dd8

Please sign in to comment.