diff --git a/pontarius-xmpp.cabal b/pontarius-xmpp.cabal index f47d74b..84f7d23 100644 --- a/pontarius-xmpp.cabal +++ b/pontarius-xmpp.cabal @@ -56,6 +56,7 @@ Library , split >=0.1.2.3 , stm >=2.1.2.1 , stringprep >=0.1.3 + , template-haskell >=2.5 , text >=0.11.1.5 , tls >=1.1.0 , tls-extra >=0.5.0 @@ -91,6 +92,7 @@ Library , split >=0.1.2.3 , stm >=2.1.2.1 , stringprep >=0.1.3 + , template-haskell >=2.5 , text >=0.11.1.5 , tls >=1.1.0 , tls-extra >=0.5.0 diff --git a/source/Network/Xmpp.hs b/source/Network/Xmpp.hs index b5aab7e..0ea3561 100644 --- a/source/Network/Xmpp.hs +++ b/source/Network/Xmpp.hs @@ -28,6 +28,7 @@ module Network.Xmpp Session , session , setConnectionClosedHandler + , reconnect , StreamConfiguration(..) , SessionConfiguration(..) , ConnectionDetails(..) @@ -45,6 +46,7 @@ module Network.Xmpp -- for addressing entities in the network. It is somewhat similar to an e-mail -- address, but contains three parts instead of two. , Jid + , jidQ , isBare , isFull , jidFromText diff --git a/source/Network/Xmpp/Concurrent.hs b/source/Network/Xmpp/Concurrent.hs index 059f9c6..df4ba04 100644 --- a/source/Network/Xmpp/Concurrent.hs +++ b/source/Network/Xmpp/Concurrent.hs @@ -12,12 +12,13 @@ module Network.Xmpp.Concurrent , newSession , session , newStanzaID + , reconnect ) where import Control.Concurrent.STM +import qualified Control.Exception as Ex import Control.Monad import Control.Monad.Error -import qualified Control.Exception as Ex import qualified Data.Map as Map import Data.Maybe import Data.Text as Text @@ -30,13 +31,14 @@ import Network.Xmpp.Concurrent.Monad import Network.Xmpp.Concurrent.Presence import Network.Xmpp.Concurrent.Threads import Network.Xmpp.Concurrent.Types -import Network.Xmpp.IM.Roster.Types import Network.Xmpp.IM.Roster +import Network.Xmpp.IM.Roster.Types import Network.Xmpp.Sasl import Network.Xmpp.Sasl.Types import Network.Xmpp.Stream import Network.Xmpp.Tls import Network.Xmpp.Types +import System.Log.Logger import Control.Monad.State.Strict @@ -119,13 +121,17 @@ handleIQ iqHands writeSem sta = do iqID (Right iq') = iqResultID iq' -- | Creates and initializes a new Xmpp context. -newSession :: Stream -> SessionConfiguration -> IO (Either XmppFailure Session) -newSession stream config = runErrorT $ do +newSession :: Stream + -> SessionConfiguration + -> HostName + -> Maybe (ConnectionState -> [SaslHandler] , Maybe Text) + -> IO (Either XmppFailure Session) +newSession stream config realm mbSasl = runErrorT $ do write' <- liftIO $ withStream' (gets $ streamSend . streamHandle) stream writeSem <- liftIO $ newTMVarIO write' stanzaChan <- lift newTChanIO iqHands <- lift $ newTVarIO (Map.empty, Map.empty) - eh <- lift $ newTVarIO $ EventHandlers { connectionClosedHandler = onConnectionClosed config } + eh <- lift $ newEmptyTMVarIO ros <- liftIO . newTVarIO $ Roster Nothing Map.empty let rosterH = if (enableRoster config) then handleRoster ros else \ _ _ -> return True @@ -139,7 +145,7 @@ newSession stream config = runErrorT $ do ] (kill, wLock, streamState, reader) <- ErrorT $ startThreadsWith writeSem stanzaHandler eh stream idGen <- liftIO $ sessionStanzaIDs config - return $ Session { stanzaCh = stanzaChan + let sess = Session { stanzaCh = stanzaChan , iqHandlers = iqHands , writeSemaphore = wLock , readerThread = reader @@ -149,7 +155,12 @@ newSession stream config = runErrorT $ do , stopThreads = kill , conf = config , rosterRef = ros + , sRealm = realm + , sSaslCredentials = mbSasl } + liftIO . atomically $ putTMVar eh $ EventHandlers { connectionClosedHandler = + onConnectionClosed config sess } + return sess -- | Creates a 'Session' object by setting up a connection with an XMPP server. -- @@ -172,9 +183,40 @@ session realm mbSasl config = runErrorT $ do case mbAuthError of Nothing -> return () Just e -> throwError $ XmppAuthFailure e - ses <- ErrorT $ newSession stream config + ses <- ErrorT $ newSession stream config realm mbSasl liftIO $ when (enableRoster config) $ initRoster ses return ses +reconnect :: Session -> IO () +reconnect sess@Session{conf = config} = do + debugM "Pontarius.Xmpp" "reconnecting" + _ <- flip withConnection sess $ \oldStream -> do + s <- runErrorT $ do + liftIO $ debugM "Pontarius.Xmpp" "reconnect: closing stream" + _ <- liftIO $ closeStreams oldStream + liftIO $ debugM "Pontarius.Xmpp" "reconnect: opening stream" + stream <- ErrorT $ openStream (sRealm sess) + (sessionStreamConfiguration config) + liftIO $ debugM "Pontarius.Xmpp" "reconnect: tls" + ErrorT $ tls stream + liftIO $ debugM "Pontarius.Xmpp" "reconnect: auth" + cs <- liftIO $ withStream (gets streamConnectionState) stream + mbAuthError <- case sSaslCredentials sess of + Nothing -> return Nothing + Just (handlers, resource) -> ErrorT $ auth (handlers cs) + resource stream + case mbAuthError of + Nothing -> return () + Just e -> throwError $ XmppAuthFailure e + return stream + case s of + Left e -> do + errorM "Pontarius.Xmpp" $ "reconnect failed" ++ show e + return (Left e , oldStream ) + Right r -> return (Right () , r ) + when (enableRoster config) $ initRoster sess + + + newStanzaID :: Session -> IO StanzaID newStanzaID = idGenerator diff --git a/source/Network/Xmpp/Concurrent/Monad.hs b/source/Network/Xmpp/Concurrent/Monad.hs index 2545164..0aeb4c1 100644 --- a/source/Network/Xmpp/Concurrent/Monad.hs +++ b/source/Network/Xmpp/Concurrent/Monad.hs @@ -51,7 +51,7 @@ withConnection a session = do putTMVar (writeSemaphore session) wl putTMVar (streamRef session) s' return $ Right res - ) + ) -- TODO: DO we have to replace the MVars in case of ane exception? -- We treat all Exceptions as fatal. If we catch a StreamError, we -- return it. Otherwise, we throw an exception. [ Ex.Handler $ \e -> return $ Left (e :: XmppFailure) @@ -61,15 +61,15 @@ withConnection a session = do -- | Executes a function to update the event handlers. modifyHandlers :: (EventHandlers -> EventHandlers) -> Session -> IO () -modifyHandlers f session = atomically $ modifyTVar_ (eventHandlers session) f +modifyHandlers f session = atomically $ modifyTMVar_ (eventHandlers session) f where -- Borrowing modifyTVar from -- http://hackage.haskell.org/packages/archive/stm/2.4/doc/html/src/Control-Concurrent-STM-TVar.html -- as it's not available in GHC 7.0. - modifyTVar_ :: TVar a -> (a -> a) -> STM () - modifyTVar_ var g = do - x <- readTVar var - writeTVar var (g x) + modifyTMVar_ :: TMVar a -> (a -> a) -> STM () + modifyTMVar_ var g = do + x <- takeTMVar var + putTMVar var (g x) -- | Changes the handler to be executed when the server connection is closed. To -- avoid race conditions the initial value should be set in the configuration @@ -81,12 +81,13 @@ setConnectionClosedHandler eh session = do runConnectionClosedHandler :: Session -> XmppFailure -> IO () runConnectionClosedHandler session e = do - h <- connectionClosedHandler <$> atomically (readTVar $ eventHandlers session) + h <- connectionClosedHandler <$> atomically (readTMVar + $ eventHandlers session) h e -- | Run an event handler. runHandler :: (EventHandlers -> IO a) -> Session -> IO a -runHandler h session = h =<< atomically (readTVar $ eventHandlers session) +runHandler h session = h =<< atomically (readTMVar $ eventHandlers session) -- | End the current Xmpp session. diff --git a/source/Network/Xmpp/Concurrent/Threads.hs b/source/Network/Xmpp/Concurrent/Threads.hs index aa9cc50..5f711b0 100644 --- a/source/Network/Xmpp/Concurrent/Threads.hs +++ b/source/Network/Xmpp/Concurrent/Threads.hs @@ -91,7 +91,7 @@ readWorker onStanza onCClosed stateRef = forever . Ex.mask_ $ do -- connection. startThreadsWith :: TMVar (BS.ByteString -> IO Bool) -> (Stanza -> IO ()) - -> TVar EventHandlers + -> TMVar EventHandlers -> Stream -> IO (Either XmppFailure (IO (), TMVar (BS.ByteString -> IO Bool), @@ -116,9 +116,9 @@ startThreadsWith writeSem stanzaHandler eh con = do _ <- forM threads killThread return () -- Call the connection closed handlers. - noCon :: TVar EventHandlers -> XmppFailure -> IO () + noCon :: TMVar EventHandlers -> XmppFailure -> IO () noCon h e = do - hands <- atomically $ readTVar h + hands <- atomically $ readTMVar h _ <- forkIO $ connectionClosedHandler hands e return () diff --git a/source/Network/Xmpp/Concurrent/Types.hs b/source/Network/Xmpp/Concurrent/Types.hs index d5a1c6a..30d460a 100644 --- a/source/Network/Xmpp/Concurrent/Types.hs +++ b/source/Network/Xmpp/Concurrent/Types.hs @@ -13,8 +13,10 @@ import Data.Text (Text) import qualified Data.Text as Text import Data.Typeable import Data.XML.Types (Element) +import Network import Network.Xmpp.IM.Roster.Types import Network.Xmpp.Types +import Network.Xmpp.Sasl.Types -- | Configuration for the @Session@ object. @@ -22,7 +24,7 @@ data SessionConfiguration = SessionConfiguration { -- | Configuration for the @Stream@ object. sessionStreamConfiguration :: StreamConfiguration -- | Handler to be run when the session ends (for whatever reason). - , onConnectionClosed :: XmppFailure -> IO () + , onConnectionClosed :: Session -> XmppFailure -> IO () -- | Function to generate the stream of stanza identifiers. , sessionStanzaIDs :: IO (IO StanzaID) , extraStanzaHandlers :: [StanzaHandler] @@ -31,7 +33,7 @@ data SessionConfiguration = SessionConfiguration instance Default SessionConfiguration where def = SessionConfiguration { sessionStreamConfiguration = def - , onConnectionClosed = \_ -> return () + , onConnectionClosed = \_ _ -> return () , sessionStanzaIDs = do idRef <- newTVarIO 1 return . atomically $ do @@ -69,10 +71,12 @@ data Session = Session -- | Lock (used by withStream) to make sure that a maximum of one -- Stream action is executed at any given time. , streamRef :: TMVar Stream - , eventHandlers :: TVar EventHandlers + , eventHandlers :: TMVar EventHandlers , stopThreads :: IO () , rosterRef :: TVar Roster , conf :: SessionConfiguration + , sRealm :: HostName + , sSaslCredentials :: Maybe (ConnectionState -> [SaslHandler] , Maybe Text) } -- | IQHandlers holds the registered channels for incomming IQ requests and diff --git a/source/Network/Xmpp/IM/Roster.hs b/source/Network/Xmpp/IM/Roster.hs index ba01377..6933543 100644 --- a/source/Network/Xmpp/IM/Roster.hs +++ b/source/Network/Xmpp/IM/Roster.hs @@ -145,23 +145,23 @@ retrieveRoster mbOldRoster sess = do is) toItem :: QueryItem -> Item -toItem qi = Item { approved = fromMaybe False (qiApproved qi) - , ask = qiAsk qi - , jid = qiJid qi - , name = qiName qi - , subscription = fromMaybe None (qiSubscription qi) - , groups = nub $ qiGroups qi +toItem qi = Item { riApproved = fromMaybe False (qiApproved qi) + , riAsk = qiAsk qi + , riJid = qiJid qi + , riName = qiName qi + , riSubscription = fromMaybe None (qiSubscription qi) + , riGroups = nub $ qiGroups qi } fromItem :: Item -> QueryItem fromItem i = QueryItem { qiApproved = Nothing , qiAsk = False - , qiJid = jid i - , qiName = name i - , qiSubscription = case subscription i of + , qiJid = riJid i + , qiName = riName i + , qiSubscription = case riSubscription i of Remove -> Just Remove _ -> Nothing - , qiGroups = nub $ groups i + , qiGroups = nub $ riGroups i } xpItems :: PU [Node] [QueryItem] diff --git a/source/Network/Xmpp/IM/Roster/Types.hs b/source/Network/Xmpp/IM/Roster/Types.hs index 04854b4..b5de0ef 100644 --- a/source/Network/Xmpp/IM/Roster/Types.hs +++ b/source/Network/Xmpp/IM/Roster/Types.hs @@ -25,13 +25,13 @@ data Roster = Roster { ver :: Maybe Text , items :: Map.Map Jid Item } deriving Show - -data Item = Item { approved :: Bool - , ask :: Bool - , jid :: Jid - , name :: Maybe Text - , subscription :: Subscription - , groups :: [Text] +-- | Roster Items +data Item = Item { riApproved :: Bool + , riAsk :: Bool + , riJid :: Jid + , riName :: Maybe Text + , riSubscription :: Subscription + , riGroups :: [Text] } deriving Show data QueryItem = QueryItem { qiApproved :: Maybe Bool diff --git a/source/Network/Xmpp/Marshal.hs b/source/Network/Xmpp/Marshal.hs index afa68f6..f594783 100644 --- a/source/Network/Xmpp/Marshal.hs +++ b/source/Network/Xmpp/Marshal.hs @@ -282,5 +282,5 @@ xpJid :: PU Text Jid xpJid = ("xpJid", "") xpPartial ( \input -> case jidFromText input of Nothing -> Left "Could not parse JID." - Just jid -> Right jid) + Just j -> Right j) jidToText diff --git a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs index 01ce054..99e30c7 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs @@ -43,7 +43,7 @@ scram hToken authcid authzid password = do scramhelper ac az pw where scramhelper authcid' authzid' pwd = do - cnonce <- liftIO $ makeNonce + cnonce <- liftIO makeNonce _ <- saslInit "SCRAM-SHA-1" (Just $ cFirstMessage cnonce) sFirstMessage <- saslFromJust =<< pullChallenge prs <- toPairs sFirstMessage diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs index 2302969..47b0da7 100644 --- a/source/Network/Xmpp/Stream.hs +++ b/source/Network/Xmpp/Stream.hs @@ -326,21 +326,26 @@ openStream realm config = runErrorT $ do -- | Send "" and wait for the server to finish processing and to -- close the connection. Any remaining elements from the server are returned. -- Surpresses StreamEndFailure exceptions, but may throw a StreamCloseError. -closeStreams :: Stream -> IO (Either XmppFailure [Element]) +closeStreams :: Stream -> IO () closeStreams = withStream closeStreams' -closeStreams' :: StateT StreamState IO (Either XmppFailure [Element]) +closeStreams' :: StateT StreamState IO () closeStreams' = do - lift $ debugM "Pontarius.Xmpp" "Closing stream..." + lift $ debugM "Pontarius.Xmpp" "Closing stream" send <- gets (streamSend . streamHandle) cc <- gets (streamClose . streamHandle) + lift $ debugM "Pontarius.Xmpp" "Sending closing tag" void . liftIO $ send "" + lift $ debugM "Pontarius.Xmpp" "Waiting for stream to close" void $ liftIO $ forkIO $ do threadDelay 3000000 -- TODO: Configurable value void ((Ex.try cc) :: IO (Either Ex.SomeException ())) return () put xmppNoStream{ streamConnectionState = Finished } - collectElems [] + lift $ debugM "Pontarius.Xmpp" "Collecting remaining elements" +-- es <- collectElems [] + -- lift $ debugM "Pontarius.Xmpp" "Stream sucessfully closed" + -- return es where -- Pulls elements from the stream until the stream ends, or an error is -- raised. diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs index 8a651b1..4ee22ce 100644 --- a/source/Network/Xmpp/Types.hs +++ b/source/Network/Xmpp/Types.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE StandaloneDeriving #-} @@ -43,6 +44,7 @@ module Network.Xmpp.Types , StreamConfiguration(..) , langTag , Jid(..) + , jidQ , isBare , isFull , jidFromText @@ -76,6 +78,8 @@ import Data.Text (Text) import qualified Data.Text as Text import Data.Typeable(Typeable) import Data.XML.Types +import Language.Haskell.TH +import Language.Haskell.TH.Quote import Network import Network.DNS import Network.TLS hiding (Version) @@ -938,9 +942,7 @@ jidToTexts (Jid nd dmn res) = (nd, dmn, res) -- Produces a Jid value in the format "parseJid \"\"". instance Show Jid where - show (Jid nd dmn res) = - "parseJid \"" ++ maybe "" ((++ "@") . Text.unpack) nd ++ Text.unpack dmn ++ - maybe "" (('/' :) . Text.unpack) res ++ "\"" + show j = "parseJid " ++ show (jidToText j) -- The string must be in the format "parseJid \"\"". -- TODO: This function should produce its error values in a uniform way. @@ -960,6 +962,26 @@ instance Read Jid where [(parseJid (read s' :: String), r)] -- May fail with "Prelude.read: no parse" -- or the `parseJid' error message (see below) +jidQ :: QuasiQuoter +jidQ = QuasiQuoter { quoteExp = \s -> do + when (head s == ' ') . fail $ "Leading whitespaces in JID" ++ show s + let t = Text.pack s + when (Text.last t == ' ') . reportWarning $ "Trailing whitespace in JID " ++ show s + case jidFromText t of + Nothing -> fail $ "Could not parse JID " ++ s + Just j -> [| Jid $(mbTextE $ localpart_ j) + $(textE $ domainpart_ j) + $(mbTextE $ resourcepart_ j) + |] + , quotePat = fail "Jid patterns aren't implemented" + , quoteType = fail "jid QQ can't be used in type context" + , quoteDec = fail "jid QQ can't be used in declaration context" + } + where + textE t = [| Text.pack $(stringE $ Text.unpack t) |] + mbTextE Nothing = [| Nothing |] + mbTextE (Just s) = [| Just $(textE s) |] + -- | Parses a JID string. -- -- Note: This function is only meant to be used to reverse @Jid@ Show @@ -967,7 +989,7 @@ instance Read Jid where -- validate; please refer to @jidFromText@ for a safe equivalent. parseJid :: String -> Jid parseJid s = case jidFromText $ Text.pack s of - Just jid -> jid + Just j -> j Nothing -> error $ "Jid value (" ++ s ++ ") did not validate" -- | Converts a Text to a JID. @@ -1017,7 +1039,7 @@ isFull = not . isBare -- | Returns the @Jid@ without the resourcepart (if any). toBare :: Jid -> Jid -toBare jid = jid{resourcepart_ = Nothing} +toBare j = j{resourcepart_ = Nothing} -- | Returns the localpart of the @Jid@ (if any). localpart :: Jid -> Maybe Text