Line data Source code
1 : //
2 : // Copyright (c) 2026 Steve Gerbino
3 : //
4 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 : //
7 : // Official repository: https://github.com/cppalliance/capy
8 : //
9 :
10 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 : #define BOOST_CAPY_WHEN_ALL_HPP
12 :
13 : #include <boost/capy/detail/config.hpp>
14 : #include <boost/capy/concept/executor.hpp>
15 : #include <boost/capy/concept/io_launchable_task.hpp>
16 : #include <boost/capy/coro.hpp>
17 : #include <boost/capy/ex/executor_ref.hpp>
18 : #include <boost/capy/ex/frame_allocator.hpp>
19 : #include <boost/capy/task.hpp>
20 :
21 : #include <array>
22 : #include <atomic>
23 : #include <exception>
24 : #include <optional>
25 : #include <stop_token>
26 : #include <tuple>
27 : #include <type_traits>
28 : #include <utility>
29 :
30 : namespace boost {
31 : namespace capy {
32 :
33 : namespace detail {
34 :
35 : /** Type trait to filter void types from a tuple.
36 :
37 : Void-returning tasks do not contribute a value to the result tuple.
38 : This trait computes the filtered result type.
39 :
40 : Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 : */
42 : template<typename T>
43 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44 :
45 : template<typename... Ts>
46 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47 :
48 : /** Holds the result of a single task within when_all.
49 : */
50 : template<typename T>
51 : struct result_holder
52 : {
53 : std::optional<T> value_;
54 :
55 44 : void set(T v)
56 : {
57 44 : value_ = std::move(v);
58 44 : }
59 :
60 37 : T get() &&
61 : {
62 37 : return std::move(*value_);
63 : }
64 : };
65 :
66 : /** Specialization for void tasks - no value storage needed.
67 : */
68 : template<>
69 : struct result_holder<void>
70 : {
71 : };
72 :
73 : /** Shared state for when_all operation.
74 :
75 : @tparam Ts The result types of the tasks.
76 : */
77 : template<typename... Ts>
78 : struct when_all_state
79 : {
80 : static constexpr std::size_t task_count = sizeof...(Ts);
81 :
82 : // Completion tracking - when_all waits for all children
83 : std::atomic<std::size_t> remaining_count_;
84 :
85 : // Result storage in input order
86 : std::tuple<result_holder<Ts>...> results_;
87 :
88 : // Runner handles - destroyed in await_resume while allocator is valid
89 : std::array<coro, task_count> runner_handles_{};
90 :
91 : // Exception storage - first error wins, others discarded
92 : std::atomic<bool> has_exception_{false};
93 : std::exception_ptr first_exception_;
94 :
95 : // Stop propagation - on error, request stop for siblings
96 : std::stop_source stop_source_;
97 :
98 : // Connects parent's stop_token to our stop_source
99 : struct stop_callback_fn
100 : {
101 : std::stop_source* source_;
102 1 : void operator()() const { source_->request_stop(); }
103 : };
104 : using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 : std::optional<stop_callback_t> parent_stop_callback_;
106 :
107 : // Parent resumption
108 : coro continuation_;
109 : executor_ref caller_ex_;
110 :
111 24 : when_all_state()
112 24 : : remaining_count_(task_count)
113 : {
114 24 : }
115 :
116 24 : ~when_all_state()
117 : {
118 84 : for(auto h : runner_handles_)
119 60 : if(h)
120 60 : h.destroy();
121 24 : }
122 :
123 : /** Capture an exception (first one wins).
124 : */
125 11 : void capture_exception(std::exception_ptr ep)
126 : {
127 11 : bool expected = false;
128 11 : if(has_exception_.compare_exchange_strong(
129 : expected, true, std::memory_order_relaxed))
130 8 : first_exception_ = ep;
131 11 : }
132 :
133 : /** Signal that a task has completed.
134 :
135 : The last child to complete triggers resumption of the parent.
136 : Dispatch handles thread affinity: resumes inline if on same
137 : thread, otherwise posts to the caller's executor.
138 : */
139 60 : coro signal_completion()
140 : {
141 60 : auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel);
142 60 : if(remaining == 1)
143 24 : caller_ex_.dispatch(continuation_);
144 60 : return std::noop_coroutine();
145 : }
146 :
147 : };
148 :
149 : /** Wrapper coroutine that intercepts task completion.
150 :
151 : This runner awaits its assigned task and stores the result in
152 : the shared state, or captures the exception and requests stop.
153 : */
154 : template<typename T, typename... Ts>
155 : struct when_all_runner
156 : {
157 : struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
158 : {
159 : when_all_state<Ts...>* state_ = nullptr;
160 : executor_ref ex_;
161 : std::stop_token stop_token_;
162 :
163 60 : when_all_runner get_return_object()
164 : {
165 60 : return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
166 : }
167 :
168 60 : std::suspend_always initial_suspend() noexcept
169 : {
170 60 : return {};
171 : }
172 :
173 60 : auto final_suspend() noexcept
174 : {
175 : struct awaiter
176 : {
177 : promise_type* p_;
178 :
179 60 : bool await_ready() const noexcept
180 : {
181 60 : return false;
182 : }
183 :
184 60 : coro await_suspend(coro) noexcept
185 : {
186 : // Signal completion; last task resumes parent
187 60 : return p_->state_->signal_completion();
188 : }
189 :
190 0 : void await_resume() const noexcept
191 : {
192 0 : }
193 : };
194 60 : return awaiter{this};
195 : }
196 :
197 49 : void return_void()
198 : {
199 49 : }
200 :
201 11 : void unhandled_exception()
202 : {
203 11 : state_->capture_exception(std::current_exception());
204 : // Request stop for sibling tasks
205 11 : state_->stop_source_.request_stop();
206 11 : }
207 :
208 : template<class Awaitable>
209 : struct transform_awaiter
210 : {
211 : std::decay_t<Awaitable> a_;
212 : promise_type* p_;
213 :
214 60 : bool await_ready()
215 : {
216 60 : return a_.await_ready();
217 : }
218 :
219 60 : decltype(auto) await_resume()
220 : {
221 60 : return a_.await_resume();
222 : }
223 :
224 : template<class Promise>
225 60 : auto await_suspend(std::coroutine_handle<Promise> h)
226 : {
227 60 : return a_.await_suspend(h, p_->ex_, p_->stop_token_);
228 : }
229 : };
230 :
231 : template<class Awaitable>
232 60 : auto await_transform(Awaitable&& a)
233 : {
234 : using A = std::decay_t<Awaitable>;
235 : if constexpr (IoAwaitable<A>)
236 : {
237 : return transform_awaiter<Awaitable>{
238 120 : std::forward<Awaitable>(a), this};
239 : }
240 : else
241 : {
242 : static_assert(sizeof(A) == 0, "requires IoAwaitable");
243 : }
244 60 : }
245 : };
246 :
247 : std::coroutine_handle<promise_type> h_;
248 :
249 60 : explicit when_all_runner(std::coroutine_handle<promise_type> h)
250 60 : : h_(h)
251 : {
252 60 : }
253 :
254 : // Enable move for all clang versions - some versions need it
255 : when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
256 :
257 : // Non-copyable
258 : when_all_runner(when_all_runner const&) = delete;
259 : when_all_runner& operator=(when_all_runner const&) = delete;
260 : when_all_runner& operator=(when_all_runner&&) = delete;
261 :
262 60 : auto release() noexcept
263 : {
264 60 : return std::exchange(h_, nullptr);
265 : }
266 : };
267 :
268 : /** Create a runner coroutine for a single task.
269 :
270 : Task is passed directly to ensure proper coroutine frame storage.
271 : */
272 : template<std::size_t Index, typename T, typename... Ts>
273 : when_all_runner<T, Ts...>
274 60 : make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
275 : {
276 : if constexpr (std::is_void_v<T>)
277 : {
278 : co_await std::move(inner);
279 : }
280 : else
281 : {
282 : std::get<Index>(state->results_).set(co_await std::move(inner));
283 : }
284 120 : }
285 :
286 : /** Internal awaitable that launches all runner coroutines and waits.
287 :
288 : This awaitable is used inside the when_all coroutine to handle
289 : the concurrent execution of child tasks.
290 : */
291 : template<typename... Ts>
292 : class when_all_launcher
293 : {
294 : std::tuple<task<Ts>...>* tasks_;
295 : when_all_state<Ts...>* state_;
296 :
297 : public:
298 24 : when_all_launcher(
299 : std::tuple<task<Ts>...>* tasks,
300 : when_all_state<Ts...>* state)
301 24 : : tasks_(tasks)
302 24 : , state_(state)
303 : {
304 24 : }
305 :
306 24 : bool await_ready() const noexcept
307 : {
308 24 : return sizeof...(Ts) == 0;
309 : }
310 :
311 24 : coro await_suspend(coro continuation, executor_ref caller_ex, std::stop_token parent_token = {})
312 : {
313 24 : state_->continuation_ = continuation;
314 24 : state_->caller_ex_ = caller_ex;
315 :
316 : // Forward parent's stop requests to children
317 24 : if(parent_token.stop_possible())
318 : {
319 8 : state_->parent_stop_callback_.emplace(
320 : parent_token,
321 4 : typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
322 :
323 4 : if(parent_token.stop_requested())
324 1 : state_->stop_source_.request_stop();
325 : }
326 :
327 : // Launch all tasks concurrently
328 24 : auto token = state_->stop_source_.get_token();
329 48 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
330 24 : (..., launch_one<Is>(caller_ex, token));
331 24 : }(std::index_sequence_for<Ts...>{});
332 :
333 : // Let signal_completion() handle resumption
334 48 : return std::noop_coroutine();
335 24 : }
336 :
337 24 : void await_resume() const noexcept
338 : {
339 : // Results are extracted by the when_all coroutine from state
340 24 : }
341 :
342 : private:
343 : template<std::size_t I>
344 60 : void launch_one(executor_ref caller_ex, std::stop_token token)
345 : {
346 60 : auto runner = make_when_all_runner<I>(
347 60 : std::move(std::get<I>(*tasks_)), state_);
348 :
349 60 : auto h = runner.release();
350 60 : h.promise().state_ = state_;
351 60 : h.promise().ex_ = caller_ex;
352 60 : h.promise().stop_token_ = token;
353 :
354 60 : coro ch{h};
355 60 : state_->runner_handles_[I] = ch;
356 60 : state_->caller_ex_.dispatch(ch);
357 60 : }
358 : };
359 :
360 : /** Compute the result type for when_all.
361 :
362 : Returns void when all tasks are void (P2300 aligned),
363 : otherwise returns a tuple with void types filtered out.
364 : */
365 : template<typename... Ts>
366 : using when_all_result_t = std::conditional_t<
367 : std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
368 : void,
369 : filter_void_tuple_t<Ts...>>;
370 :
371 : /** Helper to extract a single result, returning empty tuple for void.
372 : This is a separate function to work around a GCC-11 ICE that occurs
373 : when using nested immediately-invoked lambdas with pack expansion.
374 : */
375 : template<std::size_t I, typename... Ts>
376 39 : auto extract_single_result(when_all_state<Ts...>& state)
377 : {
378 : using T = std::tuple_element_t<I, std::tuple<Ts...>>;
379 : if constexpr (std::is_void_v<T>)
380 2 : return std::tuple<>();
381 : else
382 37 : return std::make_tuple(std::move(std::get<I>(state.results_)).get());
383 : }
384 :
385 : /** Extract results from state, filtering void types.
386 : */
387 : template<typename... Ts>
388 15 : auto extract_results(when_all_state<Ts...>& state)
389 : {
390 30 : return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
391 15 : return std::tuple_cat(extract_single_result<Is>(state)...);
392 30 : }(std::index_sequence_for<Ts...>{});
393 : }
394 :
395 : } // namespace detail
396 :
397 : /** Execute multiple tasks concurrently and collect their results.
398 :
399 : Launches all tasks simultaneously and waits for all to complete
400 : before returning. Results are collected in input order. If any
401 : task throws, cancellation is requested for siblings and the first
402 : exception is rethrown after all tasks complete.
403 :
404 : @li All child tasks run concurrently on the caller's executor
405 : @li Results are returned as a tuple in input order
406 : @li Void-returning tasks do not contribute to the result tuple
407 : @li If all tasks return void, `when_all` returns `task<void>`
408 : @li First exception wins; subsequent exceptions are discarded
409 : @li Stop is requested for siblings on first error
410 : @li Completes only after all children have finished
411 :
412 : @par Thread Safety
413 : The returned task must be awaited from a single execution context.
414 : Child tasks execute concurrently but complete through the caller's
415 : executor.
416 :
417 : @param tasks The tasks to execute concurrently. Each task is
418 : consumed (moved-from) when `when_all` is awaited.
419 :
420 : @return A task yielding a tuple of non-void results. Returns
421 : `task<void>` when all input tasks return void.
422 :
423 : @par Example
424 :
425 : @code
426 : task<> example()
427 : {
428 : // Concurrent fetch, results collected in order
429 : auto [user, posts] = co_await when_all(
430 : fetch_user( id ), // task<User>
431 : fetch_posts( id ) // task<std::vector<Post>>
432 : );
433 :
434 : // Void tasks don't contribute to result
435 : co_await when_all(
436 : log_event( "start" ), // task<void>
437 : notify_user( id ) // task<void>
438 : );
439 : // Returns task<void>, no result tuple
440 : }
441 : @endcode
442 :
443 : @see task
444 : */
445 : template<typename... Ts>
446 : [[nodiscard]] task<detail::when_all_result_t<Ts...>>
447 24 : when_all(task<Ts>... tasks)
448 : {
449 : using result_type = detail::when_all_result_t<Ts...>;
450 :
451 : // State is stored in the coroutine frame, using the frame allocator
452 : detail::when_all_state<Ts...> state;
453 :
454 : // Store tasks in the frame
455 : std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
456 :
457 : // Launch all tasks and wait for completion
458 : co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
459 :
460 : // Propagate first exception if any.
461 : // Safe without explicit acquire: capture_exception() is sequenced-before
462 : // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
463 : // last task's decrement that resumes this coroutine.
464 : if(state.first_exception_)
465 : std::rethrow_exception(state.first_exception_);
466 :
467 : // Extract and return results
468 : if constexpr (std::is_void_v<result_type>)
469 : co_return;
470 : else
471 : co_return detail::extract_results(state);
472 48 : }
473 :
474 : /// Compute the result type of `when_all` for the given task types.
475 : template<typename... Ts>
476 : using when_all_result_type = detail::when_all_result_t<Ts...>;
477 :
478 : } // namespace capy
479 : } // namespace boost
480 :
481 : #endif
|