1  
//
1  
//
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
3  
//
3  
//
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6  
//
6  
//
7  
// Official repository: https://github.com/cppalliance/capy
7  
// Official repository: https://github.com/cppalliance/capy
8  
//
8  
//
9  

9  

10  
#include "src/ex/detail/strand_queue.hpp"
10  
#include "src/ex/detail/strand_queue.hpp"
11  
#include <boost/capy/ex/detail/strand_service.hpp>
11  
#include <boost/capy/ex/detail/strand_service.hpp>
12  
#include <boost/capy/continuation.hpp>
12  
#include <boost/capy/continuation.hpp>
13  
#include <atomic>
13  
#include <atomic>
14  
#include <coroutine>
14  
#include <coroutine>
15  
#include <mutex>
15  
#include <mutex>
16  
#include <thread>
16  
#include <thread>
17  
#include <utility>
17  
#include <utility>
18  

18  

19  
namespace boost {
19  
namespace boost {
20  
namespace capy {
20  
namespace capy {
21  
namespace detail {
21  
namespace detail {
22  

22  

23  
//----------------------------------------------------------
23  
//----------------------------------------------------------
24  

24  

25  
/** Implementation state for a strand.
25  
/** Implementation state for a strand.
26  

26  

27  
    Each strand_impl provides serialization for coroutines
27  
    Each strand_impl provides serialization for coroutines
28  
    dispatched through strands that share it.
28  
    dispatched through strands that share it.
29  
*/
29  
*/
30  
// Sentinel stored in cached_frame_ after shutdown to prevent
30  
// Sentinel stored in cached_frame_ after shutdown to prevent
31  
// in-flight invokers from repopulating a freed cache slot.
31  
// in-flight invokers from repopulating a freed cache slot.
32  
inline void* const kCacheClosed = reinterpret_cast<void*>(1);
32  
inline void* const kCacheClosed = reinterpret_cast<void*>(1);
33  

33  

34  
struct strand_impl
34  
struct strand_impl
35  
{
35  
{
36  
    std::mutex mutex_;
36  
    std::mutex mutex_;
37  
    strand_queue pending_;
37  
    strand_queue pending_;
38  
    bool locked_ = false;
38  
    bool locked_ = false;
39  
    std::atomic<std::thread::id> dispatch_thread_{};
39  
    std::atomic<std::thread::id> dispatch_thread_{};
40  
    std::atomic<void*> cached_frame_{nullptr};
40  
    std::atomic<void*> cached_frame_{nullptr};
41  
};
41  
};
42  

42  

43  
//----------------------------------------------------------
43  
//----------------------------------------------------------
44  

44  

45  
/** Invoker coroutine for strand dispatch.
45  
/** Invoker coroutine for strand dispatch.
46  

46  

47  
    Uses custom allocator to recycle frame - one allocation
47  
    Uses custom allocator to recycle frame - one allocation
48  
    per strand_impl lifetime, stored in trailer for recovery.
48  
    per strand_impl lifetime, stored in trailer for recovery.
49  
*/
49  
*/
50  
struct strand_invoker
50  
struct strand_invoker
51  
{
51  
{
52  
    struct promise_type
52  
    struct promise_type
53  
    {
53  
    {
54  
        // Used to post the invoker through the inner executor.
54  
        // Used to post the invoker through the inner executor.
55  
        // Lives in the coroutine frame (heap-allocated), so has
55  
        // Lives in the coroutine frame (heap-allocated), so has
56  
        // a stable address for the duration of the queue residency.
56  
        // a stable address for the duration of the queue residency.
57  
        continuation self_;
57  
        continuation self_;
58  

58  

59  
        void* operator new(std::size_t n, strand_impl& impl)
59  
        void* operator new(std::size_t n, strand_impl& impl)
60  
        {
60  
        {
61  
            constexpr auto A = alignof(strand_impl*);
61  
            constexpr auto A = alignof(strand_impl*);
62  
            std::size_t padded = (n + A - 1) & ~(A - 1);
62  
            std::size_t padded = (n + A - 1) & ~(A - 1);
63  
            std::size_t total = padded + sizeof(strand_impl*);
63  
            std::size_t total = padded + sizeof(strand_impl*);
64  

64  

65  
            void* p = impl.cached_frame_.exchange(
65  
            void* p = impl.cached_frame_.exchange(
66  
                nullptr, std::memory_order_acquire);
66  
                nullptr, std::memory_order_acquire);
67  
            if(!p || p == kCacheClosed)
67  
            if(!p || p == kCacheClosed)
68  
                p = ::operator new(total);
68  
                p = ::operator new(total);
69  

69  

70  
            // Trailer lets delete recover impl
70  
            // Trailer lets delete recover impl
71  
            *reinterpret_cast<strand_impl**>(
71  
            *reinterpret_cast<strand_impl**>(
72  
                static_cast<char*>(p) + padded) = &impl;
72  
                static_cast<char*>(p) + padded) = &impl;
73  
            return p;
73  
            return p;
74  
        }
74  
        }
75  

75  

76  
        void operator delete(void* p, std::size_t n) noexcept
76  
        void operator delete(void* p, std::size_t n) noexcept
77  
        {
77  
        {
78  
            constexpr auto A = alignof(strand_impl*);
78  
            constexpr auto A = alignof(strand_impl*);
79  
            std::size_t padded = (n + A - 1) & ~(A - 1);
79  
            std::size_t padded = (n + A - 1) & ~(A - 1);
80  

80  

81  
            auto* impl = *reinterpret_cast<strand_impl**>(
81  
            auto* impl = *reinterpret_cast<strand_impl**>(
82  
                static_cast<char*>(p) + padded);
82  
                static_cast<char*>(p) + padded);
83  

83  

84  
            void* expected = nullptr;
84  
            void* expected = nullptr;
85  
            if(!impl->cached_frame_.compare_exchange_strong(
85  
            if(!impl->cached_frame_.compare_exchange_strong(
86  
                expected, p, std::memory_order_release))
86  
                expected, p, std::memory_order_release))
87  
                ::operator delete(p);
87  
                ::operator delete(p);
88  
        }
88  
        }
89  

89  

90  
        strand_invoker get_return_object() noexcept
90  
        strand_invoker get_return_object() noexcept
91  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
91  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
92  

92  

93  
        std::suspend_always initial_suspend() noexcept { return {}; }
93  
        std::suspend_always initial_suspend() noexcept { return {}; }
94  
        std::suspend_never final_suspend() noexcept { return {}; }
94  
        std::suspend_never final_suspend() noexcept { return {}; }
95  
        void return_void() noexcept {}
95  
        void return_void() noexcept {}
96  
        void unhandled_exception() { std::terminate(); }
96  
        void unhandled_exception() { std::terminate(); }
97  
    };
97  
    };
98  

98  

99  
    std::coroutine_handle<promise_type> h_;
99  
    std::coroutine_handle<promise_type> h_;
100  
};
100  
};
101  

101  

102  
//----------------------------------------------------------
102  
//----------------------------------------------------------
103  

103  

104  
/** Concrete implementation of strand_service.
104  
/** Concrete implementation of strand_service.
105  

105  

106  
    Holds the fixed pool of strand_impl objects.
106  
    Holds the fixed pool of strand_impl objects.
107  
*/
107  
*/
108  
class strand_service_impl : public strand_service
108  
class strand_service_impl : public strand_service
109  
{
109  
{
110  
    static constexpr std::size_t num_impls = 211;
110  
    static constexpr std::size_t num_impls = 211;
111  

111  

112  
    strand_impl impls_[num_impls];
112  
    strand_impl impls_[num_impls];
113  
    std::size_t salt_ = 0;
113  
    std::size_t salt_ = 0;
114  
    std::mutex mutex_;
114  
    std::mutex mutex_;
115  

115  

116  
public:
116  
public:
117  
    explicit
117  
    explicit
118  
    strand_service_impl(execution_context&)
118  
    strand_service_impl(execution_context&)
119  
    {
119  
    {
120  
    }
120  
    }
121  

121  

122  
    strand_impl*
122  
    strand_impl*
123  
    get_implementation() override
123  
    get_implementation() override
124  
    {
124  
    {
125  
        std::lock_guard<std::mutex> lock(mutex_);
125  
        std::lock_guard<std::mutex> lock(mutex_);
126  
        std::size_t index = salt_++;
126  
        std::size_t index = salt_++;
127  
        index = index % num_impls;
127  
        index = index % num_impls;
128  
        return &impls_[index];
128  
        return &impls_[index];
129  
    }
129  
    }
130  

130  

131  
protected:
131  
protected:
132  
    void
132  
    void
133  
    shutdown() override
133  
    shutdown() override
134  
    {
134  
    {
135  
        for(std::size_t i = 0; i < num_impls; ++i)
135  
        for(std::size_t i = 0; i < num_impls; ++i)
136  
        {
136  
        {
137  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
137  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
138  
            impls_[i].locked_ = true;
138  
            impls_[i].locked_ = true;
139  

139  

140  
            void* p = impls_[i].cached_frame_.exchange(
140  
            void* p = impls_[i].cached_frame_.exchange(
141  
                kCacheClosed, std::memory_order_acquire);
141  
                kCacheClosed, std::memory_order_acquire);
142  
            if(p)
142  
            if(p)
143  
                ::operator delete(p);
143  
                ::operator delete(p);
144  
        }
144  
        }
145  
    }
145  
    }
146  

146  

147  
private:
147  
private:
148  
    static bool
148  
    static bool
149  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
149  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
150  
    {
150  
    {
151  
        std::lock_guard<std::mutex> lock(impl.mutex_);
151  
        std::lock_guard<std::mutex> lock(impl.mutex_);
152  
        impl.pending_.push(h);
152  
        impl.pending_.push(h);
153  
        if(!impl.locked_)
153  
        if(!impl.locked_)
154  
        {
154  
        {
155  
            impl.locked_ = true;
155  
            impl.locked_ = true;
156  
            return true;
156  
            return true;
157  
        }
157  
        }
158  
        return false;
158  
        return false;
159  
    }
159  
    }
160  

160  

161  
    static void
161  
    static void
162  
    dispatch_pending(strand_impl& impl)
162  
    dispatch_pending(strand_impl& impl)
163  
    {
163  
    {
164  
        strand_queue::taken_batch batch;
164  
        strand_queue::taken_batch batch;
165  
        {
165  
        {
166  
            std::lock_guard<std::mutex> lock(impl.mutex_);
166  
            std::lock_guard<std::mutex> lock(impl.mutex_);
167  
            batch = impl.pending_.take_all();
167  
            batch = impl.pending_.take_all();
168  
        }
168  
        }
169  
        impl.pending_.dispatch_batch(batch);
169  
        impl.pending_.dispatch_batch(batch);
170  
    }
170  
    }
171  

171  

172  
    static bool
172  
    static bool
173  
    try_unlock(strand_impl& impl)
173  
    try_unlock(strand_impl& impl)
174  
    {
174  
    {
175  
        std::lock_guard<std::mutex> lock(impl.mutex_);
175  
        std::lock_guard<std::mutex> lock(impl.mutex_);
176  
        if(impl.pending_.empty())
176  
        if(impl.pending_.empty())
177  
        {
177  
        {
178  
            impl.locked_ = false;
178  
            impl.locked_ = false;
179  
            return true;
179  
            return true;
180  
        }
180  
        }
181  
        return false;
181  
        return false;
182  
    }
182  
    }
183  

183  

184  
    static void
184  
    static void
185  
    set_dispatch_thread(strand_impl& impl) noexcept
185  
    set_dispatch_thread(strand_impl& impl) noexcept
186  
    {
186  
    {
187  
        impl.dispatch_thread_.store(std::this_thread::get_id());
187  
        impl.dispatch_thread_.store(std::this_thread::get_id());
188  
    }
188  
    }
189  

189  

190  
    static void
190  
    static void
191  
    clear_dispatch_thread(strand_impl& impl) noexcept
191  
    clear_dispatch_thread(strand_impl& impl) noexcept
192  
    {
192  
    {
193  
        impl.dispatch_thread_.store(std::thread::id{});
193  
        impl.dispatch_thread_.store(std::thread::id{});
194  
    }
194  
    }
195  

195  

196  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
196  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
197  
    // (repost after each batch to let other work run) - explore if starvation observed.
197  
    // (repost after each batch to let other work run) - explore if starvation observed.
198  
    static strand_invoker
198  
    static strand_invoker
199  
    make_invoker(strand_impl& impl)
199  
    make_invoker(strand_impl& impl)
200  
    {
200  
    {
201  
        strand_impl* p = &impl;
201  
        strand_impl* p = &impl;
202  
        for(;;)
202  
        for(;;)
203  
        {
203  
        {
204  
            set_dispatch_thread(*p);
204  
            set_dispatch_thread(*p);
205  
            dispatch_pending(*p);
205  
            dispatch_pending(*p);
206  
            if(try_unlock(*p))
206  
            if(try_unlock(*p))
207  
            {
207  
            {
208  
                clear_dispatch_thread(*p);
208  
                clear_dispatch_thread(*p);
209  
                co_return;
209  
                co_return;
210  
            }
210  
            }
211  
        }
211  
        }
212  
    }
212  
    }
213  

213  

214  
    static void
214  
    static void
215  
    post_invoker(strand_impl& impl, executor_ref ex)
215  
    post_invoker(strand_impl& impl, executor_ref ex)
216  
    {
216  
    {
217  
        auto invoker = make_invoker(impl);
217  
        auto invoker = make_invoker(impl);
218  
        auto& self = invoker.h_.promise().self_;
218  
        auto& self = invoker.h_.promise().self_;
219  
        self.h = invoker.h_;
219  
        self.h = invoker.h_;
220  
        ex.post(self);
220  
        ex.post(self);
221  
    }
221  
    }
222  

222  

223  
    friend class strand_service;
223  
    friend class strand_service;
224  
};
224  
};
225  

225  

226  
//----------------------------------------------------------
226  
//----------------------------------------------------------
227  

227  

228  
strand_service::
228  
strand_service::
229  
strand_service()
229  
strand_service()
230  
    : service()
230  
    : service()
231  
{
231  
{
232  
}
232  
}
233  

233  

234  
strand_service::
234  
strand_service::
235  
~strand_service() = default;
235  
~strand_service() = default;
236  

236  

237  
bool
237  
bool
238  
strand_service::
238  
strand_service::
239  
running_in_this_thread(strand_impl& impl) noexcept
239  
running_in_this_thread(strand_impl& impl) noexcept
240  
{
240  
{
241  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
241  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
242  
}
242  
}
243  

243  

244  
std::coroutine_handle<>
244  
std::coroutine_handle<>
245  
strand_service::
245  
strand_service::
246  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
246  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
247  
{
247  
{
248  
    if(running_in_this_thread(impl))
248  
    if(running_in_this_thread(impl))
249  
        return h;
249  
        return h;
250  

250  

251  
    if(strand_service_impl::enqueue(impl, h))
251  
    if(strand_service_impl::enqueue(impl, h))
252  
        strand_service_impl::post_invoker(impl, ex);
252  
        strand_service_impl::post_invoker(impl, ex);
253  
    return std::noop_coroutine();
253  
    return std::noop_coroutine();
254  
}
254  
}
255  

255  

256  
void
256  
void
257  
strand_service::
257  
strand_service::
258  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
258  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
259  
{
259  
{
260  
    if(strand_service_impl::enqueue(impl, h))
260  
    if(strand_service_impl::enqueue(impl, h))
261  
        strand_service_impl::post_invoker(impl, ex);
261  
        strand_service_impl::post_invoker(impl, ex);
262  
}
262  
}
263  

263  

264  
strand_service&
264  
strand_service&
265  
get_strand_service(execution_context& ctx)
265  
get_strand_service(execution_context& ctx)
266  
{
266  
{
267  
    return ctx.use_service<strand_service_impl>();
267  
    return ctx.use_service<strand_service_impl>();
268  
}
268  
}
269  

269  

270  
} // namespace detail
270  
} // namespace detail
271  
} // namespace capy
271  
} // namespace capy
272  
} // namespace boost
272  
} // namespace boost