Line data Source code
1 : //
2 : // Copyright (c) 2025 Vinnie Falco (vinnie dot falco at gmail dot com)
3 : //
4 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 : //
7 : // Official repository: https://github.com/cppalliance/capy
8 : //
9 :
10 : #ifndef BOOST_CAPY_EX_IO_AWAITABLE_SUPPORT_HPP
11 : #define BOOST_CAPY_EX_IO_AWAITABLE_SUPPORT_HPP
12 :
13 : #include <boost/capy/detail/config.hpp>
14 : #include <boost/capy/coro.hpp>
15 : #include <boost/capy/ex/executor_ref.hpp>
16 : #include <boost/capy/ex/frame_allocator.hpp>
17 : #include <boost/capy/ex/this_coro.hpp>
18 :
19 : #include <coroutine>
20 : #include <cstddef>
21 : #include <memory_resource>
22 : #include <stop_token>
23 : #include <type_traits>
24 :
25 : namespace boost {
26 : namespace capy {
27 :
28 : /** CRTP mixin that adds I/O awaitable support to a promise type.
29 :
30 : Inherit from this class to enable these capabilities in your coroutine:
31 :
32 : 1. **Frame allocation** — The mixin provides `operator new/delete` that
33 : use the thread-local frame allocator set by `run_async`.
34 :
35 : 2. **Frame allocator storage** — The mixin stores the allocator pointer
36 : for propagation to child tasks.
37 :
38 : 3. **Stop token storage** — The mixin stores the `std::stop_token`
39 : that was passed when your coroutine was awaited.
40 :
41 : 4. **Stop token access** — Coroutine code can retrieve the token via
42 : `co_await this_coro::stop_token`.
43 :
44 : 5. **Executor storage** — The mixin stores the `executor_ref`
45 : that this coroutine is bound to.
46 :
47 : 6. **Executor access** — Coroutine code can retrieve the executor via
48 : `co_await this_coro::executor`.
49 :
50 : @tparam Derived The derived promise type (CRTP pattern).
51 :
52 : @par Basic Usage
53 :
54 : For coroutines that need to access their stop token or executor:
55 :
56 : @code
57 : struct my_task
58 : {
59 : struct promise_type : io_awaitable_support<promise_type>
60 : {
61 : my_task get_return_object();
62 : std::suspend_always initial_suspend() noexcept;
63 : std::suspend_always final_suspend() noexcept;
64 : void return_void();
65 : void unhandled_exception();
66 : };
67 :
68 : // ... awaitable interface ...
69 : };
70 :
71 : my_task example()
72 : {
73 : auto token = co_await this_coro::stop_token;
74 : auto ex = co_await this_coro::executor;
75 : // Use token and ex...
76 : }
77 : @endcode
78 :
79 : @par Custom Awaitable Transformation
80 :
81 : If your promise needs to transform awaitables (e.g., for affinity or
82 : logging), override `transform_awaitable` instead of `await_transform`:
83 :
84 : @code
85 : struct promise_type : io_awaitable_support<promise_type>
86 : {
87 : template<typename A>
88 : auto transform_awaitable(A&& a)
89 : {
90 : // Your custom transformation logic
91 : return std::forward<A>(a);
92 : }
93 : };
94 : @endcode
95 :
96 : The mixin's `await_transform` intercepts @ref this_coro::stop_token_tag and
97 : @ref this_coro::executor_tag, then delegates all other awaitables to your
98 : `transform_awaitable`.
99 :
100 : @par Making Your Coroutine an IoAwaitable
101 :
102 : The mixin handles the "inside the coroutine" part—accessing the token
103 : and executor. To receive these when your coroutine is awaited (satisfying
104 : @ref IoAwaitable), implement the `await_suspend` overload on your
105 : coroutine return type:
106 :
107 : @code
108 : struct my_task
109 : {
110 : struct promise_type : io_awaitable_support<promise_type> { ... };
111 :
112 : std::coroutine_handle<promise_type> h_;
113 :
114 : // IoAwaitable await_suspend receives and stores the token and executor
115 : coro await_suspend(coro cont, executor_ref ex, std::stop_token token)
116 : {
117 : h_.promise().set_stop_token(token);
118 : h_.promise().set_executor(ex);
119 : // ... rest of suspend logic ...
120 : }
121 : };
122 : @endcode
123 :
124 : @par Thread Safety
125 : The stop token and executor are stored during `await_suspend` and read
126 : during `co_await this_coro::stop_token` or `co_await this_coro::executor`.
127 : These occur on the same logical thread of execution, so no synchronization
128 : is required.
129 :
130 : @see this_coro::stop_token
131 : @see this_coro::executor
132 : @see IoAwaitable
133 : */
134 : template<typename Derived>
135 : class io_awaitable_support
136 : {
137 : executor_ref executor_;
138 : std::stop_token stop_token_;
139 : std::pmr::memory_resource* alloc_ = nullptr;
140 : executor_ref caller_ex_;
141 : mutable coro cont_{nullptr};
142 :
143 : public:
144 : //----------------------------------------------------------
145 : // Frame allocation support
146 : //----------------------------------------------------------
147 :
148 : private:
149 : static constexpr std::size_t ptr_alignment = alignof(void*);
150 :
151 : static std::size_t
152 5938 : aligned_offset(std::size_t n) noexcept
153 : {
154 5938 : return (n + ptr_alignment - 1) & ~(ptr_alignment - 1);
155 : }
156 :
157 : public:
158 : /** Allocate a coroutine frame.
159 :
160 : Uses the thread-local frame allocator set by run_async.
161 : Falls back to default memory resource if not set.
162 : Stores the allocator pointer at the end of each frame for
163 : correct deallocation even when TLS changes.
164 : */
165 : static void*
166 2969 : operator new(std::size_t size)
167 : {
168 2969 : auto* mr = current_frame_allocator();
169 2969 : if(!mr)
170 83 : mr = std::pmr::get_default_resource();
171 :
172 : // Allocate extra space for memory_resource pointer
173 2969 : std::size_t ptr_offset = aligned_offset(size);
174 2969 : std::size_t total = ptr_offset + sizeof(std::pmr::memory_resource*);
175 2969 : void* raw = mr->allocate(total, alignof(std::max_align_t));
176 :
177 : // Store the allocator pointer at the end
178 2969 : auto* ptr_loc = reinterpret_cast<std::pmr::memory_resource**>(
179 : static_cast<char*>(raw) + ptr_offset);
180 2969 : *ptr_loc = mr;
181 :
182 2969 : return raw;
183 : }
184 :
185 : /** Deallocate a coroutine frame.
186 :
187 : Reads the allocator pointer stored at the end of the frame
188 : to ensure correct deallocation regardless of current TLS.
189 : */
190 : static void
191 2969 : operator delete(void* ptr, std::size_t size)
192 : {
193 : // Read the allocator pointer from the end of the frame
194 2969 : std::size_t ptr_offset = aligned_offset(size);
195 2969 : auto* ptr_loc = reinterpret_cast<std::pmr::memory_resource**>(
196 : static_cast<char*>(ptr) + ptr_offset);
197 2969 : auto* mr = *ptr_loc;
198 :
199 2969 : std::size_t total = ptr_offset + sizeof(std::pmr::memory_resource*);
200 2969 : mr->deallocate(ptr, total, alignof(std::max_align_t));
201 2969 : }
202 :
203 2969 : ~io_awaitable_support()
204 : {
205 2969 : if (cont_)
206 1 : cont_.destroy();
207 2969 : }
208 :
209 : /** Store a frame allocator for later retrieval.
210 :
211 : Call this from initial_suspend to capture the current
212 : TLS allocator for propagation to child tasks.
213 :
214 : @param alloc The allocator to store.
215 : */
216 : void
217 2965 : set_frame_allocator(std::pmr::memory_resource* alloc) noexcept
218 : {
219 2965 : alloc_ = alloc;
220 2965 : }
221 :
222 : /** Return the stored frame allocator.
223 :
224 : @return The allocator, or nullptr if none was set.
225 : */
226 : std::pmr::memory_resource*
227 19521 : frame_allocator() const noexcept
228 : {
229 19521 : return alloc_;
230 : }
231 :
232 : //----------------------------------------------------------
233 : // Continuation support
234 : //----------------------------------------------------------
235 :
236 : /** Store continuation and caller's executor for completion dispatch.
237 :
238 : Call this from your coroutine type's `await_suspend` overload to
239 : set up the completion path. On completion, the coroutine will
240 : resume the continuation, dispatching through the caller's executor
241 : if it differs from this coroutine's executor.
242 :
243 : @param cont The continuation to resume on completion.
244 : @param caller_ex The caller's executor for completion dispatch.
245 : */
246 2915 : void set_continuation(coro cont, executor_ref caller_ex) noexcept
247 : {
248 2915 : cont_ = cont;
249 2915 : caller_ex_ = caller_ex;
250 2915 : }
251 :
252 : /** Return the handle to resume on completion with dispatch-awareness.
253 :
254 : If no continuation was set, returns `std::noop_coroutine()`.
255 : If the coroutine's executor matches the caller's executor, returns
256 : the continuation directly for symmetric transfer.
257 : Otherwise, dispatches through the caller's executor and returns
258 : `std::noop_coroutine()`.
259 :
260 : Call this from your `final_suspend` awaiter's `await_suspend`.
261 :
262 : @return A coroutine handle for symmetric transfer.
263 : */
264 2957 : coro complete() const noexcept
265 : {
266 2957 : if(!cont_)
267 43 : return std::noop_coroutine();
268 2914 : if(executor_ == caller_ex_)
269 2914 : return std::exchange(cont_, nullptr);
270 0 : caller_ex_.dispatch(std::exchange(cont_, nullptr));
271 0 : return std::noop_coroutine();
272 : }
273 :
274 : /** Store a stop token for later retrieval.
275 :
276 : Call this from your coroutine type's `await_suspend`
277 : overload to make the token available via
278 : `co_await this_coro::stop_token`.
279 :
280 : @param token The stop token to store.
281 : */
282 2917 : void set_stop_token(std::stop_token token) noexcept
283 : {
284 2917 : stop_token_ = token;
285 2917 : }
286 :
287 : /** Return the stored stop token.
288 :
289 : @return The stop token, or a default-constructed token if none was set.
290 : */
291 1846 : std::stop_token const& stop_token() const noexcept
292 : {
293 1846 : return stop_token_;
294 : }
295 :
296 : /** Store an executor for later retrieval.
297 :
298 : Call this from your coroutine type's `await_suspend`
299 : overload to make the executor available via
300 : `co_await this_coro::executor`.
301 :
302 : @param ex The executor to store.
303 : */
304 2917 : void set_executor(executor_ref ex) noexcept
305 : {
306 2917 : executor_ = ex;
307 2917 : }
308 :
309 : /** Return the stored executor.
310 :
311 : @return The executor, or a default-constructed executor_ref if none was set.
312 : */
313 1846 : executor_ref executor() const noexcept
314 : {
315 1846 : return executor_;
316 : }
317 :
318 : /** Transform an awaitable before co_await.
319 :
320 : Override this in your derived promise type to customize how
321 : awaitables are transformed. The default implementation passes
322 : the awaitable through unchanged.
323 :
324 : @param a The awaitable expression from `co_await a`.
325 :
326 : @return The transformed awaitable.
327 : */
328 : template<typename A>
329 : decltype(auto) transform_awaitable(A&& a)
330 : {
331 : return std::forward<A>(a);
332 : }
333 :
334 : /** Intercept co_await expressions.
335 :
336 : This function handles @ref this_coro::stop_token_tag and
337 : @ref this_coro::executor_tag specially, returning an awaiter that
338 : yields the stored value. All other awaitables are delegated to
339 : @ref transform_awaitable.
340 :
341 : @param t The awaited expression.
342 :
343 : @return An awaiter for the expression.
344 : */
345 : template<typename T>
346 6895 : auto await_transform(T&& t)
347 : {
348 : if constexpr (std::is_same_v<std::decay_t<T>, this_coro::stop_token_tag>)
349 : {
350 : struct awaiter
351 : {
352 : std::stop_token token_;
353 :
354 14 : bool await_ready() const noexcept
355 : {
356 14 : return true;
357 : }
358 :
359 1 : void await_suspend(coro) const noexcept
360 : {
361 1 : }
362 :
363 13 : std::stop_token await_resume() const noexcept
364 : {
365 13 : return token_;
366 : }
367 : };
368 15 : return awaiter{stop_token_};
369 : }
370 : else if constexpr (std::is_same_v<std::decay_t<T>, this_coro::executor_tag>)
371 : {
372 : struct awaiter
373 : {
374 : executor_ref executor_;
375 :
376 2 : bool await_ready() const noexcept
377 : {
378 2 : return true;
379 : }
380 :
381 1 : void await_suspend(coro) const noexcept
382 : {
383 1 : }
384 :
385 1 : executor_ref await_resume() const noexcept
386 : {
387 1 : return executor_;
388 : }
389 : };
390 3 : return awaiter{executor_};
391 : }
392 : else
393 : {
394 5629 : return static_cast<Derived*>(this)->transform_awaitable(
395 6877 : std::forward<T>(t));
396 : }
397 : }
398 : };
399 :
400 : } // namespace capy
401 : } // namespace boost
402 :
403 : #endif
|