libs/capy/include/boost/capy/when_all.hpp

98.9% Lines (91/92) 91.4% Functions (298/326) 95.2% Branches (20/21)
libs/capy/include/boost/capy/when_all.hpp
Line Branch Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_launchable_task.hpp>
16 #include <boost/capy/coro.hpp>
17 #include <boost/capy/ex/executor_ref.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 44 void set(T v)
56 {
57 44 value_ = std::move(v);
58 44 }
59
60 37 T get() &&
61 {
62 37 return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<coro, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 1 void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 coro continuation_;
109 executor_ref caller_ex_;
110
111 24 when_all_state()
112
1/1
✓ Branch 5 taken 24 times.
24 : remaining_count_(task_count)
113 {
114 24 }
115
116 24 ~when_all_state()
117 {
118
2/2
✓ Branch 0 taken 60 times.
✓ Branch 1 taken 24 times.
84 for(auto h : runner_handles_)
119
1/2
✓ Branch 1 taken 60 times.
✗ Branch 2 not taken.
60 if(h)
120 60 h.destroy();
121 24 }
122
123 /** Capture an exception (first one wins).
124 */
125 11 void capture_exception(std::exception_ptr ep)
126 {
127 11 bool expected = false;
128
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 3 times.
11 if(has_exception_.compare_exchange_strong(
129 expected, true, std::memory_order_relaxed))
130 8 first_exception_ = ep;
131 11 }
132
133 /** Signal that a task has completed.
134
135 The last child to complete triggers resumption of the parent.
136 Dispatch handles thread affinity: resumes inline if on same
137 thread, otherwise posts to the caller's executor.
138 */
139 60 coro signal_completion()
140 {
141 60 auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel);
142
2/2
✓ Branch 0 taken 24 times.
✓ Branch 1 taken 36 times.
60 if(remaining == 1)
143 24 caller_ex_.dispatch(continuation_);
144 60 return std::noop_coroutine();
145 }
146
147 };
148
149 /** Wrapper coroutine that intercepts task completion.
150
151 This runner awaits its assigned task and stores the result in
152 the shared state, or captures the exception and requests stop.
153 */
154 template<typename T, typename... Ts>
155 struct when_all_runner
156 {
157 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
158 {
159 when_all_state<Ts...>* state_ = nullptr;
160 executor_ref ex_;
161 std::stop_token stop_token_;
162
163 60 when_all_runner get_return_object()
164 {
165 60 return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
166 }
167
168 60 std::suspend_always initial_suspend() noexcept
169 {
170 60 return {};
171 }
172
173 60 auto final_suspend() noexcept
174 {
175 struct awaiter
176 {
177 promise_type* p_;
178
179 bool await_ready() const noexcept
180 {
181 return false;
182 }
183
184 coro await_suspend(coro) noexcept
185 {
186 // Signal completion; last task resumes parent
187 return p_->state_->signal_completion();
188 }
189
190 void await_resume() const noexcept
191 {
192 }
193 };
194 60 return awaiter{this};
195 }
196
197 49 void return_void()
198 {
199 49 }
200
201 11 void unhandled_exception()
202 {
203 11 state_->capture_exception(std::current_exception());
204 // Request stop for sibling tasks
205 11 state_->stop_source_.request_stop();
206 11 }
207
208 template<class Awaitable>
209 struct transform_awaiter
210 {
211 std::decay_t<Awaitable> a_;
212 promise_type* p_;
213
214 60 bool await_ready()
215 {
216 60 return a_.await_ready();
217 }
218
219 60 decltype(auto) await_resume()
220 {
221 60 return a_.await_resume();
222 }
223
224 template<class Promise>
225 60 auto await_suspend(std::coroutine_handle<Promise> h)
226 {
227
1/1
✓ Branch 3 taken 54 times.
60 return a_.await_suspend(h, p_->ex_, p_->stop_token_);
228 }
229 };
230
231 template<class Awaitable>
232 60 auto await_transform(Awaitable&& a)
233 {
234 using A = std::decay_t<Awaitable>;
235 if constexpr (IoAwaitable<A>)
236 {
237 return transform_awaiter<Awaitable>{
238 120 std::forward<Awaitable>(a), this};
239 }
240 else
241 {
242 static_assert(sizeof(A) == 0, "requires IoAwaitable");
243 }
244 60 }
245 };
246
247 std::coroutine_handle<promise_type> h_;
248
249 60 explicit when_all_runner(std::coroutine_handle<promise_type> h)
250 60 : h_(h)
251 {
252 60 }
253
254 // Enable move for all clang versions - some versions need it
255 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
256
257 // Non-copyable
258 when_all_runner(when_all_runner const&) = delete;
259 when_all_runner& operator=(when_all_runner const&) = delete;
260 when_all_runner& operator=(when_all_runner&&) = delete;
261
262 60 auto release() noexcept
263 {
264 60 return std::exchange(h_, nullptr);
265 }
266 };
267
268 /** Create a runner coroutine for a single task.
269
270 Task is passed directly to ensure proper coroutine frame storage.
271 */
272 template<std::size_t Index, typename T, typename... Ts>
273 when_all_runner<T, Ts...>
274
1/1
✓ Branch 1 taken 60 times.
60 make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
275 {
276 if constexpr (std::is_void_v<T>)
277 {
278 co_await std::move(inner);
279 }
280 else
281 {
282 std::get<Index>(state->results_).set(co_await std::move(inner));
283 }
284 120 }
285
286 /** Internal awaitable that launches all runner coroutines and waits.
287
288 This awaitable is used inside the when_all coroutine to handle
289 the concurrent execution of child tasks.
290 */
291 template<typename... Ts>
292 class when_all_launcher
293 {
294 std::tuple<task<Ts>...>* tasks_;
295 when_all_state<Ts...>* state_;
296
297 public:
298 24 when_all_launcher(
299 std::tuple<task<Ts>...>* tasks,
300 when_all_state<Ts...>* state)
301 24 : tasks_(tasks)
302 24 , state_(state)
303 {
304 24 }
305
306 24 bool await_ready() const noexcept
307 {
308 24 return sizeof...(Ts) == 0;
309 }
310
311 24 coro await_suspend(coro continuation, executor_ref caller_ex, std::stop_token parent_token = {})
312 {
313 24 state_->continuation_ = continuation;
314 24 state_->caller_ex_ = caller_ex;
315
316 // Forward parent's stop requests to children
317
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 20 times.
24 if(parent_token.stop_possible())
318 {
319 8 state_->parent_stop_callback_.emplace(
320 parent_token,
321 4 typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
322
323
2/2
✓ Branch 1 taken 1 time.
✓ Branch 2 taken 3 times.
4 if(parent_token.stop_requested())
324 1 state_->stop_source_.request_stop();
325 }
326
327 // Launch all tasks concurrently
328 24 auto token = state_->stop_source_.get_token();
329 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
330 (..., launch_one<Is>(caller_ex, token));
331
1/1
✓ Branch 1 taken 24 times.
24 }(std::index_sequence_for<Ts...>{});
332
333 // Let signal_completion() handle resumption
334 48 return std::noop_coroutine();
335 24 }
336
337 24 void await_resume() const noexcept
338 {
339 // Results are extracted by the when_all coroutine from state
340 24 }
341
342 private:
343 template<std::size_t I>
344 60 void launch_one(executor_ref caller_ex, std::stop_token token)
345 {
346
1/1
✓ Branch 2 taken 60 times.
60 auto runner = make_when_all_runner<I>(
347 60 std::move(std::get<I>(*tasks_)), state_);
348
349 60 auto h = runner.release();
350 60 h.promise().state_ = state_;
351 60 h.promise().ex_ = caller_ex;
352 60 h.promise().stop_token_ = token;
353
354 60 coro ch{h};
355 60 state_->runner_handles_[I] = ch;
356
1/1
✓ Branch 1 taken 60 times.
60 state_->caller_ex_.dispatch(ch);
357 60 }
358 };
359
360 /** Compute the result type for when_all.
361
362 Returns void when all tasks are void (P2300 aligned),
363 otherwise returns a tuple with void types filtered out.
364 */
365 template<typename... Ts>
366 using when_all_result_t = std::conditional_t<
367 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
368 void,
369 filter_void_tuple_t<Ts...>>;
370
371 /** Helper to extract a single result, returning empty tuple for void.
372 This is a separate function to work around a GCC-11 ICE that occurs
373 when using nested immediately-invoked lambdas with pack expansion.
374 */
375 template<std::size_t I, typename... Ts>
376 39 auto extract_single_result(when_all_state<Ts...>& state)
377 {
378 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
379 if constexpr (std::is_void_v<T>)
380 2 return std::tuple<>();
381 else
382
1/1
✓ Branch 4 taken 37 times.
37 return std::make_tuple(std::move(std::get<I>(state.results_)).get());
383 }
384
385 /** Extract results from state, filtering void types.
386 */
387 template<typename... Ts>
388 15 auto extract_results(when_all_state<Ts...>& state)
389 {
390 15 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
391 return std::tuple_cat(extract_single_result<Is>(state)...);
392
1/1
✓ Branch 1 taken 15 times.
30 }(std::index_sequence_for<Ts...>{});
393 }
394
395 } // namespace detail
396
397 /** Execute multiple tasks concurrently and collect their results.
398
399 Launches all tasks simultaneously and waits for all to complete
400 before returning. Results are collected in input order. If any
401 task throws, cancellation is requested for siblings and the first
402 exception is rethrown after all tasks complete.
403
404 @li All child tasks run concurrently on the caller's executor
405 @li Results are returned as a tuple in input order
406 @li Void-returning tasks do not contribute to the result tuple
407 @li If all tasks return void, `when_all` returns `task<void>`
408 @li First exception wins; subsequent exceptions are discarded
409 @li Stop is requested for siblings on first error
410 @li Completes only after all children have finished
411
412 @par Thread Safety
413 The returned task must be awaited from a single execution context.
414 Child tasks execute concurrently but complete through the caller's
415 executor.
416
417 @param tasks The tasks to execute concurrently. Each task is
418 consumed (moved-from) when `when_all` is awaited.
419
420 @return A task yielding a tuple of non-void results. Returns
421 `task<void>` when all input tasks return void.
422
423 @par Example
424
425 @code
426 task<> example()
427 {
428 // Concurrent fetch, results collected in order
429 auto [user, posts] = co_await when_all(
430 fetch_user( id ), // task<User>
431 fetch_posts( id ) // task<std::vector<Post>>
432 );
433
434 // Void tasks don't contribute to result
435 co_await when_all(
436 log_event( "start" ), // task<void>
437 notify_user( id ) // task<void>
438 );
439 // Returns task<void>, no result tuple
440 }
441 @endcode
442
443 @see task
444 */
445 template<typename... Ts>
446 [[nodiscard]] task<detail::when_all_result_t<Ts...>>
447
1/1
✓ Branch 1 taken 24 times.
24 when_all(task<Ts>... tasks)
448 {
449 using result_type = detail::when_all_result_t<Ts...>;
450
451 // State is stored in the coroutine frame, using the frame allocator
452 detail::when_all_state<Ts...> state;
453
454 // Store tasks in the frame
455 std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
456
457 // Launch all tasks and wait for completion
458 co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
459
460 // Propagate first exception if any.
461 // Safe without explicit acquire: capture_exception() is sequenced-before
462 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
463 // last task's decrement that resumes this coroutine.
464 if(state.first_exception_)
465 std::rethrow_exception(state.first_exception_);
466
467 // Extract and return results
468 if constexpr (std::is_void_v<result_type>)
469 co_return;
470 else
471 co_return detail::extract_results(state);
472 48 }
473
474 /// Compute the result type of `when_all` for the given task types.
475 template<typename... Ts>
476 using when_all_result_type = detail::when_all_result_t<Ts...>;
477
478 } // namespace capy
479 } // namespace boost
480
481 #endif
482