diff --git a/capi/geos_c.cpp b/capi/geos_c.cpp index d8cbdd9f43..3833758d9b 100644 --- a/capi/geos_c.cpp +++ b/capi/geos_c.cpp @@ -116,12 +116,24 @@ extern "C" { return geos::util::Interrupt::registerCallback(cb); } + GEOSInterruptThreadCallback* + GEOS_interruptRegisterThreadCallback(GEOSInterruptThreadCallback* cb, void* data) + { + return geos::util::Interrupt::registerThreadCallback(cb, data); + } + void GEOS_interruptRequest() { geos::util::Interrupt::request(); } + void + GEOS_interruptThread() + { + geos::util::Interrupt::requestForCurrentThread(); + } + void GEOS_interruptCancel() { diff --git a/capi/geos_c.h.in b/capi/geos_c.h.in index 6a9d0fd1fa..b7f3c38a8c 100644 --- a/capi/geos_c.h.in +++ b/capi/geos_c.h.in @@ -305,8 +305,10 @@ typedef int (*GEOSTransformXYCallback)( */ typedef void (GEOSInterruptCallback)(void); +typedef void (GEOSInterruptThreadCallback)(void*); + /** -* Register a function to be called when processing is interrupted. +* Register a function to be called when processing is interrupted on any thread. * \param cb Callback function to invoke * \return the previously configured callback * \see GEOSInterruptCallback @@ -315,10 +317,29 @@ extern GEOSInterruptCallback GEOS_DLL *GEOS_interruptRegisterCallback( GEOSInterruptCallback* cb); /** -* Request safe interruption of operations +* Register a function to be called when processing is interrupted on the current thread. +* +* \param cb Callback function to invoke +* \param context pointer to a context object that will be passed to `cb` +* \return the previously configured callback +* \see GEOSInterruptCallback +*/ +extern GEOSInterruptThreadCallback GEOS_DLL *GEOS_interruptRegisterThreadCallback( + GEOSInterruptThreadCallback* cb, void* context); + +/** +* Request safe interruption of operations. The next thread to check for an +* interrupt will be interrupted. To request interruption of a specific thread, +* instead call `GEOS_interruptThread` from a callback executed by that thread. */ extern void GEOS_DLL GEOS_interruptRequest(void); +/** +* Request safe interruption of operations in the current thread. This function +* should be called from a callback registered by `GEOS_interruptRegisterThreadCallback`. +*/ +extern void GEOS_DLL GEOS_interruptThread(void); + /** * Cancel a pending interruption request */ diff --git a/include/geos/util/Interrupt.h b/include/geos/util/Interrupt.h index e52386d410..f4c355cc13 100644 --- a/include/geos/util/Interrupt.h +++ b/include/geos/util/Interrupt.h @@ -27,15 +27,25 @@ class GEOS_DLL Interrupt { public: typedef void (Callback)(void); + typedef void (ThreadCallback)(void*); /** * Request interruption of operations * * Operations will be terminated by a GEOSInterrupt - * exception at first occasion. + * exception at first occasion, by the first thread + * to check for an interrupt request. */ static void request(); + /** + * Request interruption of operations in the current thread + * + * Operations in the current thread will be terminated by + * a GEOSInterrupt at first occasion. + */ + static void requestForCurrentThread(); + /** Cancel a pending interruption request */ static void cancel(); @@ -43,17 +53,29 @@ class GEOS_DLL Interrupt { static bool check(); /** \brief - * Register a callback that will be invoked + * Register a callback that will be invoked by all threads * before checking for interruption requests. * * NOTE that interruption request checking may happen - * frequently so any callback would better be quick. + * frequently so the callback should execute quickly. * * The callback can be used to call Interrupt::request() - * + * or Interrupt::requestForCurrentThread(). */ static Callback* registerCallback(Callback* cb); + /** \brief + * Register a callback that will be invoked the current thread + * before checking for interruption requests. + * + * NOTE that interruption request checking may happen + * frequently so the callback shoudl execute quickly. + * + * The callback can be used to call Interrupt::request() + * or Interrupt::requestForCurrentThread(). + */ + static ThreadCallback* registerThreadCallback(ThreadCallback* cb, void* data); + /** * Invoke the callback, if any. Process pending interruption, if any. * diff --git a/src/util/Interrupt.cpp b/src/util/Interrupt.cpp index 0bc988221b..409c131c63 100644 --- a/src/util/Interrupt.cpp +++ b/src/util/Interrupt.cpp @@ -18,8 +18,12 @@ namespace { /* Could these be portably stored in thread-specific space ? */ bool requested = false; +thread_local bool requested_for_thread = false; geos::util::Interrupt::Callback* callback = nullptr; +thread_local geos::util::Interrupt::ThreadCallback* callback_thread = nullptr; +thread_local void* callback_thread_data = nullptr; + } namespace geos { @@ -37,16 +41,23 @@ Interrupt::request() requested = true; } +void +Interrupt::requestForCurrentThread() +{ + requested_for_thread = true; +} + void Interrupt::cancel() { requested = false; + requested_for_thread = false; } bool Interrupt::check() { - return requested; + return requested || requested_for_thread; } Interrupt::Callback* @@ -57,14 +68,25 @@ Interrupt::registerCallback(Interrupt::Callback* cb) return prev; } +Interrupt::ThreadCallback* +Interrupt::registerThreadCallback(ThreadCallback* cb, void* data) +{ + ThreadCallback* prev = callback_thread; + callback_thread = cb; + callback_thread_data = data; + return prev; +} + void Interrupt::process() { if(callback) { (*callback)(); } - if(requested) { - requested = false; + if(callback_thread) { + (*callback_thread)(callback_thread_data); + } + if(check()) { interrupt(); } } @@ -74,6 +96,7 @@ void Interrupt::interrupt() { requested = false; + requested_for_thread = false; throw InterruptedException(); } diff --git a/tests/unit/capi/GEOSInterruptTest.cpp b/tests/unit/capi/GEOSInterruptTest.cpp index 0bd0301143..a230713103 100644 --- a/tests/unit/capi/GEOSInterruptTest.cpp +++ b/tests/unit/capi/GEOSInterruptTest.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace tut { // @@ -18,6 +19,7 @@ namespace tut { // Common data used in test cases. struct test_capiinterrupt_data { static int numcalls; + static int maxcalls; static GEOSInterruptCallback* nextcb; static void @@ -56,9 +58,18 @@ struct test_capiinterrupt_data { } } + static void + interruptAfterMaxCalls(void* data) + { + if (++*static_cast(data) >= maxcalls) { + GEOS_interruptThread(); + } + } + }; int test_capiinterrupt_data::numcalls = 0; +int test_capiinterrupt_data::maxcalls = 0; GEOSInterruptCallback* test_capiinterrupt_data::nextcb = nullptr; typedef test_group group; @@ -221,5 +232,40 @@ void object::test<5> } +// Test callback is thread-local +template<> +template<> +void object::test<6> +() +{ + maxcalls = 3; + int calls_1 = 0; + int calls_2 = 0; + + initGEOS(notice, notice); + + auto buffer = [](GEOSInterruptThreadCallback* cb, void* data) { + GEOSGeometry* geom1 = GEOSGeomFromWKT("LINESTRING (0 0, 1 0)"); + + GEOS_interruptRegisterThreadCallback(cb, data); + + GEOSGeometry* geom2 = GEOSBuffer(geom1, 1, 8); + GEOSGeom_destroy(geom2); + GEOSGeom_destroy(geom1); + }; + + std::thread t1(buffer, interruptAfterMaxCalls, &calls_1); + std::thread t2(buffer, interruptAfterMaxCalls, &calls_2); + + t1.join(); + t2.join(); + + ensure_equals(calls_1, maxcalls); + ensure_equals(calls_2, maxcalls); + + finishGEOS(); +} + + } // namespace tut diff --git a/tests/unit/util/InterruptTest.cpp b/tests/unit/util/InterruptTest.cpp new file mode 100644 index 0000000000..ceb3147764 --- /dev/null +++ b/tests/unit/util/InterruptTest.cpp @@ -0,0 +1,137 @@ +// tut +#include +// geos +#include +// std +#include +#include +#include + +using geos::util::Interrupt; + +namespace tut { +// +// Test Group +// + +// Common data used in test cases. +struct test_interrupt_data { + static void workForever() { + try { + std::cerr << "Started " << std::this_thread::get_id() << "." << std::endl; + while (true) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + GEOS_CHECK_FOR_INTERRUPTS(); + } + } catch (const std::exception&) { + std::cerr << "Interrupted " << std::this_thread::get_id() << "." << std::endl; + return; + } + } + + static void interruptNow() { + Interrupt::request(); + } + + static std::map* toInterrupt; + + static void interruptIfRequested() { + if (toInterrupt == nullptr) { + return; + } + + auto it = toInterrupt->find(std::this_thread::get_id()); + if (it != toInterrupt->end() && it->second) { + it->second = false; + Interrupt::request(); + } + } +}; + +std::map* test_interrupt_data::toInterrupt = nullptr; + +typedef test_group group; +typedef group::object object; + +group test_interrupt_group("geos::util::Interrupt"); + +// +// Test Cases +// + + +// Interrupt worker thread via global request from from main thead +template<> +template<> +void object::test<1> +() +{ + std::thread t(workForever); + Interrupt::request(); + + t.join(); +} + +// Interrupt worker thread via global requset from worker thread using a callback +template<> +template<> +void object::test<2> +() +{ + Interrupt::registerCallback(interruptIfRequested); + + std::thread t1(workForever); + std::thread t2(workForever); + + std::map shouldInterrupt; + shouldInterrupt[t1.get_id()] = false; + shouldInterrupt[t2.get_id()] = false; + toInterrupt = &shouldInterrupt; + + shouldInterrupt[t2.get_id()] = true; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + shouldInterrupt[t1.get_id()] = true; + + t1.join(); + t2.join(); +} + +// Register separate callbacks for each thread. Each callback will +// request interruption of itself only. +template<> +template<> +void object::test<3> +() +{ + bool interrupt1 = false; + int numCalls2 = 0; + + auto cb1 = ([](void* data) { + if (*static_cast(data)) { + Interrupt::requestForCurrentThread(); + } + }); + + auto cb2 = ([](void* data) { + if (++*static_cast(data) > 5) { + Interrupt::requestForCurrentThread(); + } + }); + + + std::thread t1([&cb1, &interrupt1]() { + Interrupt::registerThreadCallback(cb1, &interrupt1); + }); + + std::thread t2([&cb2, &numCalls2]() { + Interrupt::registerThreadCallback(cb2, &numCalls2); + }); + + t2.join(); + + interrupt1 = true; + t1.join(); +} + +} // namespace tut +