diff --git a/source/Network/Xmpp/Concurrent.hs b/source/Network/Xmpp/Concurrent.hs index 4344875..c40090f 100644 --- a/source/Network/Xmpp/Concurrent.hs +++ b/source/Network/Xmpp/Concurrent.hs @@ -8,14 +8,12 @@ module Network.Xmpp.Concurrent , module Network.Xmpp.Concurrent.Message , module Network.Xmpp.Concurrent.Presence , module Network.Xmpp.Concurrent.IQ - , toChans + , StanzaHandler , newSession , writeWorker , session ) where -import Network.Xmpp.Concurrent.Monad -import Network.Xmpp.Concurrent.Threads import Control.Applicative((<$>),(<*>)) import Control.Concurrent import Control.Concurrent.STM @@ -23,44 +21,56 @@ import Control.Monad import qualified Data.ByteString as BS import Data.IORef import qualified Data.Map as Map +import Data.Maybe import Data.Maybe (fromMaybe) +import Data.Text as Text import Data.XML.Types +import Network +import qualified Network.TLS as TLS import Network.Xmpp.Concurrent.Basic import Network.Xmpp.Concurrent.IQ import Network.Xmpp.Concurrent.Message +import Network.Xmpp.Concurrent.Monad import Network.Xmpp.Concurrent.Presence -import Network.Xmpp.Concurrent.Types import Network.Xmpp.Concurrent.Threads +import Network.Xmpp.Concurrent.Threads +import Network.Xmpp.Concurrent.Types import Network.Xmpp.Marshal -import Network.Xmpp.Types -import Network -import Data.Text as Text -import Network.Xmpp.Tls -import qualified Network.TLS as TLS import Network.Xmpp.Sasl import Network.Xmpp.Sasl.Mechanisms import Network.Xmpp.Sasl.Types -import Data.Maybe import Network.Xmpp.Stream +import Network.Xmpp.Tls +import Network.Xmpp.Types import Network.Xmpp.Utilities import Control.Monad.Error -import Data.Default -import System.Log.Logger -import Control.Monad.State.Strict +import Data.Default +import System.Log.Logger +import Control.Monad.State.Strict + +runHandlers :: (TChan Stanza) -> [StanzaHandler] -> Stanza -> IO () +runHandlers _ [] _ = return () +runHandlers outC (h:hands) sta = do + res <- h outC sta + case res of + True -> runHandlers outC hands sta + False -> return () + +toChan :: TChan Stanza -> StanzaHandler +toChan stanzaC _ sta = do + atomically $ writeTChan stanzaC sta + return True + -toChans :: TChan Stanza - -> TChan Stanza - -> TVar IQHandlers - -> Stanza - -> IO () -toChans stanzaC outC iqHands sta = atomically $ do - writeTChan stanzaC sta +handleIQ :: TVar IQHandlers + -> StanzaHandler +handleIQ iqHands outC sta = atomically $ do case sta of - IQRequestS i -> handleIQRequest iqHands i - IQResultS i -> handleIQResponse iqHands (Right i) - IQErrorS i -> handleIQResponse iqHands (Left i) - _ -> return () + IQRequestS i -> handleIQRequest iqHands i >> return False + IQResultS i -> handleIQResponse iqHands (Right i) >> return False + IQErrorS i -> handleIQResponse iqHands (Left i) >> return False + _ -> return True where -- If the IQ request has a namespace, send it through the appropriate channel. handleIQRequest :: TVar IQHandlers -> IQRequest -> STM () @@ -96,7 +106,11 @@ newSession stream config = runErrorT $ do stanzaChan <- lift newTChanIO iqHandlers <- lift $ newTVarIO (Map.empty, Map.empty) eh <- lift $ newTVarIO $ EventHandlers { connectionClosedHandler = sessionClosedHandler config } - let stanzaHandler = toChans stanzaChan outC iqHandlers + let stanzaHandler = runHandlers outC $ Prelude.concat [ [toChan stanzaChan] + , extraStanzaHandlers + config + , [handleIQ iqHandlers] + ] (kill, wLock, streamState, readerThread) <- ErrorT $ startThreadsWith stanzaHandler eh stream writer <- lift $ forkIO $ writeWorker outC wLock return $ Session { stanzaCh = stanzaChan diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs index 35c766c..e3f8340 100644 --- a/source/Network/Xmpp/Types.hs +++ b/source/Network/Xmpp/Types.hs @@ -36,6 +36,7 @@ module Network.Xmpp.Types , StreamState(..) , ConnectionState(..) , StreamErrorInfo(..) + , StanzaHandler , StreamConfiguration(..) , langTag , Jid(..) @@ -1105,6 +1106,10 @@ hostnameP = do then fail "Hostname too long." else return $ Text.concat [label, Text.pack ".", r] +type StanzaHandler = TChan Stanza -- ^ outgoing stanza + -> Stanza -- ^ stanza to handle + -> IO Bool -- ^ True when processing should continue + -- | Configuration for the @Session@ object. data SessionConfiguration = SessionConfiguration { -- | Configuration for the @Stream@ object. @@ -1113,6 +1118,7 @@ data SessionConfiguration = SessionConfiguration , sessionClosedHandler :: XmppFailure -> IO () -- | Function to generate the stream of stanza identifiers. , sessionStanzaIDs :: IO StanzaID + , extraStanzaHandlers :: [StanzaHandler] } instance Default SessionConfiguration where @@ -1124,6 +1130,7 @@ instance Default SessionConfiguration where curId <- readTVar idRef writeTVar idRef (curId + 1 :: Integer) return . read. show $ curId + , extraStanzaHandlers = [] } -- | How the client should behave in regards to TLS.