diff --git a/src/Network/XMPP/Concurrent/Threads.hs b/src/Network/XMPP/Concurrent/Threads.hs index 7a4309a..6a57dbb 100644 --- a/src/Network/XMPP/Concurrent/Threads.hs +++ b/src/Network/XMPP/Concurrent/Threads.hs @@ -29,6 +29,13 @@ import Text.XML.Stream.Elements import GHC.IO (unsafeUnmask) +-- While waiting for the first semaphore(s) to flip we might receive +-- another interrupt. When that happens we add it's semaphore to the +-- list and retry waiting +handleInterrupts ts = + Ex.catch (atomically $ forM ts takeTMVar) + ( \(Interrupt t) -> handleInterrupts (t:ts)) + readWorker :: TChan (Either MessageError Message) -> TChan (Either PresenceError Presence) -> TVar IQHandlers @@ -36,22 +43,25 @@ readWorker :: TChan (Either MessageError Message) -> IO () readWorker messageC presenceC handlers stateRef = Ex.mask_ . forever $ do - s <- liftIO . atomically $ takeTMVar stateRef - (sta', s') <- flip runStateT s $ Ex.catch ( do - -- we don't know whether pull will necessarily be interruptible - liftIO $ allowInterrupt - Just <$> pull - ) - (\(Interrupt t) -> do - liftIO . atomically $ - putTMVar stateRef s - liftIO . atomically $ takeTMVar t - return Nothing - ) + res <- liftIO $ Ex.catch ( + Ex.bracket + (atomically $ takeTMVar stateRef) + (atomically . putTMVar stateRef ) + (\s -> do + -- we don't know whether pull will + -- necessarily be interruptible + allowInterrupt + Just <$> runStateT pull s + ) + ) + (\(Interrupt t) -> do + handleInterrupts [t] + return Nothing + ) liftIO . atomically $ do - case sta' of + case res of Nothing -> return () - Just sta -> do + Just (sta, s') -> do putTMVar stateRef s' case sta of MessageS m -> do writeTChan messageC $ Right m