diff --git a/strings/base_collections_base.h b/strings/base_collections_base.h index b2d58f602..22ba4e659 100644 --- a/strings/base_collections_base.h +++ b/strings/base_collections_base.h @@ -1,6 +1,6 @@ namespace winrt::impl { - struct nop_lock_guard {}; + struct [[maybe_unused]] nop_lock_guard {}; struct single_threaded_collection_base { diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index 0faaa1acd..616cb3390 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -332,224 +332,229 @@ namespace winrt::impl } }; - struct timespan_awaiter : cancellable_awaiter + template + struct threadpool_awaiter_base : cancellable_awaiter> { - explicit timespan_awaiter(Windows::Foundation::TimeSpan duration) noexcept : - m_duration(duration) - { - } - -#if defined(__GNUC__) && !defined(__clang__) - // HACK: GCC seems to require a move when calling operator co_await - // on the return value of resume_after. - // This might be related to upstream bug: - // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=99575 - timespan_awaiter(timespan_awaiter &&other) noexcept : - m_timer{std::move(other.m_timer)}, - m_duration{std::move(other.m_duration)}, - m_handle{std::move(other.m_handle)}, - m_state{other.m_state.load()} - {} -#endif - void enable_cancellation(cancellable_promise* promise) { promise->set_canceller([](void* context) { - auto that = static_cast(context); - if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending) + auto that = static_cast(context); + if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::suspended) { - that->fire_immediately(); + if (static_cast(that)->cancel(that->m_handle.get())) + { + static_cast(that)->fire_immediately(that->m_handle.get()); + } } }, this); } - bool await_ready() const noexcept + Result await_resume() { - return m_duration.count() <= 0; + if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled) + { + throw hresult_canceled(); + } + + if constexpr (!std::is_same_v) + { + return static_cast(this)->get_result(); + } } template - void await_suspend(impl::coroutine_handle handle) + bool await_suspend(impl::coroutine_handle resume) { - set_cancellable_promise_from_handle(handle); + handle_type new_handle; + new_handle.attach(check_pointer(static_cast(this)->create_threadpool_handle())); - m_handle = handle; - create_threadpool_timer(); + state expected = state::idle; + if (m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release)) + { + this->set_cancellable_promise_from_handle(resume); + + m_resume = resume; + m_handle = std::move(new_handle); + + if (m_state.load(std::memory_order_acquire) != state::canceled) + { + slim_lock_guard guard{m_mutex}; + + static_cast(this)->suspend_on_threadpool(m_handle.get()); + + expected = state::pending; + if (!m_state.compare_exchange_strong(expected, state::suspended, std::memory_order_release)) + { + // handle the case of the cancelation occurring while we where suspending on the thread pool + if (static_cast(this)->cancel(m_handle.get())) + { + // we canceled before a callback was scheduled, so we can short-circuit + return false; + } + } + + return true; + } + else + { + // short-circuit in case of an early cancelation + return false; + } + } + else + { + throw hresult_illegal_method_call(); + } } - void await_resume() + protected: + threadpool_awaiter_base() = default; + + void resume() { - if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled) { - throw hresult_canceled(); + // acquire the mutex to ensure await_suspend is finished executing + slim_lock_guard guard{m_mutex}; } + + m_resume(); } private: - void create_threadpool_timer() + enum class state { idle, pending, suspended, canceled }; + + handle_type m_handle; + impl::coroutine_handle<> m_resume{ nullptr }; + std::atomic m_state{ state::idle }; + slim_mutex m_mutex; + }; + + struct tp_timer_traits + { + using type = impl::ptp_timer; + + static void close(type value) noexcept { - m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr))); - int64_t relative_count = -m_duration.count(); - WINRT_IMPL_SetThreadpoolTimer(m_timer.get(), &relative_count, 0, 0); + WINRT_IMPL_CloseThreadpoolTimer(value); + } - state expected = state::idle; - if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release)) - { - fire_immediately(); - } + static constexpr type invalid() noexcept + { + return nullptr; } + }; - void fire_immediately() noexcept + struct timespan_awaiter : threadpool_awaiter_base + { + explicit timespan_awaiter(Windows::Foundation::TimeSpan duration) noexcept : + m_duration(duration) { - if (WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), nullptr, 0, 0)) - { - int64_t now = 0; - WINRT_IMPL_SetThreadpoolTimer(m_timer.get(), &now, 0, 0); - } } - static void __stdcall callback(void*, void* context, void*) noexcept + bool await_ready() const noexcept { - auto that = reinterpret_cast(context); - that->m_handle(); + return m_duration.count() <= 0; } - struct timer_traits + impl::ptp_timer create_threadpool_handle() noexcept { - using type = impl::ptp_timer; + return WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr); + } - static void close(type value) noexcept - { - WINRT_IMPL_CloseThreadpoolTimer(value); - } + void suspend_on_threadpool(impl::ptp_timer handle) const noexcept + { + int64_t relative_count = -m_duration.count(); + WINRT_IMPL_SetThreadpoolTimerEx(handle, &relative_count, 0, 0); + } - static constexpr type invalid() noexcept - { - return nullptr; - } - }; + bool cancel(impl::ptp_timer handle) const noexcept + { + return WINRT_IMPL_SetThreadpoolTimerEx(handle, nullptr, 0, 0); + } - enum class state { idle, pending, canceled }; + void fire_immediately(impl::ptp_timer handle) const noexcept + { + int64_t now = 0; + WINRT_IMPL_SetThreadpoolTimerEx(handle, &now, 0, 0); + } + + private: + static void __stdcall callback(void*, void* context, void*) noexcept + { + auto that = reinterpret_cast(context); + that->resume(); + } - handle_type m_timer; Windows::Foundation::TimeSpan m_duration; - impl::coroutine_handle<> m_handle; - std::atomic m_state{ state::idle }; }; - struct signal_awaiter : cancellable_awaiter + struct tp_wait_traits { - signal_awaiter(void* handle, Windows::Foundation::TimeSpan timeout) noexcept : - m_timeout(timeout), - m_handle(handle) - {} + using type = impl::ptp_wait; -#if defined(__GNUC__) && !defined(__clang__) - // HACK: GCC seems to require a move when calling operator co_await - // on the return value of resume_on_signal. - // This might be related to upstream bug: - // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=99575 - signal_awaiter(signal_awaiter &&other) noexcept : - m_wait{std::move(other.m_wait)}, - m_timeout{std::move(other.m_timeout)}, - m_handle{std::move(other.m_handle)}, - m_result{std::move(other.m_result)}, - m_resume{std::move(other.m_resume)}, - m_state{other.m_state.load()} - {} -#endif + static void close(type value) noexcept + { + WINRT_IMPL_CloseThreadpoolWait(value); + } - void enable_cancellation(cancellable_promise* promise) + static constexpr type invalid() noexcept { - promise->set_canceller([](void* context) - { - auto that = static_cast(context); - if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending) - { - that->fire_immediately(); - } - }, this); + return nullptr; } + }; + + struct signal_awaiter : threadpool_awaiter_base + { + signal_awaiter(void* handle, Windows::Foundation::TimeSpan timeout) noexcept : + m_timeout(timeout), + m_handle(handle) + {} bool await_ready() const noexcept { return WINRT_IMPL_WaitForSingleObject(m_handle, 0) == 0; } - template - void await_suspend(impl::coroutine_handle resume) + bool get_result() const noexcept { - set_cancellable_promise_from_handle(resume); - - m_resume = resume; - create_threadpool_wait(); + return m_result == 0; } - bool await_resume() + impl::ptp_wait create_threadpool_handle() noexcept { - if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled) - { - throw hresult_canceled(); - } - return m_result == 0; + return WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr); } - private: - - void create_threadpool_wait() + void suspend_on_threadpool(impl::ptp_wait handle) const noexcept { - m_wait.attach(check_pointer(WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr))); int64_t relative_count = -m_timeout.count(); int64_t* file_time = relative_count != 0 ? &relative_count : nullptr; - WINRT_IMPL_SetThreadpoolWait(m_wait.get(), m_handle, file_time); + WINRT_IMPL_SetThreadpoolWaitEx(handle, m_handle, file_time, nullptr); + } - state expected = state::idle; - if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release)) - { - fire_immediately(); - } + bool cancel(impl::ptp_wait handle) const noexcept + { + return WINRT_IMPL_SetThreadpoolWaitEx(handle, nullptr, nullptr, nullptr); } - void fire_immediately() noexcept + void fire_immediately(impl::ptp_wait handle) const noexcept { - if (WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), nullptr, nullptr, nullptr)) - { - int64_t now = 0; - WINRT_IMPL_SetThreadpoolWait(m_wait.get(), WINRT_IMPL_GetCurrentProcess(), &now); - } + int64_t now = 0; + WINRT_IMPL_SetThreadpoolWaitEx(handle, WINRT_IMPL_GetCurrentProcess(), &now, nullptr); } + private: static void __stdcall callback(void*, void* context, void*, uint32_t result) noexcept { auto that = static_cast(context); that->m_result = result; - that->m_resume(); + that->resume(); } - struct wait_traits - { - using type = impl::ptp_wait; - - static void close(type value) noexcept - { - WINRT_IMPL_CloseThreadpoolWait(value); - } - - static constexpr type invalid() noexcept - { - return nullptr; - } - }; - - enum class state { idle, pending, canceled }; - - handle_type m_wait; Windows::Foundation::TimeSpan m_timeout; void* m_handle; uint32_t m_result{}; - impl::coroutine_handle<> m_resume{ nullptr }; - std::atomic m_state{ state::idle }; }; } diff --git a/strings/base_extern.h b/strings/base_extern.h index 92872d416..f2ab090cc 100644 --- a/strings/base_extern.h +++ b/strings/base_extern.h @@ -26,8 +26,6 @@ extern "C" { int32_t __stdcall WINRT_IMPL_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept WINRT_IMPL_LINK(RoGetActivationFactory, 12); int32_t __stdcall WINRT_IMPL_RoGetAgileReference(uint32_t options, winrt::guid const& iid, void* object, void** reference) noexcept WINRT_IMPL_LINK(RoGetAgileReference, 16); - int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer, void*, uint32_t, uint32_t) noexcept WINRT_IMPL_LINK(SetThreadpoolTimerEx, 16); - int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait, void*, void*, void*) noexcept WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16); int32_t __stdcall WINRT_IMPL_RoOriginateLanguageException(int32_t error, void* message, void* exception) noexcept WINRT_IMPL_LINK(RoOriginateLanguageException, 12); void __stdcall WINRT_IMPL_RoFailFastWithErrorContext(int32_t) noexcept WINRT_IMPL_LINK(RoFailFastWithErrorContext, 4); int32_t __stdcall WINRT_IMPL_RoTransformError(int32_t, int32_t, void*) noexcept WINRT_IMPL_LINK(RoTransformError, 12); @@ -88,11 +86,11 @@ extern "C" uint32_t __stdcall WINRT_IMPL_WaitForSingleObject(void* handle, uint32_t milliseconds) noexcept WINRT_IMPL_LINK(WaitForSingleObject, 8); int32_t __stdcall WINRT_IMPL_TrySubmitThreadpoolCallback(void(__stdcall *callback)(void*, void* context), void* context, void*) noexcept WINRT_IMPL_LINK(TrySubmitThreadpoolCallback, 12); - winrt::impl::ptp_timer __stdcall WINRT_IMPL_CreateThreadpoolTimer(void(__stdcall *callback)(void*, void* context, void*), void* context, void*) noexcept WINRT_IMPL_LINK(CreateThreadpoolTimer, 12); - void __stdcall WINRT_IMPL_SetThreadpoolTimer(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept WINRT_IMPL_LINK(SetThreadpoolTimer, 16); + winrt::impl::ptp_timer __stdcall WINRT_IMPL_CreateThreadpoolTimer(void(__stdcall *callback)(void*, void* context, void*), void* context, void*) noexcept WINRT_IMPL_LINK(CreateThreadpoolTimer, 12); + int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer, void*, uint32_t, uint32_t) noexcept WINRT_IMPL_LINK(SetThreadpoolTimerEx, 16); void __stdcall WINRT_IMPL_CloseThreadpoolTimer(winrt::impl::ptp_timer timer) noexcept WINRT_IMPL_LINK(CloseThreadpoolTimer, 4); winrt::impl::ptp_wait __stdcall WINRT_IMPL_CreateThreadpoolWait(void(__stdcall *callback)(void*, void* context, void*, uint32_t result), void* context, void*) noexcept WINRT_IMPL_LINK(CreateThreadpoolWait, 12); - void __stdcall WINRT_IMPL_SetThreadpoolWait(winrt::impl::ptp_wait wait, void* handle, void* timeout) noexcept WINRT_IMPL_LINK(SetThreadpoolWait, 12); + int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait, void*, void*, void*) noexcept WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16); void __stdcall WINRT_IMPL_CloseThreadpoolWait(winrt::impl::ptp_wait wait) noexcept WINRT_IMPL_LINK(CloseThreadpoolWait, 4); winrt::impl::ptp_io __stdcall WINRT_IMPL_CreateThreadpoolIo(void* object, void(__stdcall *callback)(void*, void* context, void* overlapped, uint32_t result, std::size_t bytes, void*) noexcept, void* context, void*) noexcept WINRT_IMPL_LINK(CreateThreadpoolIo, 16); void __stdcall WINRT_IMPL_StartThreadpoolIo(winrt::impl::ptp_io io) noexcept WINRT_IMPL_LINK(StartThreadpoolIo, 4); diff --git a/strings/base_lock.h b/strings/base_lock.h index cf70a001a..4d5a44514 100644 --- a/strings/base_lock.h +++ b/strings/base_lock.h @@ -50,7 +50,7 @@ WINRT_EXPORT namespace winrt impl::srwlock m_lock{}; }; - struct slim_lock_guard + struct [[maybe_unused]] slim_lock_guard { explicit slim_lock_guard(slim_mutex& m) noexcept : m_mutex(m) @@ -69,7 +69,7 @@ WINRT_EXPORT namespace winrt slim_mutex& m_mutex; }; - struct slim_shared_lock_guard + struct [[maybe_unused]] slim_shared_lock_guard { explicit slim_shared_lock_guard(slim_mutex& m) noexcept : m_mutex(m) diff --git a/test/old_tests/UnitTests/IInspectable_GetRuntimeClassName.cpp b/test/old_tests/UnitTests/IInspectable_GetRuntimeClassName.cpp index 7db7ed869..af02a43bf 100644 --- a/test/old_tests/UnitTests/IInspectable_GetRuntimeClassName.cpp +++ b/test/old_tests/UnitTests/IInspectable_GetRuntimeClassName.cpp @@ -18,7 +18,7 @@ struct Test_GetRuntimeClassName_NoOverride : implements { - hstring GetRuntimeClassName() + hstring GetRuntimeClassName() const override { return L"GetRuntimeClassName"; } diff --git a/test/old_tests/UnitTests/async.cpp b/test/old_tests/UnitTests/async.cpp index 037e16ab7..e3e0e0131 100644 --- a/test/old_tests/UnitTests/async.cpp +++ b/test/old_tests/UnitTests/async.cpp @@ -1553,6 +1553,31 @@ TEST_CASE("async, resume_after") REQUIRE(after != GetCurrentThreadId()); } +namespace +{ + IAsyncAction test_resume_after_illegal_state(winrt::impl::timespan_awaiter &awaiter) + { + co_await awaiter; + } +} + +#if defined(__clang__) && defined(_MSC_VER) +// FIXME: Test is known to segfault when built with Clang. +TEST_CASE("async, resume_after, illegal_state", "[.clang-crash]") +#else +TEST_CASE("async, resume_after, illegal_state") +#endif +{ + auto awaiter = resume_after(1s); + + IAsyncAction first = test_resume_after_illegal_state(awaiter); + IAsyncAction second = test_resume_after_illegal_state(awaiter); + + REQUIRE_THROWS_AS(second.get(), hresult_illegal_method_call); + + first.get(); // allow first coroutine to succeed +} + // // Other tests already excercise resume_on_signal so here we focus on testing the timeout. // @@ -1584,3 +1609,30 @@ TEST_CASE("async, resume_on_signal") SetEvent(event.get()); // allow final resume_on_signal to succeed async.get(); } + +namespace +{ + IAsyncAction test_resume_on_signal_illegal_state(winrt::impl::signal_awaiter &awaiter) + { + co_await awaiter; + } +} + +#if defined(__clang__) && defined(_MSC_VER) +// FIXME: Test is known to segfault when built with Clang. +TEST_CASE("async, resume_on_signal, illegal_state", "[.clang-crash]") +#else +TEST_CASE("async, resume_on_signal, illegal_state") +#endif +{ + handle event { CreateEvent(nullptr, false, false, nullptr) }; + auto awaiter = resume_on_signal(event.get()); + + IAsyncAction first = test_resume_on_signal_illegal_state(awaiter); + IAsyncAction second = test_resume_on_signal_illegal_state(awaiter); + + REQUIRE_THROWS_AS(second.get(), hresult_illegal_method_call); + + SetEvent(event.get()); // allow first coroutine to succeed + first.get(); +} diff --git a/test/old_tests/UnitTests/produce.cpp b/test/old_tests/UnitTests/produce.cpp index 149466002..a9be5cd63 100644 --- a/test/old_tests/UnitTests/produce.cpp +++ b/test/old_tests/UnitTests/produce.cpp @@ -124,7 +124,7 @@ struct produce_IInspectable_No_RuntimeClassName : implements { - hstring GetRuntimeClassName() + hstring GetRuntimeClassName() const override { return L"produce_IInspectable_RuntimeClassName"; }