Browse Source

add reconnect function

master
Philipp Balzarek 13 years ago
parent
commit
ae41225d54
  1. 1
      source/Network/Xmpp.hs
  2. 56
      source/Network/Xmpp/Concurrent.hs
  3. 17
      source/Network/Xmpp/Concurrent/Monad.hs
  4. 6
      source/Network/Xmpp/Concurrent/Threads.hs
  5. 10
      source/Network/Xmpp/Concurrent/Types.hs
  6. 2
      source/Network/Xmpp/Marshal.hs
  7. 2
      source/Network/Xmpp/Sasl/Mechanisms/Scram.hs
  8. 13
      source/Network/Xmpp/Stream.hs
  9. 4
      source/Network/Xmpp/Types.hs

1
source/Network/Xmpp.hs

@ -28,6 +28,7 @@ module Network.Xmpp @@ -28,6 +28,7 @@ module Network.Xmpp
Session
, session
, setConnectionClosedHandler
, reconnect
, StreamConfiguration(..)
, SessionConfiguration(..)
, ConnectionDetails(..)

56
source/Network/Xmpp/Concurrent.hs

@ -12,12 +12,13 @@ module Network.Xmpp.Concurrent @@ -12,12 +12,13 @@ module Network.Xmpp.Concurrent
, newSession
, session
, newStanzaID
, reconnect
) where
import Control.Concurrent.STM
import qualified Control.Exception as Ex
import Control.Monad
import Control.Monad.Error
import qualified Control.Exception as Ex
import qualified Data.Map as Map
import Data.Maybe
import Data.Text as Text
@ -30,13 +31,14 @@ import Network.Xmpp.Concurrent.Monad @@ -30,13 +31,14 @@ import Network.Xmpp.Concurrent.Monad
import Network.Xmpp.Concurrent.Presence
import Network.Xmpp.Concurrent.Threads
import Network.Xmpp.Concurrent.Types
import Network.Xmpp.IM.Roster.Types
import Network.Xmpp.IM.Roster
import Network.Xmpp.IM.Roster.Types
import Network.Xmpp.Sasl
import Network.Xmpp.Sasl.Types
import Network.Xmpp.Stream
import Network.Xmpp.Tls
import Network.Xmpp.Types
import System.Log.Logger
import Control.Monad.State.Strict
@ -119,13 +121,17 @@ handleIQ iqHands writeSem sta = do @@ -119,13 +121,17 @@ handleIQ iqHands writeSem sta = do
iqID (Right iq') = iqResultID iq'
-- | Creates and initializes a new Xmpp context.
newSession :: Stream -> SessionConfiguration -> IO (Either XmppFailure Session)
newSession stream config = runErrorT $ do
newSession :: Stream
-> SessionConfiguration
-> HostName
-> Maybe (ConnectionState -> [SaslHandler] , Maybe Text)
-> IO (Either XmppFailure Session)
newSession stream config realm mbSasl = runErrorT $ do
write' <- liftIO $ withStream' (gets $ streamSend . streamHandle) stream
writeSem <- liftIO $ newTMVarIO write'
stanzaChan <- lift newTChanIO
iqHands <- lift $ newTVarIO (Map.empty, Map.empty)
eh <- lift $ newTVarIO $ EventHandlers { connectionClosedHandler = onConnectionClosed config }
eh <- lift $ newEmptyTMVarIO
ros <- liftIO . newTVarIO $ Roster Nothing Map.empty
let rosterH = if (enableRoster config) then handleRoster ros
else \ _ _ -> return True
@ -139,7 +145,7 @@ newSession stream config = runErrorT $ do @@ -139,7 +145,7 @@ newSession stream config = runErrorT $ do
]
(kill, wLock, streamState, reader) <- ErrorT $ startThreadsWith writeSem stanzaHandler eh stream
idGen <- liftIO $ sessionStanzaIDs config
return $ Session { stanzaCh = stanzaChan
let sess = Session { stanzaCh = stanzaChan
, iqHandlers = iqHands
, writeSemaphore = wLock
, readerThread = reader
@ -149,7 +155,12 @@ newSession stream config = runErrorT $ do @@ -149,7 +155,12 @@ newSession stream config = runErrorT $ do
, stopThreads = kill
, conf = config
, rosterRef = ros
, sRealm = realm
, sSaslCredentials = mbSasl
}
liftIO . atomically $ putTMVar eh $ EventHandlers { connectionClosedHandler =
onConnectionClosed config sess }
return sess
-- | Creates a 'Session' object by setting up a connection with an XMPP server.
--
@ -172,9 +183,40 @@ session realm mbSasl config = runErrorT $ do @@ -172,9 +183,40 @@ session realm mbSasl config = runErrorT $ do
case mbAuthError of
Nothing -> return ()
Just e -> throwError $ XmppAuthFailure e
ses <- ErrorT $ newSession stream config
ses <- ErrorT $ newSession stream config realm mbSasl
liftIO $ when (enableRoster config) $ initRoster ses
return ses
reconnect :: Session -> IO ()
reconnect sess@Session{conf = config} = do
debugM "Pontarius.Xmpp" "reconnecting"
_ <- flip withConnection sess $ \oldStream -> do
s <- runErrorT $ do
liftIO $ debugM "Pontarius.Xmpp" "reconnect: closing stream"
_ <- liftIO $ closeStreams oldStream
liftIO $ debugM "Pontarius.Xmpp" "reconnect: opening stream"
stream <- ErrorT $ openStream (sRealm sess)
(sessionStreamConfiguration config)
liftIO $ debugM "Pontarius.Xmpp" "reconnect: tls"
ErrorT $ tls stream
liftIO $ debugM "Pontarius.Xmpp" "reconnect: auth"
cs <- liftIO $ withStream (gets streamConnectionState) stream
mbAuthError <- case sSaslCredentials sess of
Nothing -> return Nothing
Just (handlers, resource) -> ErrorT $ auth (handlers cs)
resource stream
case mbAuthError of
Nothing -> return ()
Just e -> throwError $ XmppAuthFailure e
return stream
case s of
Left e -> do
errorM "Pontarius.Xmpp" $ "reconnect failed" ++ show e
return (Left e , oldStream )
Right r -> return (Right () , r )
when (enableRoster config) $ initRoster sess
newStanzaID :: Session -> IO StanzaID
newStanzaID = idGenerator

17
source/Network/Xmpp/Concurrent/Monad.hs

@ -51,7 +51,7 @@ withConnection a session = do @@ -51,7 +51,7 @@ withConnection a session = do
putTMVar (writeSemaphore session) wl
putTMVar (streamRef session) s'
return $ Right res
)
) -- TODO: DO we have to replace the MVars in case of ane exception?
-- We treat all Exceptions as fatal. If we catch a StreamError, we
-- return it. Otherwise, we throw an exception.
[ Ex.Handler $ \e -> return $ Left (e :: XmppFailure)
@ -61,15 +61,15 @@ withConnection a session = do @@ -61,15 +61,15 @@ withConnection a session = do
-- | Executes a function to update the event handlers.
modifyHandlers :: (EventHandlers -> EventHandlers) -> Session -> IO ()
modifyHandlers f session = atomically $ modifyTVar_ (eventHandlers session) f
modifyHandlers f session = atomically $ modifyTMVar_ (eventHandlers session) f
where
-- Borrowing modifyTVar from
-- http://hackage.haskell.org/packages/archive/stm/2.4/doc/html/src/Control-Concurrent-STM-TVar.html
-- as it's not available in GHC 7.0.
modifyTVar_ :: TVar a -> (a -> a) -> STM ()
modifyTVar_ var g = do
x <- readTVar var
writeTVar var (g x)
modifyTMVar_ :: TMVar a -> (a -> a) -> STM ()
modifyTMVar_ var g = do
x <- takeTMVar var
putTMVar var (g x)
-- | Changes the handler to be executed when the server connection is closed. To
-- avoid race conditions the initial value should be set in the configuration
@ -81,12 +81,13 @@ setConnectionClosedHandler eh session = do @@ -81,12 +81,13 @@ setConnectionClosedHandler eh session = do
runConnectionClosedHandler :: Session -> XmppFailure -> IO ()
runConnectionClosedHandler session e = do
h <- connectionClosedHandler <$> atomically (readTVar $ eventHandlers session)
h <- connectionClosedHandler <$> atomically (readTMVar
$ eventHandlers session)
h e
-- | Run an event handler.
runHandler :: (EventHandlers -> IO a) -> Session -> IO a
runHandler h session = h =<< atomically (readTVar $ eventHandlers session)
runHandler h session = h =<< atomically (readTMVar $ eventHandlers session)
-- | End the current Xmpp session.

6
source/Network/Xmpp/Concurrent/Threads.hs

@ -91,7 +91,7 @@ readWorker onStanza onCClosed stateRef = forever . Ex.mask_ $ do @@ -91,7 +91,7 @@ readWorker onStanza onCClosed stateRef = forever . Ex.mask_ $ do
-- connection.
startThreadsWith :: TMVar (BS.ByteString -> IO Bool)
-> (Stanza -> IO ())
-> TVar EventHandlers
-> TMVar EventHandlers
-> Stream
-> IO (Either XmppFailure (IO (),
TMVar (BS.ByteString -> IO Bool),
@ -116,9 +116,9 @@ startThreadsWith writeSem stanzaHandler eh con = do @@ -116,9 +116,9 @@ startThreadsWith writeSem stanzaHandler eh con = do
_ <- forM threads killThread
return ()
-- Call the connection closed handlers.
noCon :: TVar EventHandlers -> XmppFailure -> IO ()
noCon :: TMVar EventHandlers -> XmppFailure -> IO ()
noCon h e = do
hands <- atomically $ readTVar h
hands <- atomically $ readTMVar h
_ <- forkIO $ connectionClosedHandler hands e
return ()

10
source/Network/Xmpp/Concurrent/Types.hs

@ -13,8 +13,10 @@ import Data.Text (Text) @@ -13,8 +13,10 @@ import Data.Text (Text)
import qualified Data.Text as Text
import Data.Typeable
import Data.XML.Types (Element)
import Network
import Network.Xmpp.IM.Roster.Types
import Network.Xmpp.Types
import Network.Xmpp.Sasl.Types
-- | Configuration for the @Session@ object.
@ -22,7 +24,7 @@ data SessionConfiguration = SessionConfiguration @@ -22,7 +24,7 @@ data SessionConfiguration = SessionConfiguration
{ -- | Configuration for the @Stream@ object.
sessionStreamConfiguration :: StreamConfiguration
-- | Handler to be run when the session ends (for whatever reason).
, onConnectionClosed :: XmppFailure -> IO ()
, onConnectionClosed :: Session -> XmppFailure -> IO ()
-- | Function to generate the stream of stanza identifiers.
, sessionStanzaIDs :: IO (IO StanzaID)
, extraStanzaHandlers :: [StanzaHandler]
@ -31,7 +33,7 @@ data SessionConfiguration = SessionConfiguration @@ -31,7 +33,7 @@ data SessionConfiguration = SessionConfiguration
instance Default SessionConfiguration where
def = SessionConfiguration { sessionStreamConfiguration = def
, onConnectionClosed = \_ -> return ()
, onConnectionClosed = \_ _ -> return ()
, sessionStanzaIDs = do
idRef <- newTVarIO 1
return . atomically $ do
@ -69,10 +71,12 @@ data Session = Session @@ -69,10 +71,12 @@ data Session = Session
-- | Lock (used by withStream) to make sure that a maximum of one
-- Stream action is executed at any given time.
, streamRef :: TMVar Stream
, eventHandlers :: TVar EventHandlers
, eventHandlers :: TMVar EventHandlers
, stopThreads :: IO ()
, rosterRef :: TVar Roster
, conf :: SessionConfiguration
, sRealm :: HostName
, sSaslCredentials :: Maybe (ConnectionState -> [SaslHandler] , Maybe Text)
}
-- | IQHandlers holds the registered channels for incomming IQ requests and

2
source/Network/Xmpp/Marshal.hs

@ -282,5 +282,5 @@ xpJid :: PU Text Jid @@ -282,5 +282,5 @@ xpJid :: PU Text Jid
xpJid = ("xpJid", "") <?>
xpPartial ( \input -> case jidFromText input of
Nothing -> Left "Could not parse JID."
Just jid -> Right jid)
Just j -> Right j)
jidToText

2
source/Network/Xmpp/Sasl/Mechanisms/Scram.hs

@ -43,7 +43,7 @@ scram hToken authcid authzid password = do @@ -43,7 +43,7 @@ scram hToken authcid authzid password = do
scramhelper ac az pw
where
scramhelper authcid' authzid' pwd = do
cnonce <- liftIO $ makeNonce
cnonce <- liftIO makeNonce
_ <- saslInit "SCRAM-SHA-1" (Just $ cFirstMessage cnonce)
sFirstMessage <- saslFromJust =<< pullChallenge
prs <- toPairs sFirstMessage

13
source/Network/Xmpp/Stream.hs

@ -326,21 +326,26 @@ openStream realm config = runErrorT $ do @@ -326,21 +326,26 @@ openStream realm config = runErrorT $ do
-- | Send "</stream:stream>" and wait for the server to finish processing and to
-- close the connection. Any remaining elements from the server are returned.
-- Surpresses StreamEndFailure exceptions, but may throw a StreamCloseError.
closeStreams :: Stream -> IO (Either XmppFailure [Element])
closeStreams :: Stream -> IO ()
closeStreams = withStream closeStreams'
closeStreams' :: StateT StreamState IO (Either XmppFailure [Element])
closeStreams' :: StateT StreamState IO ()
closeStreams' = do
lift $ debugM "Pontarius.Xmpp" "Closing stream..."
lift $ debugM "Pontarius.Xmpp" "Closing stream"
send <- gets (streamSend . streamHandle)
cc <- gets (streamClose . streamHandle)
lift $ debugM "Pontarius.Xmpp" "Sending closing tag"
void . liftIO $ send "</stream:stream>"
lift $ debugM "Pontarius.Xmpp" "Waiting for stream to close"
void $ liftIO $ forkIO $ do
threadDelay 3000000 -- TODO: Configurable value
void ((Ex.try cc) :: IO (Either Ex.SomeException ()))
return ()
put xmppNoStream{ streamConnectionState = Finished }
collectElems []
lift $ debugM "Pontarius.Xmpp" "Collecting remaining elements"
-- es <- collectElems []
-- lift $ debugM "Pontarius.Xmpp" "Stream sucessfully closed"
-- return es
where
-- Pulls elements from the stream until the stream ends, or an error is
-- raised.

4
source/Network/Xmpp/Types.hs

@ -987,7 +987,7 @@ jidQ = QuasiQuoter { quoteExp = \s -> case jidFromText (Text.pack s) of @@ -987,7 +987,7 @@ jidQ = QuasiQuoter { quoteExp = \s -> case jidFromText (Text.pack s) of
-- validate; please refer to @jidFromText@ for a safe equivalent.
parseJid :: String -> Jid
parseJid s = case jidFromText $ Text.pack s of
Just jid -> jid
Just j -> j
Nothing -> error $ "Jid value (" ++ s ++ ") did not validate"
-- | Converts a Text to a JID.
@ -1037,7 +1037,7 @@ isFull = not . isBare @@ -1037,7 +1037,7 @@ isFull = not . isBare
-- | Returns the @Jid@ without the resourcepart (if any).
toBare :: Jid -> Jid
toBare jid = jid{resourcepart_ = Nothing}
toBare j = j{resourcepart_ = Nothing}
-- | Returns the localpart of the @Jid@ (if any).
localpart :: Jid -> Maybe Text

Loading…
Cancel
Save