Browse Source

do proper disconnects

master
Philipp Balzarek 14 years ago
parent
commit
4e24d6f16a
  1. 1
      source/Network/Xmpp.hs
  2. 9
      source/Network/Xmpp/Bind.hs
  3. 13
      source/Network/Xmpp/Concurrent/Monad.hs
  4. 17
      source/Network/Xmpp/Concurrent/Threads.hs
  5. 13
      source/Network/Xmpp/Monad.hs
  6. 1
      source/Network/Xmpp/Types.hs
  7. 22
      source/Text/XML/Stream/Elements.hs
  8. 12
      tests/Tests.hs

1
source/Network/Xmpp.hs

@ -37,6 +37,7 @@ module Network.Xmpp
, startTLS , startTLS
, simpleAuth , simpleAuth
, auth , auth
, closeConnection
, endSession , endSession
, setConnectionClosedHandler , setConnectionClosedHandler
-- * JID -- * JID

9
source/Network/Xmpp/Bind.hs

@ -32,16 +32,17 @@ xmppBind :: Maybe Text -> XmppConMonad Jid
xmppBind rsrc = do xmppBind rsrc = do
answer <- xmppSendIQ' "bind" Nothing Set Nothing (bindBody rsrc) answer <- xmppSendIQ' "bind" Nothing Set Nothing (bindBody rsrc)
jid <- case () of () | Right IQResult{iqResultPayload = Just b} <- answer jid <- case () of () | Right IQResult{iqResultPayload = Just b} <- answer
, Right jid <- unpickleElem jidP b , Right jid <- unpickleElem xpJid b
-> return jid -> return jid
| otherwise -> throw $ StreamXMLError | otherwise -> throw $ StreamXMLError
"Bind couldn't unpickle JID" ("Bind couldn't unpickle JID from " ++ show answer)
modify (\s -> s{sJid = Just jid}) modify (\s -> s{sJid = Just jid})
return jid return jid
where where
-- Extracts the character data in the `jid' element. -- Extracts the character data in the `jid' element.
jidP :: PU [Node] Jid xpJid :: PU [Node] Jid
jidP = xpBind $ xpElemNodes "jid" (xpContent xpPrim) xpJid = xpBind $ xpElemNodes jidName (xpContent xpPrim)
jidName = "{urn:ietf:params:xml:ns:xmpp-bind}jid"
-- A `bind' element pickler. -- A `bind' element pickler.
xpBind :: PU [Node] b -> PU [Node] b xpBind :: PU [Node] b -> PU [Node] b

13
source/Network/Xmpp/Concurrent/Monad.hs

@ -1,7 +1,9 @@
{-# LANGUAGE OverloadedStrings #-}
module Network.Xmpp.Concurrent.Monad where module Network.Xmpp.Concurrent.Monad where
import Network.Xmpp.Types import Network.Xmpp.Types
import Control.Applicative((<$>))
import Control.Concurrent import Control.Concurrent
import Control.Concurrent.STM import Control.Concurrent.STM
import qualified Control.Exception.Lifted as Ex import qualified Control.Exception.Lifted as Ex
@ -244,4 +246,13 @@ endSession = do -- TODO: This has to be idempotent (is it?)
-- | Close the connection to the server. -- | Close the connection to the server.
closeConnection :: Xmpp () closeConnection :: Xmpp ()
closeConnection = void $ withConnection xmppKillConnection closeConnection = Ex.mask_ $ do
write <- asks writeRef
send <- liftIO . atomically $ takeTMVar write
cc <- sCloseConnection <$> (liftIO . atomically . readTMVar =<< asks conStateRef)
liftIO . send $ "</stream:stream>"
void . liftIO . forkIO $ do
threadDelay 3000000
(Ex.try cc) :: IO (Either Ex.SomeException ())
return ()
liftIO . atomically $ putTMVar write (\_ -> return False)

17
source/Network/Xmpp/Concurrent/Threads.hs

@ -54,7 +54,10 @@ readWorker messageC presenceC stanzaC iqHands handlers stateRef =
[ Ex.Handler $ \(Interrupt t) -> do [ Ex.Handler $ \(Interrupt t) -> do
void $ handleInterrupts [t] void $ handleInterrupts [t]
return Nothing return Nothing
, Ex.Handler $ \(e :: StreamError) -> noCon handlers e , Ex.Handler $ \(e :: StreamError) -> do
hands <- atomically $ readTVar handlers
_ <- forkIO $ connectionClosedHandler hands e
return Nothing
] ]
liftIO . atomically $ do liftIO . atomically $ do
case res of case res of
@ -139,10 +142,14 @@ writeWorker stCh writeR = forever $ do
takeTMVar writeR <*> takeTMVar writeR <*>
readTChan stCh readTChan stCh
r <- write $ renderElement (pickleElem xpStanza next) r <- write $ renderElement (pickleElem xpStanza next)
unless r $ do -- If the writing failed, the connection is dead. atomically $ putTMVar writeR write
atomically $ unGetTChan stCh next unless r $ do
atomically $ unGetTChan stCh next -- If the writing failed, the
-- connection is dead.
threadDelay 250000 -- Avoid free spinning. threadDelay 250000 -- Avoid free spinning.
atomically $ putTMVar writeR write -- Put it back.
-- Two streams: input and output. Threads read from input stream and write to -- Two streams: input and output. Threads read from input stream and write to
-- output stream. -- output stream.
@ -236,4 +243,4 @@ connPersist lock = forever $ do
pushBS <- atomically $ takeTMVar lock pushBS <- atomically $ takeTMVar lock
_ <- pushBS " " _ <- pushBS " "
atomically $ putTMVar lock pushBS atomically $ putTMVar lock pushBS
threadDelay 30000000 threadDelay 30000000 -- 30s

13
source/Network/Xmpp/Monad.hs

@ -64,13 +64,16 @@ pullToSink snk = do
pullElement :: XmppConMonad Element pullElement :: XmppConMonad Element
pullElement = do pullElement = do
Ex.catch (do Ex.catches (do
e <- pullToSink (elements =$ CL.head) e <- pullToSink (elements =$ CL.head)
case e of case e of
Nothing -> liftIO $ Ex.throwIO StreamConnectionError Nothing -> liftIO $ Ex.throwIO StreamConnectionError
Just r -> return r Just r -> return r
) )
(\(InvalidEventStream s) -> liftIO . Ex.throwIO $ StreamXMLError s) [ Ex.Handler (\StreamEnd -> Ex.throwIO StreamStreamEnd)
, Ex.Handler (\(InvalidEventStream s)
-> liftIO . Ex.throwIO $ StreamXMLError s)
]
-- Pulls an element and unpickles it. -- Pulls an element and unpickles it.
pullPickle :: PU [Node] a -> XmppConMonad a pullPickle :: PU [Node] a -> XmppConMonad a
@ -95,6 +98,7 @@ catchPush p = Ex.catch
(p >> return True) (p >> return True)
(\e -> case GIE.ioe_type e of (\e -> case GIE.ioe_type e of
GIE.ResourceVanished -> return False GIE.ResourceVanished -> return False
GIE.IllegalOperation -> return False
_ -> Ex.throwIO e _ -> Ex.throwIO e
) )
@ -143,11 +147,12 @@ xmppNewSession :: XmppConMonad a -> IO (a, XmppConnection)
xmppNewSession action = runStateT action xmppNoConnection xmppNewSession action = runStateT action xmppNoConnection
-- Closes the connection and updates the XmppConMonad XmppConnection state. -- Closes the connection and updates the XmppConMonad XmppConnection state.
xmppKillConnection :: XmppConMonad () xmppKillConnection :: XmppConMonad (Either Ex.SomeException ())
xmppKillConnection = do xmppKillConnection = do
cc <- gets sCloseConnection cc <- gets sCloseConnection
void . liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ())) err <- liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ()))
put xmppNoConnection put xmppNoConnection
return err
-- Sends an IQ request and waits for the response. If the response ID does not -- Sends an IQ request and waits for the response. If the response ID does not
-- match the outgoing ID, an error is thrown. -- match the outgoing ID, an error is thrown.

1
source/Network/Xmpp/Types.hs

@ -533,6 +533,7 @@ data XmppStreamError = XmppStreamError
data StreamError = StreamError XmppStreamError data StreamError = StreamError XmppStreamError
| StreamWrongVersion Text | StreamWrongVersion Text
| StreamXMLError String -- If stream pickling goes wrong. | StreamXMLError String -- If stream pickling goes wrong.
| StreamStreamEnd -- received closing stream tag
| StreamConnectionError | StreamConnectionError
deriving (Show, Eq, Typeable) deriving (Show, Eq, Typeable)

22
source/Text/XML/Stream/Elements.hs

@ -1,22 +1,26 @@
{-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
module Text.XML.Stream.Elements where module Text.XML.Stream.Elements where
import Control.Applicative ((<$>)) import Control.Applicative ((<$>))
import Control.Exception
import Control.Monad.Trans.Class import Control.Monad.Trans.Class
import Control.Monad.Trans.Resource as R import Control.Monad.Trans.Resource as R
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
import Data.Conduit as C
import Data.Conduit.List as CL
import qualified Data.Text as Text import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text import qualified Data.Text.Encoding as Text
import Data.Typeable
import Data.XML.Types import Data.XML.Types
import qualified Text.XML.Stream.Render as TXSR
import Text.XML.Unresolved as TXU
import Data.Conduit as C
import Data.Conduit.List as CL
import System.IO.Unsafe(unsafePerformIO) import System.IO.Unsafe(unsafePerformIO)
import qualified Text.XML.Stream.Render as TXSR
import Text.XML.Unresolved as TXU
compressNodes :: [Node] -> [Node] compressNodes :: [Node] -> [Node]
compressNodes [] = [] compressNodes [] = []
compressNodes [x] = [x] compressNodes [x] = [x]
@ -24,6 +28,13 @@ compressNodes (NodeContent (ContentText x) : NodeContent (ContentText y) : z) =
compressNodes $ NodeContent (ContentText $ x `Text.append` y) : z compressNodes $ NodeContent (ContentText $ x `Text.append` y) : z
compressNodes (x:xs) = x : compressNodes xs compressNodes (x:xs) = x : compressNodes xs
streamName :: Name
streamName =
(Name "stream" (Just "http://etherx.jabber.org/streams") (Just "stream"))
data StreamEnd = StreamEnd deriving (Typeable, Show)
instance Exception StreamEnd
elements :: R.MonadThrow m => C.Conduit Event m Element elements :: R.MonadThrow m => C.Conduit Event m Element
elements = do elements = do
x <- C.await x <- C.await
@ -31,6 +42,7 @@ elements = do
Just (EventBeginElement n as) -> do Just (EventBeginElement n as) -> do
goE n as >>= C.yield goE n as >>= C.yield
elements elements
Just (EventEndElement streamName) -> lift $ R.monadThrow StreamEnd
Nothing -> return () Nothing -> return ()
_ -> lift $ R.monadThrow $ InvalidEventStream $ "not an element: " ++ show x _ -> lift $ R.monadThrow $ InvalidEventStream $ "not an element: " ++ show x
where where

12
tests/Tests.hs

@ -3,6 +3,7 @@ module Example where
import Control.Concurrent import Control.Concurrent
import Control.Concurrent.STM import Control.Concurrent.STM
import qualified Control.Exception.Lifted as Ex
import Control.Monad import Control.Monad
import Control.Monad.IO.Class import Control.Monad.IO.Class
@ -114,19 +115,20 @@ runMain debug number = do
debug . (("Thread " ++ show number ++ ":") ++) debug . (("Thread " ++ show number ++ ":") ++)
wait <- newEmptyTMVarIO wait <- newEmptyTMVarIO
withNewSession $ do withNewSession $ do
setSessionEndHandler (liftIO . atomically $ putTMVar wait ())
setConnectionClosedHandler (\e -> do setConnectionClosedHandler (\e -> do
liftIO (debug' $ "connection lost because " ++ show e) liftIO (debug' $ "connection lost because " ++ show e)
endSession ) endSession )
debug' "running" debug' "running"
withConnection $ do withConnection $ Ex.catch (do
connect "localhost" "species64739.dyndns.org" connect "localhost" "species64739.dyndns.org"
startTLS exampleParams startTLS exampleParams
saslResponse <- auth (fromJust $ localpart we) "pwd" (resourcepart we) saslResponse <- simpleAuth
(fromJust $ localpart we) "pwd" (resourcepart we)
case saslResponse of case saslResponse of
Right _ -> return () Right _ -> return ()
Left e -> error $ show e Left e -> error $ show e
debug' "session standing" debug' "session standing")
(\e -> liftIO (print (e ::Ex.SomeException) >> Ex.throwIO e) )
sendPresence presenceOnline sendPresence presenceOnline
fork autoAccept fork autoAccept
sendPresence $ presenceSubscribe them sendPresence $ presenceSubscribe them
@ -148,7 +150,7 @@ runMain debug number = do
sendUser "All tests done" sendUser "All tests done"
debug' "ending session" debug' "ending session"
liftIO . atomically $ putTMVar wait () liftIO . atomically $ putTMVar wait ()
endSession closeConnection
liftIO . atomically $ takeTMVar wait liftIO . atomically $ takeTMVar wait
return () return ()
return () return ()

Loading…
Cancel
Save